@@ -1634,6 +1634,14 @@ def _check_runtime_broadcast(node, value, shape):
16341634 if v_static_dim is None and value_dim == 1 and out_dim != 1 :
16351635 raise ValueError (Alloc ._runtime_broadcast_error_msg )
16361636
1637+ @staticmethod
1638+ def value_is_scalar_zero (x : TensorVariable ) -> bool :
1639+ return (
1640+ all (x .type .broadcastable )
1641+ and isinstance (x , Constant )
1642+ and (x .unique_value == 0 )
1643+ )
1644+
16371645 def perform (self , node , inputs , out_ ):
16381646 (out ,) = out_
16391647 v = inputs [0 ]
@@ -1659,6 +1667,7 @@ def c_code(self, node, name, inp, out, sub):
16591667 o_static_shape = node .outputs [0 ].type .shape
16601668 v_ndim = len (v_static_shape )
16611669 o_ndim = len (o_static_shape )
1670+ is_zero = self .value_is_scalar_zero (node .inputs [0 ])
16621671 assert o_ndim == len (inp [1 :])
16631672
16641673 # Declare variables
@@ -1699,16 +1708,18 @@ def c_code(self, node, name, inp, out, sub):
16991708 { fail }
17001709 }}
17011710 }}
1702-
1711+ if ({ int (is_zero )} && (PyArray_IS_C_CONTIGUOUS({ zz } ) || PyArray_IS_F_CONTIGUOUS({ zz } ))){{
1712+ PyArray_FILLWBYTE({ zz } , 0);
1713+ }}
17031714 // This function takes care of broadcasting
1704- if (PyArray_CopyInto({ zz } , { vv } ) == -1)
1715+ else if (PyArray_CopyInto({ zz } , { vv } ) == -1)
17051716 { fail }
17061717 """
17071718
17081719 return code
17091720
17101721 def c_code_cache_version (self ):
1711- return (4 ,)
1722+ return (5 ,)
17121723
17131724 def infer_shape (self , fgraph , node , input_shapes ):
17141725 return [node .inputs [1 :]]
0 commit comments