2020import jax
2121from jax import jit
2222import jax .numpy as jnp
23- from jax .lax import scan
23+ from jax .lax import scan , cond
2424from jax .flatten_util import ravel_pytree
2525
2626from varipeps import varipeps_config , varipeps_global_state
3232from varipeps .ctmrg import CTMRGNotConvergedError , CTMRGGradientNotConvergedError
3333from varipeps .utils .random import PEPS_Random_Number_Generator
3434from varipeps .utils .slurm import SlurmUtils
35+ from varipeps .contractions import apply_contraction_jitted
36+ from varipeps .utils .debug_print import debug_print
3537
3638from .inner_function import (
3739 calc_ctmrg_expectation ,
@@ -143,7 +145,7 @@ def _bfgs_workhorse(
143145
144146
145147@jit
146- def _l_bfgs_workhorse (value_tuple , gradient_tuple ):
148+ def _l_bfgs_workhorse (value_tuple , gradient_tuple , t_objs , config ):
147149 gradient_elem_0 , gradient_unravel = ravel_pytree (gradient_tuple [0 ])
148150 gradient_len = gradient_elem_0 .size
149151
@@ -155,6 +157,9 @@ def _make_1d(x):
155157 return jnp .concatenate ((jnp .real (x_1d ), jnp .imag (x_1d )))
156158 return x_1d
157159
160+ gradient_elem_0_1d = _make_1d (gradient_elem_0 )
161+ norm_grad_square = jnp .sum (gradient_elem_0_1d * gradient_elem_0_1d )
162+
158163 value_arr = jnp .asarray ([_make_1d (e ) for e in value_tuple ])
159164 gradient_arr = jnp .asarray ([_make_1d (e ) for e in gradient_tuple ])
160165
@@ -173,9 +178,69 @@ def first_loop(q, x):
173178 (pho_arr [:, jnp .newaxis ] * s_arr , y_arr ),
174179 )
175180
176- gamma = jnp .sum (s_arr [- 1 ] * y_arr [- 1 ]) / jnp .sum (y_arr [- 1 ] * y_arr [- 1 ])
181+ def apply_precond (x ):
182+ if hasattr (t_objs [0 ], "is_triangular_peps" ) and t_objs [0 ].is_triangular_peps :
183+ contraction = "precondition_operator_triangular"
184+ elif hasattr (t_objs [0 ], "is_split_transfer" ) and t_objs [0 ].is_split_transfer :
185+ contraction = "precondition_operator_split_transfer"
186+ else :
187+ contraction = "precondition_operator"
188+
189+ if iscomplex :
190+ x = x [:gradient_len ] + 1j * x [gradient_len :]
191+ x = gradient_unravel (x )
192+ x = [
193+ apply_contraction_jitted (contraction , (te .tensor ,), (te ,), (xe ,))
194+ + norm_grad_square * xe
195+ for te , xe in zip (t_objs , x , strict = True )
196+ ]
197+
198+ return _make_1d (x )
199+
200+ if config .optimizer_use_preconditioning :
201+ y_precond , _ = jax .scipy .sparse .linalg .gmres (
202+ apply_precond ,
203+ y_arr [0 ],
204+ y_arr [0 ],
205+ restart = config .optimizer_precond_gmres_krylov_subspace_size ,
206+ maxiter = config .optimizer_precond_gmres_maxiter ,
207+ solve_method = "incremental" ,
208+ )
209+
210+ def calc_q_precond (y , y_precond , q ):
211+ q_precond , _ = jax .scipy .sparse .linalg .gmres (
212+ apply_precond ,
213+ q ,
214+ q ,
215+ restart = config .optimizer_precond_gmres_krylov_subspace_size ,
216+ maxiter = config .optimizer_precond_gmres_maxiter ,
217+ solve_method = "incremental" ,
218+ )
219+
220+ return cond (
221+ jnp .sum (q_precond * q ) >= 0 ,
222+ lambda y , y_precond , q , q_precond : (y_precond , q_precond ),
223+ lambda y , y_precond , q , q_precond : (y , q ),
224+ y ,
225+ y_precond ,
226+ q ,
227+ q_precond ,
228+ )
229+
230+ y_precond , q_precond = cond (
231+ jnp .sum (y_precond * y_arr [0 ]) >= 0 ,
232+ calc_q_precond ,
233+ lambda y , y_precond , q : (y , q ),
234+ y_arr [0 ],
235+ y_precond ,
236+ q ,
237+ )
238+ else :
239+ y_precond = y_arr [0 ]
240+ q_precond = q
177241
178- z_result = gamma * q
242+ gamma = jnp .sum (s_arr [0 ] * y_arr [0 ]) / jnp .sum (y_arr [0 ] * y_precond )
243+ z_result = gamma * q_precond
179244
180245 def second_loop (z , x ):
181246 pho_y , s , alpha_i = x
@@ -753,9 +818,72 @@ def random_noise(a):
753818
754819 if count == 0 or signal_reset_descent_dir :
755820 descent_dir = [- elem for elem in working_gradient ]
821+
822+ if varipeps_config .optimizer_use_preconditioning :
823+ if (
824+ hasattr (
825+ working_unitcell .get_unique_tensors ()[0 ],
826+ "is_triangular_peps" ,
827+ )
828+ and working_unitcell .get_unique_tensors ()[
829+ 0
830+ ].is_triangular_peps
831+ ):
832+ contraction = "precondition_operator_triangular"
833+ elif (
834+ hasattr (
835+ working_unitcell .get_unique_tensors ()[0 ],
836+ "is_split_transfer" ,
837+ )
838+ and working_unitcell .get_unique_tensors ()[
839+ 0
840+ ].is_split_transfer
841+ ):
842+ contraction = "precondition_operator_split_transfer"
843+ else :
844+ contraction = "precondition_operator"
845+
846+ grad_norm_squared = 1e-2 * (
847+ jnp .linalg .norm (jnp .asarray (working_gradient )) ** 2
848+ )
849+
850+ tmp_descent_dir = [
851+ jax .scipy .sparse .linalg .gmres (
852+ lambda x : (
853+ apply_contraction_jitted (
854+ contraction , (te .tensor ,), (te ,), (x ,)
855+ )
856+ + grad_norm_squared * x
857+ ),
858+ xe ,
859+ xe ,
860+ restart = varipeps_config .optimizer_precond_gmres_krylov_subspace_size ,
861+ maxiter = varipeps_config .optimizer_precond_gmres_maxiter ,
862+ solve_method = "incremental" ,
863+ )[0 ]
864+ for te , xe in zip (
865+ working_unitcell .get_unique_tensors (),
866+ descent_dir ,
867+ strict = True ,
868+ )
869+ ]
870+ if all (
871+ jnp .sum (xe * x2e .conj ()) >= 0
872+ for xe , x2e in zip (
873+ descent_dir , tmp_descent_dir , strict = True
874+ )
875+ ):
876+ descent_dir = tmp_descent_dir
877+ else :
878+ tqdm .write ("Warning: Non-positive preconditioner" )
879+ del contraction
880+ del grad_norm_squared
756881 else :
757882 descent_dir = _l_bfgs_workhorse (
758- tuple (l_bfgs_x_cache ), tuple (l_bfgs_grad_cache )
883+ tuple (l_bfgs_x_cache ),
884+ tuple (l_bfgs_grad_cache ),
885+ working_unitcell .get_unique_tensors (),
886+ varipeps_config ,
759887 )
760888 else :
761889 raise ValueError ("Unknown optimization method." )
@@ -767,6 +895,8 @@ def random_noise(a):
767895 descent_dir = [- elem for elem in working_gradient ]
768896
769897 conv = jnp .linalg .norm (ravel_pytree (working_gradient )[0 ])
898+ if jnp .isinf (conv ) or jnp .isnan (conv ):
899+ conv = 0
770900 step_conv [random_noise_retries ].append (conv )
771901
772902 try :
0 commit comments