Skip to content

Commit eed13f9

Browse files
committed
Implement preconditioning of gradient in L-BFGS method
1 parent 4ef210b commit eed13f9

File tree

3 files changed

+218
-7
lines changed

3 files changed

+218
-7
lines changed

varipeps/config.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,15 @@ class VariPEPS_Config:
166166
optimizer_reuse_env_eps (:obj:`float`):
167167
Reuse CTMRG environment of previous step if norm of gradient is below
168168
this threshold.
169+
optimizer_use_preconditioning (:obj:`bool`):
170+
Use (local) preconditioning method as described in
171+
https://arxiv.org/abs/2511.09546.
172+
optimizer_precond_gmres_krylov_subspace_size (:obj:`int`):
173+
Size of Krylov subspace built up during GMRES method for the inversion
174+
of the preconditioner.
175+
optimizer_precond_gmres_maxiter (:obj:`int`):
176+
Maximal number of outer iterations inside the GMRES method for the
177+
inversion of the preconditioner.
169178
line_search_method (:obj:`Line_Search_Methods`):
170179
Method used for the line search routine.
171180
line_search_initial_step_size (:obj:`float`):
@@ -263,19 +272,22 @@ class VariPEPS_Config:
263272
svd_ad_lorentz_broadening_eps: float = 1e-13
264273

265274
# Optimizer
266-
optimizer_method: Optimizing_Methods = Optimizing_Methods.BFGS
275+
optimizer_method: Optimizing_Methods = Optimizing_Methods.L_BFGS
267276
optimizer_max_steps: int = 300
268277
optimizer_convergence_eps: float = 1e-5
269278
optimizer_ctmrg_preconverged_eps: float = 1e-5
270279
optimizer_fail_if_no_step_size_found: bool = False
271280
optimizer_l_bfgs_maxlen: int = 15
272-
optimizer_preconverge_with_half_projectors: bool = True
281+
optimizer_preconverge_with_half_projectors: bool = False
273282
optimizer_preconverge_with_half_projectors_eps: float = 1e-3
274283
optimizer_autosave_step_count: int = 2
275284
optimizer_random_noise_eps: float = 1e-4
276285
optimizer_random_noise_max_retries: int = 5
277286
optimizer_random_noise_relative_amplitude: float = 1e-1
278287
optimizer_reuse_env_eps: float = 1e-3
288+
optimizer_use_preconditioning: bool = True
289+
optimizer_precond_gmres_krylov_subspace_size: int = 30
290+
optimizer_precond_gmres_maxiter: int = 3
279291

280292
# Line search
281293
line_search_method: Line_Search_Methods = Line_Search_Methods.HAGERZHANG

varipeps/contractions/definitions.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4860,5 +4860,74 @@ def _prepare_defs(cls):
48604860
],
48614861
}
48624862

4863+
precondition_operator: Definition = {
4864+
"tensors": [["C1", "T1", "C2", "T2", "C3", "T3", "C4", "T4"], "ket_tensor"],
4865+
"network": [
4866+
[
4867+
(2, 12), # C1
4868+
(12, 9, -5, 3), # T1
4869+
(3, 8), # C2
4870+
(10, -4, 4, 8), # T2
4871+
(11, 4), # C3
4872+
(1, 11, -2, 7), # T3
4873+
(1, 5), # C4
4874+
(5, -1, 6, 2), # T4
4875+
],
4876+
(6, 7, -3, 10, 9), # ket_tensor
4877+
],
4878+
}
4879+
4880+
precondition_operator_triangular: Definition = {
4881+
"tensors": [["C1", "C2", "C3", "C4", "C5", "C6"], "ket_tensor"],
4882+
"network": [
4883+
[
4884+
(12, 5, -1, 1), # C1
4885+
(1, 6, -2, 2), # C2
4886+
(2, 7, -3, 8), # C3
4887+
(8, 9, -4, 3), # C4
4888+
(3, 10, -5, 4), # C5
4889+
(4, 11, -6, 12), # C6
4890+
],
4891+
(5, 6, 7, 9, 10, 11, -7), # ket_tensor
4892+
],
4893+
}
4894+
4895+
precondition_operator_split_transfer: Definition = {
4896+
"tensors": [
4897+
[
4898+
"C1",
4899+
"T1_ket",
4900+
"T1_bra",
4901+
"C2",
4902+
"T2_ket",
4903+
"T2_bra",
4904+
"C3",
4905+
"T3_bra",
4906+
"T3_ket",
4907+
"C4",
4908+
"T4_bra",
4909+
"T4_ket",
4910+
],
4911+
"ket_tensor",
4912+
],
4913+
"network": [
4914+
[
4915+
(1, 2), # C1
4916+
(2, 13, 3), # T1_ket
4917+
(3, -5, 4), # T1_bra
4918+
(4, 5), # C2
4919+
(6, 14, 5), # T2_ket
4920+
(7, -4, 6), # T2_bra
4921+
(8, 7), # C3
4922+
(9, 15, 8), # T3_ket
4923+
(10, -2, 9), # T3_bra
4924+
(10, 11), # C4
4925+
(11, 16, 12), # T4_ket
4926+
(12, -1, 1), # T4_bra
4927+
],
4928+
(16, 15, -3, 14, 13), # ket_tensor
4929+
],
4930+
}
4931+
48634932

