@@ -101,15 +101,23 @@ def __init__(
101
101
assert w_spec .shape == arg_shape
102
102
assert pxd .NDArrayInfo .from_obj (w_spec ) == self ._ndi
103
103
104
- # TODO: modify for weighted case?
105
104
# Cheap analytical Lipschitz upper bound given by
106
- # \sigma_{\max}(P) <= \max{P.sum(axis=0), P.sum(axis=1) },
105
+ # \sigma_{\max}(P) <= \norm{P}{F },
107
106
# with
108
- # P.sum(axis=0) <= \norm{pitch}{2} * N_ray [very pessimistic]
109
- # P.sum(axis=1) <= \norm{pitch}{2} * \norm{arg_shape}{2}
110
- row_ub = np .linalg .norm (pitch ) * np .linalg .norm (arg_shape )
111
- col_ub = np .linalg .norm (pitch ) * len (n_spec )
112
- self .lipschitz = max (row_ub , col_ub )
107
+ # \norm{P}{F}^{2}
108
+ # <= (max cell weight)^{2} * #non-zero elements
109
+ # = (max cell weight)^{2} * N_ray * (maximum number of cells traversable by a ray)
110
+ # = (max cell weight)^{2} * N_ray * \norm{arg_shape}{2}
111
+ #
112
+ # (max cell weight) =
113
+ # unweighted : \norm{pitch}{2}
114
+ # weighted & (w_min > 0): \norm{pitch}{2}
115
+ # weighted & (w_min < 0): cannot infer
116
+ if weighted and (w_spec .min () < 0 ):
117
+ max_cell_weight = np .inf
118
+ else :
119
+ max_cell_weight = np .linalg .norm (pitch )
120
+ self .lipschitz = max_cell_weight * np .sqrt (N_l * np .linalg .norm (arg_shape ))
113
121
114
122
# Dr.Jit variables. {Have shapes consistent for xrt_[apply,adjoint]().}
115
123
# xrt_[apply,adjoint]() only support D=3 case.
0 commit comments