-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathconfig.py
More file actions
497 lines (448 loc) · 20.9 KB
/
config.py
File metadata and controls
497 lines (448 loc) · 20.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
from dataclasses import dataclass
from enum import Enum, IntEnum, auto, unique
from typing import TypeVar, Tuple, Any, Type, NoReturn
import logging
import numpy as np
from jax.tree_util import register_pytree_node_class
T_VariPEPS_Config = TypeVar("T_VariPEPS_Config", bound="VariPEPS_Config")
@unique
class Optimizing_Methods(IntEnum):
STEEPEST = auto() #: Steepest gradient descent
CG = auto() #: Conjugate gradient method
BFGS = auto() #: BFGS method
L_BFGS = auto() #: L-BFGS method
@unique
class Line_Search_Methods(IntEnum):
SIMPLE = auto() #: Simple line search method
ARMIJO = auto() #: Armijo line search method
WOLFE = auto() #: Wolfe line search method
HAGERZHANG = auto() #: Hager-Zhang line search method
@unique
class Projector_Method(IntEnum):
HALF = auto() #: Use only half network for projector calculation
FULL = auto() #: Use full network for projector calculation
FISHMAN = auto() #: Use the Fishman method for projector calculation
HALF_FISHMAN = auto() #: Use the Fishman method but with half projectors as basis
@unique
class Wavevector_Type(IntEnum):
TWO_PI_POSITIVE_ONLY = auto() #: Use interval [0, 2pi) for q vectors
TWO_PI_SYMMETRIC = auto() #: Use interval [-2pi, 2pi) for q vectors
@unique
class Slurm_Restart_Mode(IntEnum):
DISABLED = (
auto()
) #: Disable automatic restart of slurm job if maximal runtime limit is reached
WRITE_NEED_RESTART_FILE = (
auto()
) #: Write file to indicate that restart is needed but no slurm scripts
WRITE_RESTART_SCRIPT = (
auto()
) #: Write slurm restart script but do not submit new slurm job
AUTOMATIC_RESTART = auto() #: Write restart script and start new slurm job with it
@unique
class LogLevel(IntEnum):
OFF = 0
ERROR = logging.ERROR
WARNING = logging.WARNING
INFO = logging.INFO
DEBUG = logging.DEBUG
@dataclass
@register_pytree_node_class
class VariPEPS_Config:
"""
Config class for varipeps module. Normally only the blow created instance
:obj:`config` is used.
Parameters:
ad_use_custom_vjp (:obj:`bool`):
Use custom VJP rule for the CTMRG routine during AD calculation.
ad_custom_print_steps (:obj:`bool`):
Print steps of fix-point iteration in custom VJP function.
ad_custom_verbose_output (:obj:`bool`):
Print verbose output in custom VJP function.
ad_custom_convergence_eps (:obj:`float`):
Convergence criterion for the custom VJP function.
ad_custom_max_steps (:obj:`int`):
Maximal number of steps for fix-pointer iteration of the custom VJP
function.
checkpointing_ncon (:obj:`bool`):
Enable AD checkpointing for the ncon calls.
checkpointing_projectors (:obj:`bool`):
Enable AD checkpointing for the the calculation of the proejctors.
ctmrg_convergence_eps (:obj:`float`):
Convergence criterion for the CTMRG routine.
ctmrg_enforce_elementwise_convergence (:obj:`bool`):
Enforce elementwise convergence of the CTM tensors instead of only
convergence of the singular values of the corners.
ctmrg_max_steps (:obj:`int`):
Maximal number of steps for fix-pointer iteration of the CTMRG routine.
ctmrg_print_steps (:obj:`bool`):
Print steps of fix-point iteration in CTMRG routine.
ctmrg_verbose_output (:obj:`bool`):
Print verbose output in CTMRG routine.
ctmrg_truncation_eps (:obj:`float`):
Value for cut off of the singular values compared to the
biggest one. Used in the calculation of the CTMRG projectors.
ctmrg_fail_if_not_converged (:obj:`bool`):
Flag if the CTMRG routine should fail with an error if no convergence
can be reached within the maximal number of steps.
If disabled, the result converged so far is returned.
ctmrg_full_projector_method (:obj:`~varipeps.config.Projector_Method`):
Set which projector method should be used as default (full) projector
method during the CTMRG routine. Sensible values are
:obj:`~varipeps.config.Projector_Method.FULL` or
:obj:`~varipeps.config.Projector_Method.FISHMAN`.
ctmrg_increase_truncation_eps (:obj:`bool`):
Flag if the CTMRG routine should try higher truncation thresholds for
the SVD based projector methods if the routine does not converge in the
maximum number of steps.
ctmrg_increase_truncation_eps_factor (:obj:`float`):
Factor by which the truncation threshold should be increased.
ctmrg_increase_truncation_eps_max_value (:obj:`float`):
Maximal value for the truncation threshold. Do not increase higher than
this value.
ctmrg_heuristic_increase_chi (:obj:`bool`):
Flag if the CTMRG routine should try higher environment bond dimension
for if the routine found singular values above a threshold during the
projector calculation of the last absorption step.
ctmrg_heuristic_increase_chi_threshold (:obj:`float`):
Threshold for the heuristic environment bond dimension increase.
ctmrg_heuristic_increase_chi_step_size (:obj:`int`):
Step size for the heuristic environment bond dimension increase.
ctmrg_heuristic_decrease_chi (:obj:`bool`):
Flag if the CTMRG routine should try lower environment bond dimension
for if the routine found singular values below the SVD threshold during
the projector calculation of the last absorption step.
ctmrg_heuristic_decrease_chi_step_size (:obj:`int`):
Step size for the heuristic environment bond dimension decrease.
triangular_ctmrg_use_split (:obj:`bool`):
Flag if the split projector method should be used in the
triangular CTMRG.
svd_sign_fix_eps (:obj:`float`):
Value for numerical stability threshold in sign-fixed SVD.
svd_ad_use_lorentz_broadening (:obj:`bool`):
Enable Lorentz broadening in the AD rule for the SVD.
svd_ad_lorentz_broadening_eps (:obj:`float`):
Numerical stabilization constant in the Lorentz broadening in the
AD rule for the SVD.
optimizer_method (:obj:`Optimizing_Methods`):
Method used for variational optimization of the PEPS network.
optimizer_max_steps (:obj:`int`):
Maximal number of steps for fix-pointer iteration in optimization routine.
optimizer_convergence_eps (:obj:`float`):
Convergence criterion for the optimization routine.
optimizer_ctmrg_preconverged_eps (:obj:`float`):
Convergence criterion for the optimization routine using the gradient
calculations with the preconverged environment.
optimizer_fail_if_no_step_size_found (:obj:`bool`):
Flag if the optimizer routine should fail with an error if no step size
can be found before the gradient norm is below the convergence
threshold. If disabled, the result converged so far is returned.
optimizer_l_bfgs_maxlen (:obj:`int`):
Maximal number of previous steps used for the L-BFGS method.
optimizer_preconverge_with_half_projectors (:obj:`bool`):
Flag if the optimizer should use only CTM half projectors for the steps
till some converge is reached.
optimizer_preconverge_with_half_projectors_eps (:obj:`float`):
Convergence criterion for the preconvergence with only the half
CTM projectors.
optimizer_autosave_step_count (:obj:`int`):
Step count after which the optimizer result is automatically saved.
optimizer_random_noise_eps (:obj:`float`):
Optimizer should try best state sofar with some random noise if
gradient norm is below this threshold.
optimizer_random_noise_max_retries (:obj:`int`):
Maximal retries for optimization with random noise.
optimizer_random_noise_relative_amplitude (:obj:`float`):
Relative amplitude used for random noise.
optimizer_reuse_env_eps (:obj:`float`):
Reuse CTMRG environment of previous step if norm of gradient is below
this threshold.
optimizer_use_preconditioning (:obj:`bool`):
Use (local) preconditioning method as described in
https://arxiv.org/abs/2511.09546.
optimizer_precond_gmres_krylov_subspace_size (:obj:`int`):
Size of Krylov subspace built up during GMRES method for the inversion
of the preconditioner.
optimizer_precond_gmres_maxiter (:obj:`int`):
Maximal number of outer iterations inside the GMRES method for the
inversion of the preconditioner.
line_search_method (:obj:`Line_Search_Methods`):
Method used for the line search routine.
line_search_initial_step_size (:obj:`float`):
Initial step size for the line search routine.
line_search_reduction_factor (:obj:`float`):
Reduction factor between two line search steps.
line_search_max_steps (:obj:`int`):
Maximal number of steps in the line search routine.
line_search_armijo_const (:obj:`float`):
Constant used in Armijo line search method.
line_search_wolfe_const (:obj:`float`):
Constant used in Wolfe line search method.
line_search_use_last_step_size (:obj:`bool`):
Flag if the line search should start from the step size of the
previous optimizer step.
line_search_hager_zhang_quad_step (:obj:`bool`):
Use QuadStep method in Hager-Zhang line search to find initial
step size.
line_search_hager_zhang_delta (:obj:`float`):
Constant used in Hager-Zhang line search method.
line_search_hager_zhang_sigma (:obj:`float`):
Constant used in Hager-Zhang line search method.
line_search_hager_zhang_psi_0 (:obj:`float`):
Constant used in Hager-Zhang line search method.
line_search_hager_zhang_psi_1 (:obj:`float`):
Constant used in Hager-Zhang line search method.
line_search_hager_zhang_psi_2 (:obj:`float`):
Constant used in Hager-Zhang line search method.
line_search_hager_zhang_eps (:obj:`float`):
Constant used in Hager-Zhang line search method.
line_search_hager_zhang_theta (:obj:`float`):
Constant used in Hager-Zhang line search method.
line_search_hager_zhang_gamma (:obj:`float`):
Constant used in Hager-Zhang line search method.
line_search_hager_zhang_rho (:obj:`float`):
Constant used in Hager-Zhang line search method.
line_search_hager_zhang_eps_use_grad_norm (:obj:`bool`):
Use norm of gradient multiplied by
:obj:`VariPEPS_Config.line_search_hager_zhang_eps_grad_norm_factor` to
calculate eps value in Hager-Zhang line search. If disabled, the fixed
value from config parameter
:obj:`VariPEPS_Config.line_search_hager_zhang_eps` is used.
line_search_hager_zhang_eps_grad_norm_factor (:obj:`float`):
Factor used for gradient based eps calculation. See parameter
:obj:`VariPEPS_Config.line_search_hager_zhang_eps_use_grad_norm`
for details.
basinhopping_niter (:obj:`int`):
Value for parameter `niter` of :obj:`scipy.optimize.basinhopping`.
See this function for details.
basinhopping_T (:obj:`int`):
Value for parameter `T` of :obj:`scipy.optimize.basinhopping`.
See this function for details.
basinhopping_niter_success (:obj:`int`):
Value for parameter `niterniter_success` of
:obj:`scipy.optimize.basinhopping`. See this function for details.
spiral_wavevector_type (:obj:`Wavevector_Type`):
Type of wavevector to be used (only positive/symmetric interval/...).
slurm_restart_mode (:obj:`Slurm_Restart_Mode`):
Mode of operation to restart slurm job if maximal runtime is reached.
log_level_global (:obj:`LogLevel`):
Global logging level for the 'varipeps' package logger.
log_level_optimizer (:obj:`LogLevel`):
Logging level for 'varipeps.optimizer'.
log_level_ctmrg (:obj:`LogLevel`):
Logging level for 'varipeps.ctmrg'.
log_level_line_search (:obj:`LogLevel`):
Logging level for 'varipeps.line_search'.
log_level_expectation (:obj:`LogLevel`):
Logging level for 'varipeps.expectation'.
log_to_console (:obj:`bool`):
Enable standard console logging (StreamHandler).
Ignored when :obj:`VariPEPS_Config.log_tqdm` is True.
log_to_file (:obj:`bool`):
Enable logging to file.
log_file (:obj:`str`):
Filename for logging to file (used when :obj:`VariPEPS_Config.log_to_file` is True).
log_tqdm (:obj:`bool`):
Enable tqdm-based console logging. If True, messages from
'varipeps.optimizer' update a tqdm progress bar, while other modules
log via tqdm.write. File logging settings still apply.
"""
# AD config
ad_use_custom_vjp: bool = True
ad_custom_print_steps: bool = False
ad_custom_verbose_output: bool = False
ad_custom_convergence_eps: float = 1e-7
ad_custom_max_steps: int = 75
checkpointing_ncon: bool = False
checkpointing_projectors: bool = False
# CTMRG routine
ctmrg_convergence_eps: float = 1e-8
ctmrg_enforce_elementwise_convergence: bool = True
ctmrg_max_steps: int = 75
ctmrg_print_steps: bool = False
ctmrg_verbose_output: bool = False
ctmrg_truncation_eps: float = 1e-12
ctmrg_fail_if_not_converged: bool = True
ctmrg_full_projector_method: Projector_Method = Projector_Method.FISHMAN
ctmrg_increase_truncation_eps: bool = True
ctmrg_increase_truncation_eps_factor: float = 100.0
ctmrg_increase_truncation_eps_max_value: float = 1e-6
ctmrg_heuristic_increase_chi: bool = True
ctmrg_heuristic_increase_chi_threshold: float = 1e-6
ctmrg_heuristic_increase_chi_step_size: int = 2
ctmrg_heuristic_decrease_chi: bool = True
ctmrg_heuristic_decrease_chi_step_size: int = 1
# Triangular CTMRG routine
triangular_ctmrg_use_split: bool = False
# SVD
svd_sign_fix_eps: float = 1e-1
svd_ad_use_lorentz_broadening: bool = False
svd_ad_lorentz_broadening_eps: float = 1e-13
# Optimizer
optimizer_method: Optimizing_Methods = Optimizing_Methods.L_BFGS
optimizer_max_steps: int = 300
optimizer_convergence_eps: float = 1e-5
optimizer_ctmrg_preconverged_eps: float = 1e-5
optimizer_fail_if_no_step_size_found: bool = False
optimizer_l_bfgs_maxlen: int = 15
optimizer_preconverge_with_half_projectors: bool = False
optimizer_preconverge_with_half_projectors_eps: float = 1e-3
optimizer_autosave_step_count: int = 2
optimizer_random_noise_eps: float = 1e-4
optimizer_random_noise_max_retries: int = 5
optimizer_random_noise_relative_amplitude: float = 1e-1
optimizer_reuse_env_eps: float = 1e-3
optimizer_use_preconditioning: bool = True
optimizer_precond_gmres_krylov_subspace_size: int = 30
optimizer_precond_gmres_maxiter: int = 3
# Line search
line_search_method: Line_Search_Methods = Line_Search_Methods.HAGERZHANG
line_search_initial_step_size: float = 1.0
line_search_reduction_factor: float = 0.5
line_search_max_steps: int = 40
line_search_armijo_const: float = 1e-4
line_search_wolfe_const: float = 0.9
line_search_use_last_step_size: bool = False
line_search_hager_zhang_quad_step: bool = True
line_search_hager_zhang_delta: float = 0.1
line_search_hager_zhang_sigma: float = 0.9
line_search_hager_zhang_psi_0: float = 0.01
line_search_hager_zhang_psi_1: float = 0.1
line_search_hager_zhang_psi_2: float = 2.0
line_search_hager_zhang_eps: float = 1e-6
line_search_hager_zhang_theta: float = 0.5
line_search_hager_zhang_gamma: float = 0.66
line_search_hager_zhang_rho: float = 5
line_search_hager_zhang_eps_use_grad_norm: bool = True
line_search_hager_zhang_eps_grad_norm_factor: float = 1e-2
# Basinhopping
basinhopping_niter: int = 20
basinhopping_T: float = 0.001
basinhopping_niter_success: int = 5
# Spiral PEPS
spiral_wavevector_type: Wavevector_Type = Wavevector_Type.TWO_PI_POSITIVE_ONLY
# Slurm
slurm_restart_mode: Slurm_Restart_Mode = Slurm_Restart_Mode.WRITE_NEED_RESTART_FILE
# Logging configuration
log_level_global: LogLevel = LogLevel.INFO
log_level_optimizer: LogLevel = LogLevel.INFO
log_level_ctmrg: LogLevel = LogLevel.INFO
log_level_line_search: LogLevel = LogLevel.INFO
log_level_expectation: LogLevel = LogLevel.INFO
log_to_console: bool = True
log_to_file: bool = False
log_file: str = "varipeps.log"
log_tqdm: bool = False #: Enable tqdm-based console logging
def update(self, name: str, value: Any) -> NoReturn:
self.__setattr__(name, value)
def __setattr__(self, name: str, value: Any) -> NoReturn:
try:
field = self.__dataclass_fields__[name]
except KeyError as e:
raise KeyError(f"Unknown config option '{name}'.") from e
if not type(value) is field.type:
if field.type is float and type(value) is int:
value = float(value)
elif (
field.type is float
and hasattr(value, "dtype")
and (
np.issubdtype(value.dtype, np.floating)
or np.issubdtype(value.dtype, np.integer)
)
and value.size == 1
):
if value.ndim > 0:
value = value.reshape(-1)[0]
value = float(value)
elif (
field.type is int
and hasattr(value, "dtype")
and np.issubdtype(value.dtype, np.integer)
and value.size == 1
):
if value.ndim > 0:
value = value.reshape(-1)[0]
value = int(value)
elif (
field.type is bool
and hasattr(value, "dtype")
and np.issubdtype(value.dtype, np.bool_)
and value.size == 1
):
if value.ndim > 0:
value = value.reshape(-1)[0]
value = bool(value)
elif isinstance(field.type, type) and issubclass(field.type, Enum):
# Accept ints/np.int64 or enum names for Enum fields
if isinstance(value, field.type):
pass
elif isinstance(value, (int,)) or (
hasattr(value, "dtype")
and np.issubdtype(value.dtype, np.integer)
and value.size == 1
):
if hasattr(value, "ndim") and value.ndim > 0:
value = value.reshape(-1)[0]
value = field.type(int(value))
elif isinstance(value, str):
try:
value = field.type[value]
except KeyError:
value = field.type(int(value))
else:
raise TypeError(
f"Type mismatch for option '{name}', got '{type(value)}', expected '{field.type}'."
)
else:
raise TypeError(
f"Type mismatch for option '{name}', got '{type(value)}', expected '{field.type}'."
)
super().__setattr__(name, value)
def update_from_config_dict(self, new_config):
for k in self.__dataclass_fields__:
setattr(self, k, new_config[k])
def update_from_config_object(self, new_config):
for k in self.__dataclass_fields__:
setattr(self, k, getattr(new_config, k))
def tree_flatten(self) -> Tuple[Tuple[Any, ...], Tuple[Any, ...]]:
aux_data = (
{name: getattr(self, name) for name in self.__dataclass_fields__.keys()},
)
return ((), aux_data)
@classmethod
def tree_unflatten(
cls: Type[T_VariPEPS_Config],
aux_data: Tuple[Any, ...],
children: Tuple[Any, ...],
) -> T_VariPEPS_Config:
(data_dict,) = aux_data
return cls(**data_dict)
config = VariPEPS_Config()
class ConfigModuleWrapper:
__slots__ = {
"Optimizing_Methods",
"Line_Search_Methods",
"Projector_Method",
"Wavevector_Type",
"Slurm_Restart_Mode",
"LogLevel",
"VariPEPS_Config",
"config",
}
def __init__(self):
for e in self.__slots__:
setattr(self, e, globals()[e])
def __getattr__(self, name: str) -> Any:
if name.startswith("__") or name in self.__slots__:
return super().__getattr__(name)
else:
return getattr(self.config, name)
def __setattr__(self, name: str, value: Any) -> NoReturn:
if not name.startswith("__") and name not in self.__slots__:
setattr(self.config, name, value)
elif not hasattr(self, name):
super().__setattr__(name, value)
else:
raise AttributeError(f"Attribute '{name}' is write-protected.")
wrapper = ConfigModuleWrapper()