@@ -608,9 +608,12 @@ def FProp(self, theta, inputs, paddings=None):
608608 def _Compute (self , theta , inputs , paddings , conv_padding ):
609609 """Computes the forward prop (conv, bn, act)."""
610610 p = self .params
611+ del paddings
611612
612613 bn_padding = conv_padding
613614 if bn_padding is None :
615+ batch_time = None
616+ batch_time_any_any = None
614617 bn_padding_expanded = None
615618 else :
616619 batch_time = tf .shape (bn_padding )
@@ -644,6 +647,7 @@ def _ComputeConvLast(self, theta, inputs, paddings, conv_padding):
644647 p = self .params
645648 out = inputs
646649 out_padding = paddings
650+ del conv_padding
647651
648652 if p .batch_norm :
649653 if out_padding is None :
@@ -1014,6 +1018,7 @@ def _CreateLayerVariables(self):
10141018 init = p .params_init ,
10151019 dtype = p .dtype ,
10161020 collections = [self .__class__ .__name__ + '_vars' ])
1021+ mix_kernel_pc = None
10171022 elif p .use_block_diagonal_matmul :
10181023 w_pc = py_utils .WeightParams (
10191024 shape = (p .bd_num_blocks , p .input_dim // p .bd_num_blocks ,
@@ -1039,6 +1044,7 @@ def _CreateLayerVariables(self):
10391044 device_mesh = p .device_mesh ,
10401045 tensor_split_dims_mapping = p .weight_split_dims_mapping ,
10411046 collections = [self .__class__ .__name__ + '_vars' ])
1047+ mix_kernel_pc = None
10421048
10431049 if p .apply_pruning :
10441050 mask_w_pc = py_utils .WeightParams (w_pc .shape ,
@@ -1047,6 +1053,9 @@ def _CreateLayerVariables(self):
10471053 threshold_w_pc = py_utils .WeightParams ([],
10481054 py_utils .WeightInit .Constant (0.0 ),
10491055 tf .float32 )
1056+ else :
1057+ mask_w_pc = None
1058+ threshold_w_pc = None
10501059 if p .has_bias :
10511060 if p .device_mesh is not None :
10521061 bias_split_dims_mapping = [p .weight_split_dims_mapping [1 ]]
@@ -1059,12 +1068,16 @@ def _CreateLayerVariables(self):
10591068 device_mesh = p .device_mesh ,
10601069 tensor_split_dims_mapping = bias_split_dims_mapping ,
10611070 collections = [self .__class__ .__name__ + '_vars' ])
1071+ else :
1072+ b_pc = None
10621073 if p .weight_norm :
10631074 g_pc = py_utils .WeightParams (
10641075 shape = [self ._internal_output_dim ],
10651076 init = py_utils .WeightInit .Constant (0.0 ),
10661077 dtype = p .dtype ,
10671078 collections = [self .__class__ .__name__ + '_vars' ])
1079+ else :
1080+ g_pc = None
10681081
10691082 weights_var_name = 'w'
10701083 if p .apply_pruning :
@@ -1450,6 +1463,8 @@ def _CreateLayerVariables(self):
14501463 dtype = p .dtype ,
14511464 collections = [self .__class__ .__name__ + '_vars' ],
14521465 )
1466+ else :
1467+ b_pc = None
14531468 w_name = 'w'
14541469 self .CreateVariable (w_name , w_pc )
14551470 self .TrackQWeight (w_name , shape = w_pc .shape , feature_axis = - 1 )
@@ -3442,9 +3457,8 @@ class RotaryPositionalEmbeddingLayer(PositionalEmbeddingLayer):
34423457 The Rotary position embedding is described in https://arxiv.org/abs/2104.09864
34433458 """
34443459
3445- # pylint: disable=arguments-renamed
3460+ # pylint: disable-next =arguments-renamed
34463461 def FProp (self , theta , inputs , position = None ):
3447- # pylint: enable=arguments-renamed
34483462 """Generates a JTensor of sinusoids with different frequencies.
34493463
34503464 Args:
@@ -3759,6 +3773,9 @@ def _CreateLayerVariables(self):
37593773 threshold_pc = py_utils .WeightParams ([],
37603774 py_utils .WeightInit .Constant (0.0 ),
37613775 tf .float32 )
3776+ else :
3777+ mask_pc = None
3778+ threshold_pc = None
37623779
37633780 for i in range (p .num_shards ):
37643781 weights_var_name = f'weight_{ i } '
@@ -4315,6 +4332,8 @@ def XentLossFromLogits(self,
43154332 class_ids = None ,
43164333 class_probabilities = None ):
43174334 """Computes cross-entropy, argmax etc. from logits."""
4335+ del theta
4336+ del class_weights
43184337 p = self .params
43194338 assert logits is not None
43204339 per_example_argmax = py_utils .ArgMax (logits )
@@ -4819,6 +4838,7 @@ def Params(cls):
48194838 return p
48204839
48214840 def _Dropout (self , theta , inputs , noise_shape ):
4841+ del theta
48224842 return tf .nn .dropout (
48234843 inputs ,
48244844 rate = 1 - self .params .keep_prob ,
@@ -4827,6 +4847,7 @@ def _Dropout(self, theta, inputs, noise_shape):
48274847
48284848 @classmethod
48294849 def NumOutputNodes (cls , p ):
4850+ del p
48304851 # The layer does element-wise processing thus is input-shape agnostic.
48314852 return
48324853
0 commit comments