48644933
Definitions._prepare_defs()

varipeps/optimization/optimizer.py

Lines changed: 135 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import jax
2121
from jax import jit
2222
import jax.numpy as jnp
23-
from jax.lax import scan
23+
from jax.lax import scan, cond
2424
from jax.flatten_util import ravel_pytree
2525

2626
from varipeps import varipeps_config, varipeps_global_state
@@ -32,6 +32,8 @@
3232
from varipeps.ctmrg import CTMRGNotConvergedError, CTMRGGradientNotConvergedError
3333
from varipeps.utils.random import PEPS_Random_Number_Generator
3434
from varipeps.utils.slurm import SlurmUtils
35+
from varipeps.contractions import apply_contraction_jitted
36+
from varipeps.utils.debug_print import debug_print
3537

3638
from .inner_function import (
3739
calc_ctmrg_expectation,
@@ -143,7 +145,7 @@ def _bfgs_workhorse(
143145

144146

145147
@jit
146-
def _l_bfgs_workhorse(value_tuple, gradient_tuple):
148+
def _l_bfgs_workhorse(value_tuple, gradient_tuple, t_objs, config):
147149
gradient_elem_0, gradient_unravel = ravel_pytree(gradient_tuple[0])
148150
gradient_len = gradient_elem_0.size
149151

@@ -155,6 +157,9 @@ def _make_1d(x):
155157
return jnp.concatenate((jnp.real(x_1d), jnp.imag(x_1d)))
156158
return x_1d
157159

160+
gradient_elem_0_1d = _make_1d(gradient_elem_0)
161+
norm_grad_square = jnp.sum(gradient_elem_0_1d * gradient_elem_0_1d)
162+
158163
value_arr = jnp.asarray([_make_1d(e) for e in value_tuple])
159164
gradient_arr = jnp.asarray([_make_1d(e) for e in gradient_tuple])
160165

@@ -173,9 +178,69 @@ def first_loop(q, x):
173178
(pho_arr[:, jnp.newaxis] * s_arr, y_arr),
174179
)
175180

176-
gamma = jnp.sum(s_arr[-1] * y_arr[-1]) / jnp.sum(y_arr[-1] * y_arr[-1])
181+
def apply_precond(x):
182+
if hasattr(t_objs[0], "is_triangular_peps") and t_objs[0].is_triangular_peps:
183+
contraction = "precondition_operator_triangular"
184+
elif hasattr(t_objs[0], "is_split_transfer") and t_objs[0].is_split_transfer:
185+
contraction = "precondition_operator_split_transfer"
186+
else:
187+
contraction = "precondition_operator"
188+
189+
if iscomplex:
190+
x = x[:gradient_len] + 1j * x[gradient_len:]
191+
x = gradient_unravel(x)
192+
x = [
193+
apply_contraction_jitted(contraction, (te.tensor,), (te,), (xe,))
194+
+ norm_grad_square * xe
195+
for te, xe in zip(t_objs, x, strict=True)
196+
]
197+
198+
return _make_1d(x)
199+
200+
if config.optimizer_use_preconditioning:
201+
y_precond, _ = jax.scipy.sparse.linalg.gmres(
202+
apply_precond,
203+
y_arr[0],
204+
y_arr[0],
205+
restart=config.optimizer_precond_gmres_krylov_subspace_size,
206+
maxiter=config.optimizer_precond_gmres_maxiter,
207+
solve_method="incremental",
208+
)
209+
210+
def calc_q_precond(y, y_precond, q):
211+
q_precond, _ = jax.scipy.sparse.linalg.gmres(
212+
apply_precond,
213+
q,
214+
q,
215+
restart=config.optimizer_precond_gmres_krylov_subspace_size,
216+
maxiter=config.optimizer_precond_gmres_maxiter,
217+
solve_method="incremental",
218+
)
219+
220+
return cond(
221+
jnp.sum(q_precond * q) >= 0,
222+
lambda y, y_precond, q, q_precond: (y_precond, q_precond),
223+
lambda y, y_precond, q, q_precond: (y, q),
224+
y,
225+
y_precond,
226+
q,
227+
q_precond,
228+
)
229+
230+
y_precond, q_precond = cond(
231+
jnp.sum(y_precond * y_arr[0]) >= 0,
232+
calc_q_precond,
233+
lambda y, y_precond, q: (y, q),
234+
y_arr[0],
235+
y_precond,
236+
q,
237+
)
238+
else:
239+
y_precond = y_arr[0]
240+
q_precond = q
177241

178-
z_result = gamma * q
242+
gamma = jnp.sum(s_arr[0] * y_arr[0]) / jnp.sum(y_arr[0] * y_precond)
243+
z_result = gamma * q_precond
179244

180245
def second_loop(z, x):
181246
pho_y, s, alpha_i = x
@@ -753,9 +818,72 @@ def random_noise(a):
753818

754819
if count == 0 or signal_reset_descent_dir:
755820
descent_dir = [-elem for elem in working_gradient]
821+
822+
if varipeps_config.optimizer_use_preconditioning:
823+
if (
824+
hasattr(
825+
working_unitcell.get_unique_tensors()[0],
826+
"is_triangular_peps",
827+
)
828+
and working_unitcell.get_unique_tensors()[
829+
0
830+
].is_triangular_peps
831+
):
832+
contraction = "precondition_operator_triangular"
833+
elif (
834+
hasattr(
835+
working_unitcell.get_unique_tensors()[0],
836+
"is_split_transfer",
837+
)
838+
and working_unitcell.get_unique_tensors()[
839+
0
840+
].is_split_transfer
841+
):
842+
contraction = "precondition_operator_split_transfer"
843+
else:
844+
contraction = "precondition_operator"
845+
846+
grad_norm_squared = 1e-2 * (
847+
jnp.linalg.norm(jnp.asarray(working_gradient)) ** 2
848+
)
849+
850+
tmp_descent_dir = [
851+
jax.scipy.sparse.linalg.gmres(
852+
lambda x: (
853+
apply_contraction_jitted(
854+
contraction, (te.tensor,), (te,), (x,)
855+
)
856+
+ grad_norm_squared * x
857+
),
858+
xe,
859+
xe,
860+
restart=varipeps_config.optimizer_precond_gmres_krylov_subspace_size,
861+
maxiter=varipeps_config.optimizer_precond_gmres_maxiter,
862+
solve_method="incremental",
863+
)[0]
864+
for te, xe in zip(
865+
working_unitcell.get_unique_tensors(),
866+
descent_dir,
867+
strict=True,
868+
)
869+
]
870+
if all(
871+
jnp.sum(xe * x2e.conj()) >= 0
872+
for xe, x2e in zip(
873+
descent_dir, tmp_descent_dir, strict=True
874+
)
875+
):
876+
descent_dir = tmp_descent_dir
877+
else:
878+
tqdm.write("Warning: Non-positive preconditioner")
879+
del contraction
880+
del grad_norm_squared
756881
else:
757882
descent_dir = _l_bfgs_workhorse(
758-
tuple(l_bfgs_x_cache), tuple(l_bfgs_grad_cache)
883+
tuple(l_bfgs_x_cache),
884+
tuple(l_bfgs_grad_cache),
885+
working_unitcell.get_unique_tensors(),
886+
varipeps_config,
759887
)
760888
else:
761889
raise ValueError("Unknown optimization method.")
@@ -767,6 +895,8 @@ def random_noise(a):
767895
descent_dir = [-elem for elem in working_gradient]
768896

769897
conv = jnp.linalg.norm(ravel_pytree(working_gradient)[0])
898+
if jnp.isinf(conv) or jnp.isnan(conv):
899+
conv = 0
770900
step_conv[random_noise_retries].append(conv)
771901

772902
try:

0 commit comments

Comments
 (0)