@@ -89,13 +89,31 @@ def _preprocessing(self):
8989 g_idx_trivial = torch .tensor (
9090 g_idx_trivial , dtype = torch .int32 , device = self .g_idx .device
9191 )
92- assert torch .equal (
93- self .g_idx , g_idx_trivial
94- ), "Non-trivial tensor g_idx is not supported"
92+ sort_zeros = not (torch .equal (self .g_idx , g_idx_trivial ))
9593 self .qzeros = self .qzeros .cpu ()
9694 zeros = self .unpack_zeros_from_cuda_old_format ()
97- new_qzeros = pack_tensor (zeros )
98- self .qzeros = new_qzeros .to (orig_device )
95+ if sort_zeros :
96+ zeros_group_1 = torch .zeros (
97+ (self .infeatures , self .outfeatures ),
98+ dtype = zeros .dtype ,
99+ device = zeros .device ,
100+ )
101+ scales = self .scales .cpu ()
102+ scale_group_1 = torch .zeros (
103+ (self .infeatures , self .outfeatures ),
104+ dtype = scales .dtype ,
105+ device = scales .device ,
106+ )
107+ for i in range (self .infeatures ):
108+ zeros_group_1 [i ] = zeros [self .g_idx [i ]]
109+ scale_group_1 [i ] = self .scales [self .g_idx [i ]]
110+ self .qzeros = pack_tensor (zeros_group_1 ).to (orig_device )
111+ self .scales = scale_group_1 .to (orig_device )
112+ self .groupsize = 1
113+ self .g_idx = None
114+ else :
115+ new_qzeros = pack_tensor (zeros )
116+ self .qzeros = new_qzeros .to (orig_device )
99117
100118 @classmethod
101119 def new (cls , bits , groupsize , infeatures , outfeatures , bias ):
0 commit comments