@@ -16,14 +16,14 @@ __global__ void rms_norm_kernel(T* __restrict__ out,
1616 const T* __restrict__ input,
1717 const T* __restrict__ weight,
1818 const float epsilon,
19- int n) {
20- const int tidx = threadIdx .x ;
21- const int bidx = blockIdx .x ;
19+ int64_t n) {
20+ const auto tidx = threadIdx .x ;
21+ const auto bidx = blockIdx .x ;
2222
2323 __shared__ float s_variance;
2424 float variance = 0 .0f ;
2525
26- for (int i = tidx; i < n; i += blockDim .x ) {
26+ for (int64_t i = tidx; i < n; i += blockDim .x ) {
2727 const float x = input[bidx * n + i];
2828 variance += x * x;
2929 }
@@ -33,8 +33,8 @@ __global__ void rms_norm_kernel(T* __restrict__ out,
3333 }
3434 __syncthreads ();
3535
36- for (int i = tidx; i < n; i += blockDim .x ) {
37- const int idx = bidx * n + i;
36+ for (int64_t i = tidx; i < n; i += blockDim .x ) {
37+ const int64_t idx = bidx * n + i;
3838 const float x = input[idx];
3939 out[idx] = (T)(x * s_variance) * weight[i];
4040 }
@@ -47,10 +47,10 @@ void rms_norm(torch::Tensor& out,
4747 DCHECK (input.is_contiguous ()) << " input tensor must be contiguous" ;
4848 DCHECK (out.is_contiguous ()) << " output tensor must be contiguous" ;
4949
50- const int n = input.size (1 );
50+ const int64_t n = input.size (1 );
5151
5252 dim3 grid (input.size (0 ));
53- dim3 block (std::min (n, 1024 ));
53+ dim3 block (std::min< int > (n, 1024 ));
5454 DISPATCH_FLOATING_TYPES (input.scalar_type (), " rms_norm_kernel" , [&] {
5555 rms_norm_kernel<scalar_t >
5656 <<<grid, block, 0 , at::cuda::getCurrentCUDAStream()>>> (
@@ -73,15 +73,15 @@ __global__ void rms_norm_residual_kernel(T* __restrict__ out,
7373 const T* __restrict__ input,
7474 const T* __restrict__ weight,
7575 const float epsilon,
76- int n) {
77- const int tidx = threadIdx .x ;
78- const int bidx = blockIdx .x ;
76+ int64_t n) {
77+ const auto tidx = threadIdx .x ;
78+ const auto bidx = blockIdx .x ;
7979
8080 __shared__ float s_variance;
8181 float variance = 0 .0f ;
8282
83- for (int i = tidx; i < n; i += blockDim .x ) {
84- const int idx = bidx * n + i;
83+ for (int64_t i = tidx; i < n; i += blockDim .x ) {
84+ const int64_t idx = bidx * n + i;
8585 const float r = residual[idx];
8686 const float x = r + input[idx];
8787 residual[idx] = x;
@@ -93,8 +93,8 @@ __global__ void rms_norm_residual_kernel(T* __restrict__ out,
9393 }
9494 __syncthreads ();
9595
96- for (int i = tidx; i < n; i += blockDim .x ) {
97- const int idx = bidx * n + i;
96+ for (int64_t i = tidx; i < n; i += blockDim .x ) {
97+ const int64_t idx = bidx * n + i;
9898 const float x = residual[idx];
9999 out[idx] = (T)(x * s_variance) * weight[i];
100100 }
@@ -109,10 +109,10 @@ void rms_norm_residual(torch::Tensor& out,
109109 DCHECK (out.is_contiguous ()) << " output tensor must be contiguous" ;
110110 DCHECK (residual.is_contiguous ()) << " residual tensor must be contiguous" ;
111111
112- const int n = input.size (1 );
112+ const int64_t n = input.size (1 );
113113
114114 dim3 grid (input.size (0 ));
115- dim3 block (std::min (n, 1024 ));
115+ dim3 block (std::min< int > (n, 1024 ));
116116 DISPATCH_FLOATING_TYPES (input.scalar_type (), " rms_norm_residual_kernel" , [&] {
117117 rms_norm_residual_kernel<scalar_t >
118118 <<<grid, block, 0 , at::cuda::getCurrentCUDAStream()>>> (
@@ -133,17 +133,17 @@ __global__ void layer_norm_kernel(T* __restrict__ out,
133133 const T* __restrict__ weight,
134134 const T* __restrict__ bias,
135135 const float epsilon,
136- int n) {
137- const int tidx = threadIdx .x ;
138- const int bidx = blockIdx .x ;
136+ int64_t n) {
137+ const auto tidx = threadIdx .x ;
138+ const auto bidx = blockIdx .x ;
139139
140140 __shared__ float s_mean;
141141 __shared__ float s_variance;
142142 float mean = 0 .0f ;
143143 float variance = 0 .0f ;
144144
145145 // calculate mean of the input.
146- for (int i = tidx; i < n; i += blockDim .x ) {
146+ for (int64_t i = tidx; i < n; i += blockDim .x ) {
147147 mean += input[bidx * n + i];
148148 }
149149 mean = block_reduce_sum<float >(mean);
@@ -153,7 +153,7 @@ __global__ void layer_norm_kernel(T* __restrict__ out,
153153 __syncthreads ();
154154
155155 // calculate variance of the input.
156- for (int i = tidx; i < n; i += blockDim .x ) {
156+ for (int64_t i = tidx; i < n; i += blockDim .x ) {
157157 const float x = input[bidx * n + i] - s_mean;
158158 variance += x * x;
159159 }
@@ -163,8 +163,8 @@ __global__ void layer_norm_kernel(T* __restrict__ out,
163163 }
164164 __syncthreads ();
165165
166- for (int i = tidx; i < n; i += blockDim .x ) {
167- const int idx = bidx * n + i;
166+ for (int64_t i = tidx; i < n; i += blockDim .x ) {
167+ const int64_t idx = bidx * n + i;
168168 float local_out = (input[idx] - s_mean) * s_variance * weight[i];
169169 if (bias != nullptr ) {
170170 local_out += bias[i];
@@ -181,10 +181,10 @@ void layer_norm(torch::Tensor& out,
181181 DCHECK (input.is_contiguous ()) << " input tensor must be contiguous" ;
182182 DCHECK (out.is_contiguous ()) << " output tensor must be contiguous" ;
183183
184- const int n = input.size (1 );
184+ const int64_t n = input.size (1 );
185185
186186 dim3 grid (input.size (0 ));
187- dim3 block (std::min (n, 1024 ));
187+ dim3 block (std::min< int > (n, 1024 ));
188188 DISPATCH_FLOATING_TYPES (input.scalar_type (), " layer_norm_kernel" , [&] {
189189 layer_norm_kernel<scalar_t >
190190 <<<grid, block, 0 , at::cuda::getCurrentCUDAStream()>>> (
0 commit comments