11"""
2- Description :
2+ Description :
33Author : Boxin Zhang
44Version : 0.1.0
5- Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
5+ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
66"""
77
88from torch import nn
2727from ktransformers .util .utils import InferenceState
2828from transformers .configuration_utils import PretrainedConfig
2929import torch
30+ from ktransformers .util .torch_auto_backend import CUDA
3031
3132# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
3233class RotaryEmbedding (BaseInjectedModule , DeepseekV2RotaryEmbedding ):
@@ -37,8 +38,8 @@ def __init__(
3738 config : PretrainedConfig ,
3839 orig_module : nn .Module ,
3940 # device: str = "cuda",
40- generate_device : str = "cuda" ,
41- prefill_device : str = "cuda" ,
41+ generate_device : str = CUDA ,
42+ prefill_device : str = CUDA ,
4243 ** kwargs ,
4344 ):
4445 BaseInjectedModule .__init__ (
@@ -67,16 +68,16 @@ def __init__(
6768 config : PretrainedConfig ,
6869 orig_module : nn .Module ,
6970 # device: str = "cuda",
70- generate_device : str = "cuda" ,
71- prefill_device : str = "cuda" ,
71+ generate_device : str = CUDA ,
72+ prefill_device : str = CUDA ,
7273 ** kwargs ,
7374 ):
7475 BaseInjectedModule .__init__ (
7576 self , key , gguf_loader , config , orig_module , prefill_device , generate_device , ** kwargs
7677 )
7778 self .generate_device = generate_device
7879 self .prefill_device = prefill_device
79-
80+
8081 @torch .no_grad ()
8182 def forward (self , x , position_ids ):
8283 # x: [bs, num_attention_heads, seq_len, head_size]
@@ -91,7 +92,7 @@ def forward(self, x, position_ids):
9192 emb = torch .cat ((freqs , freqs ), dim = - 1 )
9293 cos = emb .cos ()
9394 sin = emb .sin ()
94- return cos .to (dtype = x .dtype ), sin .to (dtype = x .dtype )
95+ return cos .to (dtype = x .dtype ), sin .to (dtype = x .dtype )
9596
9697 def load (self ):
9798 self ._init (
@@ -117,8 +118,8 @@ def __init__(
117118 gguf_loader : GGUFLoader ,
118119 config : PretrainedConfig ,
119120 orig_module : nn .Module ,
120- generate_device : str = "cuda" ,
121- prefill_device : str = "cuda" ,
121+ generate_device : str = CUDA ,
122+ prefill_device : str = CUDA ,
122123 ** kwargs ,
123124 ):
124125 BaseInjectedModule .__init__ (
@@ -155,8 +156,8 @@ def __init__(
155156 config : PretrainedConfig ,
156157 orig_module : nn .Module ,
157158 # device: str = "cuda",
158- generate_device : str = "cuda" ,
159- prefill_device : str = "cuda" ,
159+ generate_device : str = CUDA ,
160+ prefill_device : str = CUDA ,
160161 ** kwargs ,
161162 ):
162163 BaseInjectedModule .__init__ (
@@ -225,16 +226,16 @@ def __init__(
225226 config : PretrainedConfig ,
226227 orig_module : nn .Module ,
227228 # device: str = "cuda",
228- generate_device : str = "cuda" ,
229- prefill_device : str = "cuda" ,
229+ generate_device : str = CUDA ,
230+ prefill_device : str = CUDA ,
230231 ** kwargs ,
231232 ):
232233 BaseInjectedModule .__init__ (
233234 self , key , gguf_loader , config , orig_module , prefill_device , generate_device , ** kwargs
234235 )
235236 self .generate_device = generate_device
236237 self .prefill_device = prefill_device
237-
238+
238239 def load (self ):
239240 kwargs = {
240241 key : self .config .rope_scaling [key ]
@@ -270,7 +271,7 @@ def forward(self, x, position_ids):
270271 emb = torch .cat ((freqs , freqs ), dim = - 1 )
271272 cos = emb .cos ()* self ._mscale
272273 sin = emb .sin ()* self ._mscale
273- return cos .to (dtype = x .dtype ), sin .to (dtype = x .dtype )
274+ return cos .to (dtype = x .dtype ), sin .to (dtype = x .dtype )
274275
275276 def _init (
276277 self ,
@@ -332,8 +333,8 @@ def __init__(
332333 gguf_loader : GGUFLoader ,
333334 config : PretrainedConfig ,
334335 orig_module : nn .Module ,
335- prefill_device : str = "cuda" ,
336- generate_device : str = "cuda" ,
336+ prefill_device : str = CUDA ,
337+ generate_device : str = CUDA ,
337338 ** kwargs ,
338339 ):
339340 BaseInjectedModule .__init__ (
0 commit comments