-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathroutine.py
More file actions
1066 lines (943 loc) · 36.1 KB
/
routine.py
File metadata and controls
1066 lines (943 loc) · 36.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from functools import partial
import enum
import jax.numpy as jnp
from jax import jit, custom_vjp, vjp, tree_util
from jax.lax import cond, while_loop
import jax.debug as jdebug
import logging
import time
import jax
logger = logging.getLogger("varipeps.ctmrg")
from varipeps import varipeps_config, varipeps_global_state
from varipeps.peps import PEPS_Tensor, PEPS_Tensor_Split_Transfer, PEPS_Unit_Cell
from varipeps.utils.debug_print import debug_print
from .absorption import do_absorption_step, do_absorption_step_split_transfer
from .triangular_absorption import do_absorption_step_triangular
from typing import Sequence, Tuple, List, Optional
@enum.unique
class CTM_Enum(enum.IntEnum):
C1 = enum.auto()
C2 = enum.auto()
C3 = enum.auto()
C4 = enum.auto()
T1 = enum.auto()
T2 = enum.auto()
T3 = enum.auto()
T4 = enum.auto()
T1_ket = enum.auto()
T1_bra = enum.auto()
T2_ket = enum.auto()
T2_bra = enum.auto()
T3_ket = enum.auto()
T3_bra = enum.auto()
T4_ket = enum.auto()
T4_bra = enum.auto()
C5 = enum.auto()
C6 = enum.auto()
T1a = enum.auto()
T1b = enum.auto()
T2a = enum.auto()
T2b = enum.auto()
T3a = enum.auto()
T3b = enum.auto()
T4a = enum.auto()
T4b = enum.auto()
T5a = enum.auto()
T5b = enum.auto()
T6a = enum.auto()
T6b = enum.auto()
class CTMRGNotConvergedError(Exception):
"""
Exception if the CTM routine does not converge.
"""
pass
class CTMRGGradientNotConvergedError(Exception):
"""
Exception if the custom rule for the gradient of the the CTM routine does
not converge.
"""
pass
@partial(jit, static_argnums=(2,), inline=True)
def _calc_corner_svds(
peps_tensors: List[PEPS_Tensor],
old_corner_svd: jnp.ndarray,
tensor_shape: Optional[Tuple[int, int, int]],
) -> jnp.ndarray:
if tensor_shape is None:
step_corner_svd = jnp.zeros_like(old_corner_svd)
else:
step_corner_svd = jnp.zeros(tensor_shape, dtype=jnp.float64)
for ti, t in enumerate(peps_tensors):
C1_svd = jnp.linalg.svd(t.C1, full_matrices=False, compute_uv=False)
step_corner_svd = step_corner_svd.at[ti, 0, : C1_svd.shape[0]].set(
C1_svd, indices_are_sorted=True, unique_indices=True
)
C2_svd = jnp.linalg.svd(t.C2, full_matrices=False, compute_uv=False)
step_corner_svd = step_corner_svd.at[ti, 1, : C2_svd.shape[0]].set(
C2_svd, indices_are_sorted=True, unique_indices=True
)
C3_svd = jnp.linalg.svd(t.C3, full_matrices=False, compute_uv=False)
step_corner_svd = step_corner_svd.at[ti, 2, : C3_svd.shape[0]].set(
C3_svd, indices_are_sorted=True, unique_indices=True
)
C4_svd = jnp.linalg.svd(t.C4, full_matrices=False, compute_uv=False)
step_corner_svd = step_corner_svd.at[ti, 3, : C4_svd.shape[0]].set(
C4_svd, indices_are_sorted=True, unique_indices=True
)
return step_corner_svd
@partial(jit, static_argnums=(3,), inline=True)
def _is_element_wise_converged(
old_peps_tensors: List[PEPS_Tensor],
new_peps_tensors: List[PEPS_Tensor],
eps: float,
split_transfer: bool = False,
) -> Tuple[bool, float, Optional[List[Tuple[int, CTM_Enum, float]]]]:
result = 0
if split_transfer:
measure = jnp.zeros((len(old_peps_tensors), 12), dtype=jnp.float64)
else:
measure = jnp.zeros((len(old_peps_tensors), 8), dtype=jnp.float64)
verbose_data = []
for ti in range(len(old_peps_tensors)):
old_shape = old_peps_tensors[ti].C1.shape
new_shape = new_peps_tensors[ti].C1.shape
diff = jnp.abs(
new_peps_tensors[ti].C1[: old_shape[0], : old_shape[1]]
- old_peps_tensors[ti].C1[: new_shape[0], : new_shape[1]]
)
result += jnp.sum(diff > eps)
measure = measure.at[ti, 0].set(
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
)
verbose_data.append((ti, CTM_Enum.C1, jnp.amax(diff)))
old_shape = old_peps_tensors[ti].C2.shape
new_shape = new_peps_tensors[ti].C2.shape
diff = jnp.abs(
new_peps_tensors[ti].C2[: old_shape[0], : old_shape[1]]
- old_peps_tensors[ti].C2[: new_shape[0], : new_shape[1]]
)
result += jnp.sum(diff > eps)
measure = measure.at[ti, 1].set(
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
)
verbose_data.append((ti, CTM_Enum.C2, jnp.amax(diff)))
old_shape = old_peps_tensors[ti].C3.shape
new_shape = new_peps_tensors[ti].C4.shape
diff = jnp.abs(
new_peps_tensors[ti].C3[: old_shape[0], : old_shape[1]]
- old_peps_tensors[ti].C3[: new_shape[0], : new_shape[1]]
)
result += jnp.sum(diff > eps)
measure = measure.at[ti, 2].set(
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
)
verbose_data.append((ti, CTM_Enum.C3, jnp.amax(diff)))
old_shape = old_peps_tensors[ti].C4.shape
new_shape = new_peps_tensors[ti].C4.shape
diff = jnp.abs(
new_peps_tensors[ti].C4[: old_shape[0], : old_shape[1]]
- old_peps_tensors[ti].C4[: new_shape[0], : new_shape[1]]
)
result += jnp.sum(diff > eps)
measure = measure.at[ti, 3].set(
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
)
verbose_data.append((ti, CTM_Enum.C4, jnp.amax(diff)))
if split_transfer:
old_shape = old_peps_tensors[ti].T1_ket.shape
new_shape = new_peps_tensors[ti].T1_ket.shape
diff = jnp.abs(
new_peps_tensors[ti].T1_ket[
: old_shape[0], : old_shape[1], : old_shape[2]
]
- old_peps_tensors[ti].T1_ket[
: new_shape[0], : new_shape[1], : new_shape[2]
]
)
result += jnp.sum(diff > eps)
measure = measure.at[ti, 4].set(
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
)
verbose_data.append((ti, CTM_Enum.T1_ket, jnp.amax(diff)))
old_shape = old_peps_tensors[ti].T1_bra.shape
new_shape = new_peps_tensors[ti].T1_bra.shape
diff = jnp.abs(
new_peps_tensors[ti].T1_bra[
: old_shape[0], : old_shape[1], : old_shape[2]
]
- old_peps_tensors[ti].T1_bra[
: new_shape[0], : new_shape[1], : new_shape[2]
]
)
result += jnp.sum(diff > eps)
measure = measure.at[ti, 5].set(
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
)
verbose_data.append((ti, CTM_Enum.T1_bra, jnp.amax(diff)))
old_shape = old_peps_tensors[ti].T2_ket.shape
new_shape = new_peps_tensors[ti].T2_ket.shape
diff = jnp.abs(
new_peps_tensors[ti].T2_ket[
: old_shape[0], : old_shape[1], : old_shape[2]
]
- old_peps_tensors[ti].T2_ket[
: new_shape[0], : new_shape[1], : new_shape[2]
]
)
result += jnp.sum(diff > eps)
measure = measure.at[ti, 6].set(
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
)
verbose_data.append((ti, CTM_Enum.T2_ket, jnp.amax(diff)))
old_shape = old_peps_tensors[ti].T2_bra.shape
new_shape = new_peps_tensors[ti].T2_bra.shape
diff = jnp.abs(
new_peps_tensors[ti].T2_bra[
: old_shape[0], : old_shape[1], : old_shape[2]
]
- old_peps_tensors[ti].T2_bra[
: new_shape[0], : new_shape[1], : new_shape[2]
]
)
result += jnp.sum(diff > eps)
measure = measure.at[ti, 7].set(
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
)
verbose_data.append((ti, CTM_Enum.T2_bra, jnp.amax(diff)))
old_shape = old_peps_tensors[ti].T3_ket.shape
new_shape = new_peps_tensors[ti].T3_ket.shape
diff = jnp.abs(
new_peps_tensors[ti].T3_ket[
: old_shape[0], : old_shape[1], : old_shape[2]
]
- old_peps_tensors[ti].T3_ket[
: new_shape[0], : new_shape[1], : new_shape[2]
]
)
result += jnp.sum(diff > eps)
measure = measure.at[ti, 8].set(
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
)
verbose_data.append((ti, CTM_Enum.T3_ket, jnp.amax(diff)))
old_shape = old_peps_tensors[ti].T3_bra.shape
new_shape = new_peps_tensors[ti].T3_bra.shape
diff = jnp.abs(
new_peps_tensors[ti].T3_bra[
: old_shape[0], : old_shape[1], : old_shape[2]
]
- old_peps_tensors[ti].T3_bra[
: new_shape[0], : new_shape[1], : new_shape[2]
]
)
result += jnp.sum(diff > eps)
measure = measure.at[ti, 9].set(
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
)
verbose_data.append((ti, CTM_Enum.T3_bra, jnp.amax(diff)))
old_shape = old_peps_tensors[ti].T4_ket.shape
new_shape = new_peps_tensors[ti].T4_ket.shape
diff = jnp.abs(
new_peps_tensors[ti].T4_ket[
: old_shape[0], : old_shape[1], : old_shape[2]
]
- old_peps_tensors[ti].T4_ket[
: new_shape[0], : new_shape[1], : new_shape[2]
]
)
result += jnp.sum(diff > eps)
measure = measure.at[ti, 10].set(
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
)
verbose_data.append((ti, CTM_Enum.T4_ket, jnp.amax(diff)))
old_shape = old_peps_tensors[ti].T4_bra.shape
new_shape = new_peps_tensors[ti].T4_bra.shape
diff = jnp.abs(
new_peps_tensors[ti].T4_bra[
: old_shape[0], : old_shape[1], : old_shape[2]
]
- old_peps_tensors[ti].T4_bra[
: new_shape[0], : new_shape[1], : new_shape[2]
]
)
result += jnp.sum(diff > eps)
measure = measure.at[ti, 11].set(
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
)
verbose_data.append((ti, CTM_Enum.T4_bra, jnp.amax(diff)))
else:
old_shape = old_peps_tensors[ti].T1.shape
new_shape = new_peps_tensors[ti].T1.shape
diff = jnp.abs(
new_peps_tensors[ti].T1[
: old_shape[0], : old_shape[1], : old_shape[2], : old_shape[3]
]
- old_peps_tensors[ti].T1[
: new_shape[0], : new_shape[1], : new_shape[2], : new_shape[3]
]
)
result += jnp.sum(diff > eps)
measure = measure.at[ti, 4].set(
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
)
verbose_data.append((ti, CTM_Enum.T1, jnp.amax(diff)))
old_shape = old_peps_tensors[ti].T2.shape
new_shape = new_peps_tensors[ti].T2.shape
diff = jnp.abs(
new_peps_tensors[ti].T2[
: old_shape[0], : old_shape[1], : old_shape[2], : old_shape[3]
]
- old_peps_tensors[ti].T2[
: new_shape[0], : new_shape[1], : new_shape[2], : new_shape[3]
]
)
result += jnp.sum(diff > eps)
measure = measure.at[ti, 5].set(
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
)
verbose_data.append((ti, CTM_Enum.T2, jnp.amax(diff)))
old_shape = old_peps_tensors[ti].T3.shape
new_shape = new_peps_tensors[ti].T3.shape
diff = jnp.abs(
new_peps_tensors[ti].T3[
: old_shape[0], : old_shape[1], : old_shape[2], : old_shape[3]
]
- old_peps_tensors[ti].T3[
: new_shape[0], : new_shape[1], : new_shape[2], : new_shape[3]
]
)
result += jnp.sum(diff > eps)
measure = measure.at[ti, 6].set(
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
)
verbose_data.append((ti, CTM_Enum.T3, jnp.amax(diff)))
old_shape = old_peps_tensors[ti].T4.shape
new_shape = new_peps_tensors[ti].T4.shape
diff = jnp.abs(
new_peps_tensors[ti].T4[
: old_shape[0], : old_shape[1], : old_shape[2], : old_shape[3]
]
- old_peps_tensors[ti].T4[
: new_shape[0], : new_shape[1], : new_shape[2], : new_shape[3]
]
)
result += jnp.sum(diff > eps)
measure = measure.at[ti, 7].set(
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
)
verbose_data.append((ti, CTM_Enum.T4, jnp.amax(diff)))
return result == 0, jnp.linalg.norm(measure), verbose_data
@partial(jit, inline=True)
def _is_element_wise_converged_triangular(
old_peps_tensors: List[PEPS_Tensor],
new_peps_tensors: List[PEPS_Tensor],
eps: float,
):
result = 0
measure = jnp.zeros((len(old_peps_tensors), 18), dtype=jnp.float64)
verbose_data = []
for ti in range(len(old_peps_tensors)):
for ni, name in enumerate(
(
"C1",
"C2",
"C3",
"C4",
"C5",
"C6",
"T1a",
"T1b",
"T2a",
"T2b",
"T3a",
"T3b",
"T4a",
"T4b",
"T5a",
"T5b",
"T6a",
"T6b",
)
):
old_shape = getattr(old_peps_tensors[ti], name).shape
new_shape = getattr(new_peps_tensors[ti], name).shape
diff = jnp.abs(
getattr(new_peps_tensors[ti], name)[
: old_shape[0], : old_shape[1], : old_shape[2], : old_shape[3]
]
- getattr(old_peps_tensors[ti], name)[
: new_shape[0], : new_shape[1], : new_shape[2], : new_shape[3]
]
)
result += jnp.sum(diff > eps)
measure = measure.at[ti, ni].set(
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
)
verbose_data.append((ti, getattr(CTM_Enum, name), jnp.amax(diff)))
return result == 0, jnp.linalg.norm(measure), verbose_data
def print_verbose(verbose_data, *, ad=False):
if ad:
message = "Custom VJP: Verbose: ti {}, CTM tensor {}, Diff {}"
else:
message = "CTMRG: Verbose: ti {}, CTM tensor {}, Diff {}"
for ti, ctm_enum_i, diff in verbose_data:
debug_print(
message,
ti,
CTM_Enum(ctm_enum_i).name,
diff,
)
@jit
def _ctmrg_body_func(carry):
(
w_tensors,
w_unitcell_last_step,
converged,
last_corner_svd,
eps,
count,
elementwise_conv,
norm_smallest_S,
state,
config,
) = carry
if w_unitcell_last_step.is_triangular_peps():
w_unitcell, norm_smallest_S = do_absorption_step_triangular(
w_tensors, w_unitcell_last_step, config, state
)
elif w_unitcell_last_step.is_split_transfer():
w_unitcell, norm_smallest_S = do_absorption_step_split_transfer(
w_tensors, w_unitcell_last_step, config, state
)
else:
w_unitcell, norm_smallest_S = do_absorption_step(
w_tensors, w_unitcell_last_step, config, state
)
def elementwise_func(old, new, old_corner, conv_eps, config):
if w_unitcell_last_step.is_triangular_peps():
converged, measure, verbose_data = _is_element_wise_converged_triangular(
old,
new,
conv_eps,
)
return converged, measure, verbose_data, old_corner
converged, measure, verbose_data = _is_element_wise_converged(
old,
new,
conv_eps,
split_transfer=w_unitcell.is_split_transfer(),
)
return converged, measure, verbose_data, old_corner
def corner_svd_func(old, new, old_corner, conv_eps, config):
if w_unitcell_last_step.is_triangular_peps():
verbose_data = (
[(jnp.array(0), jnp.array(0), jnp.array(0.0))] * 18 * len(w_tensors)
)
elif w_unitcell_last_step.is_split_transfer():
verbose_data = (
[(jnp.array(0), jnp.array(0), jnp.array(0.0))] * 12 * len(w_tensors)
)
else:
verbose_data = (
[(jnp.array(0), jnp.array(0), jnp.array(0.0))] * 8 * len(w_tensors)
)
if old_corner is None:
return (
False,
jnp.nan,
verbose_data,
old_corner,
)
corner_svd = _calc_corner_svds(new, old_corner, None)
measure = jnp.linalg.norm(corner_svd - old_corner)
converged = measure < conv_eps
return (
converged,
measure,
verbose_data,
corner_svd,
)
converged, measure, verbose_data, corner_svd = cond(
elementwise_conv,
elementwise_func,
corner_svd_func,
w_unitcell_last_step.get_unique_tensors(),
w_unitcell.get_unique_tensors(),
last_corner_svd,
eps,
config,
)
if logger.isEnabledFor(logging.DEBUG):
jax.debug.callback(lambda cnt, msr: logger.debug(f"CTMRG: Step {cnt}: {msr}"), count, measure, ordered=True)
if config.ctmrg_verbose_output:
jax.debug.callback(print_verbose, verbose_data, ordered=True)
count += 1
return (
w_tensors,
w_unitcell,
converged,
corner_svd,
eps,
count,
elementwise_conv,
norm_smallest_S,
state,
config,
)
@jit
def _ctmrg_while_wrapper(start_carry):
def cond_func(carry):
_, _, converged, _, _, count, _, _, _, config = carry
return jnp.logical_not(converged) & (count < config.ctmrg_max_steps)
(
_,
working_unitcell,
converged,
_,
_,
end_count,
_,
norm_smallest_S,
_,
_,
) = while_loop(cond_func, _ctmrg_body_func, start_carry)
return working_unitcell, converged, end_count, norm_smallest_S
def calc_ctmrg_env(
peps_tensors: Sequence[jnp.ndarray],
unitcell: PEPS_Unit_Cell,
*,
eps: Optional[float] = None,
enforce_elementwise_convergence: Optional[bool] = None,
_return_truncation_eps: bool = False,
) -> PEPS_Unit_Cell:
"""
Calculate the new converged CTMRG tensors for the unit cell. The function
updates the environment all iPEPS tensors in the unit cell according to the
periodic structure.
Args:
peps_tensors (:term:`sequence` of :obj:`jax.numpy.ndarray`):
The sequence of unique PEPS tensors the unitcell consists of.
unitcell (:obj:`~varipeps.peps.PEPS_Unit_Cell`):
The unitcell to work on.
Keyword args:
eps (:obj:`float`):
The convergence criterion.
enforce_elementwise_convergence (obj:`bool`):
Enforce elementwise convergence of the CTM tensors instead of only
convergence of the singular values of the corners.
Returns:
:obj:`~varipeps.peps.PEPS_Unit_Cell`:
New instance of the unitcell with all updated converged CTMRG tensors of
all elements of the unitcell.
"""
eps = eps if eps is not None else varipeps_config.ctmrg_convergence_eps
enforce_elementwise_convergence = (
enforce_elementwise_convergence
if enforce_elementwise_convergence is not None
else varipeps_config.ctmrg_enforce_elementwise_convergence
)
init_corner_singular_vals = None
if enforce_elementwise_convergence:
last_step_tensors = unitcell.get_unique_tensors()
else:
shape_corner_svd = (
unitcell.get_len_unique_tensors(),
4,
unitcell[0, 0][0][0].chi,
)
init_corner_singular_vals = _calc_corner_svds(
unitcell.get_unique_tensors(), None, shape_corner_svd
)
initial_unitcell = unitcell
working_unitcell = unitcell
varipeps_global_state.ctmrg_effective_truncation_eps = None
norm_smallest_S = jnp.nan
already_tried_chi = {working_unitcell[0, 0][0][0].chi}
best_chi = 0
best_result = None
best_norm_smallest_S = None
best_truncation_eps = None
have_been_increased = False
while True:
tmp_count = 0
t0 = time.perf_counter()
corner_singular_vals = None
while tmp_count < varipeps_config.ctmrg_max_steps and (
(
not working_unitcell.is_triangular_peps()
and any(
getattr(i, j).shape[0] != i.chi or getattr(i, j).shape[1] != i.chi
for i in working_unitcell.get_unique_tensors()
for j in ("C1", "C2", "C3", "C4")
)
)
or (
working_unitcell.is_split_transfer()
and any(
getattr(i, j).shape[0] != i.interlayer_chi
for i in working_unitcell.get_unique_tensors()
for j in ("T1_bra", "T2_ket", "T3_bra", "T4_ket")
)
)
or (
working_unitcell.is_triangular_peps()
and any(
getattr(i, j).shape[0] != i.chi or getattr(i, j).shape[3] != i.chi
for i in working_unitcell.get_unique_tensors()
for j in (
"C1",
"C2",
"C3",
"C4",
"C5",
"C6",
"T1a",
"T1b",
"T2a",
"T2b",
"T3a",
"T3b",
"T4a",
"T4b",
"T5a",
"T5b",
"T6a",
"T6b",
)
)
)
):
(
_,
working_unitcell,
_,
corner_singular_vals,
_,
tmp_count,
_,
norm_smallest_S,
_,
_,
) = _ctmrg_body_func(
(
peps_tensors,
working_unitcell,
False,
init_corner_singular_vals,
eps,
tmp_count,
enforce_elementwise_convergence,
jnp.inf,
varipeps_global_state,
varipeps_config,
)
)
if tmp_count < varipeps_config.ctmrg_max_steps:
working_unitcell, converged, end_count, norm_smallest_S = (
_ctmrg_while_wrapper(
(
peps_tensors,
working_unitcell,
False,
(
corner_singular_vals
if corner_singular_vals is not None
else init_corner_singular_vals
),
eps,
tmp_count,
enforce_elementwise_convergence,
jnp.inf,
varipeps_global_state,
varipeps_config,
)
)
)
else:
converged = False
end_count = tmp_count
if not converged and logger.isEnabledFor(logging.WARNING):
logger.warning(
"CTMRG: ❌ did not converge, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)",
time.perf_counter() - t0, end_count, norm_smallest_S
)
elif logger.isEnabledFor(logging.INFO):
logger.info(
"CTMRG: ✅ converged, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)",
time.perf_counter() - t0, end_count, norm_smallest_S
)
if converged and (
working_unitcell[0, 0][0][0].chi > best_chi or best_result is None
):
best_chi = working_unitcell[0, 0][0][0].chi
best_result = working_unitcell
best_norm_smallest_S = norm_smallest_S
best_truncation_eps = varipeps_global_state.ctmrg_effective_truncation_eps
current_truncation_eps = (
varipeps_config.ctmrg_truncation_eps
if varipeps_global_state.ctmrg_effective_truncation_eps is None
else varipeps_global_state.ctmrg_effective_truncation_eps
)
if (
varipeps_config.ctmrg_heuristic_increase_chi
and norm_smallest_S > varipeps_config.ctmrg_heuristic_increase_chi_threshold
and working_unitcell[0, 0][0][0].chi < working_unitcell[0, 0][0][0].max_chi
):
new_chi = (
working_unitcell[0, 0][0][0].chi
+ varipeps_config.ctmrg_heuristic_increase_chi_step_size
)
if new_chi > working_unitcell[0, 0][0][0].max_chi:
new_chi = working_unitcell[0, 0][0][0].max_chi
if not new_chi in already_tried_chi:
working_unitcell = working_unitcell.change_chi(new_chi)
initial_unitcell = initial_unitcell.change_chi(new_chi)
if logger.isEnabledFor(logging.INFO):
logger.info(
"Increasing chi to %d since smallest SVD Norm was %.3e.",
new_chi,
norm_smallest_S,
)
already_tried_chi.add(new_chi)
have_been_increased = True
continue
elif varipeps_config.ctmrg_heuristic_decrease_chi and (
(
norm_smallest_S < current_truncation_eps
and working_unitcell[0, 0][0][0].chi > 2
)
or (
not converged
and not have_been_increased
and norm_smallest_S
< varipeps_config.ctmrg_heuristic_increase_chi_threshold
)
):
new_chi = (
working_unitcell[0, 0][0][0].chi
- varipeps_config.ctmrg_heuristic_decrease_chi_step_size
)
if new_chi < 2:
new_chi = 2
if not new_chi in already_tried_chi:
working_unitcell = working_unitcell.change_chi(new_chi)
if logger.isEnabledFor(logging.INFO):
logger.info(
"Decreasing chi to %d since smallest SVD Norm was %.3e or routine did not converge.",
new_chi,
norm_smallest_S,
)
already_tried_chi.add(new_chi)
continue
if (
varipeps_config.ctmrg_increase_truncation_eps
and end_count == varipeps_config.ctmrg_max_steps
and not converged
):
new_truncation_eps = (
current_truncation_eps
* varipeps_config.ctmrg_increase_truncation_eps_factor
)
if (
new_truncation_eps
<= varipeps_config.ctmrg_increase_truncation_eps_max_value
):
if logger.isEnabledFor(logging.INFO):
logger.info(
"Increasing SVD truncation eps to %.1e.",
new_truncation_eps,
)
varipeps_global_state.ctmrg_effective_truncation_eps = (
new_truncation_eps
)
working_unitcell = initial_unitcell
already_tried_chi = {working_unitcell[0, 0][0][0].chi}
continue
break
if _return_truncation_eps:
last_truncation_eps = varipeps_global_state.ctmrg_effective_truncation_eps
varipeps_global_state.ctmrg_effective_truncation_eps = None
if not converged and best_result is not None:
working_unitcell = best_result
norm_smallest_S = best_norm_smallest_S
converged = True
last_truncation_eps = best_truncation_eps
if (
varipeps_config.ctmrg_fail_if_not_converged
and end_count == varipeps_config.ctmrg_max_steps
and not converged
):
raise CTMRGNotConvergedError
if _return_truncation_eps:
return working_unitcell, last_truncation_eps, norm_smallest_S
return working_unitcell, norm_smallest_S
@custom_vjp
def calc_ctmrg_env_custom_rule(
peps_tensors: Sequence[jnp.ndarray],
unitcell: PEPS_Unit_Cell,
_return_truncation_eps: bool = False,
) -> PEPS_Unit_Cell:
"""
Wrapper function of :obj:`~varipeps.ctmrg.routine.calc_ctmrg_env` which
enables the use of the custom VJP for the calculation of the gradient.
Args:
peps_tensors (:term:`sequence` of :obj:`jax.numpy.ndarray`):
The sequence of unique PEPS tensors the unitcell consists of.
unitcell (:obj:`~varipeps.peps.PEPS_Unit_Cell`):
The unitcell to work on.
Returns:
:obj:`~varipeps.peps.PEPS_Unit_Cell`:
New instance of the unitcell with all updated converged CTMRG tensors of
all elements of the unitcell.
"""
return calc_ctmrg_env(
peps_tensors,
unitcell,
enforce_elementwise_convergence=True,
_return_truncation_eps=_return_truncation_eps,
)
def calc_ctmrg_env_fwd(
peps_tensors: Sequence[jnp.ndarray],
unitcell: PEPS_Unit_Cell,
_return_truncation_eps: bool = False,
) -> Tuple[PEPS_Unit_Cell, Tuple[Sequence[jnp.ndarray], PEPS_Unit_Cell]]:
"""
Internal helper function of custom VJP to calculate the values in
the forward sweep.
"""
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Custom VJP: Starting forward CTMRG calculation.")
new_unitcell, last_truncation_eps, norm_smallest_S = calc_ctmrg_env_custom_rule(
peps_tensors, unitcell, _return_truncation_eps=True
)
return (new_unitcell, norm_smallest_S), (
peps_tensors,
new_unitcell,
unitcell,
last_truncation_eps,
)
def _ctmrg_rev_while_body(carry):
(
vjp_env,
initial_bar,
bar_fixed_point_last_step,
converged,
count,
config,
state,
) = carry
new_env_bar = vjp_env((bar_fixed_point_last_step, jnp.array(0, dtype=jnp.float64)))[
0
]
bar_fixed_point = bar_fixed_point_last_step.replace_unique_tensors(
[
t_old.__add__(t_new, checks=False)
for t_old, t_new in zip(
initial_bar.get_unique_tensors(),
new_env_bar.get_unique_tensors(),
strict=True,
)
]
)
if bar_fixed_point_last_step.is_triangular_peps():
converged, measure, verbose_data = _is_element_wise_converged_triangular(
bar_fixed_point_last_step.get_unique_tensors(),
bar_fixed_point.get_unique_tensors(),
config.ad_custom_convergence_eps,
)
else:
converged, measure, verbose_data = _is_element_wise_converged(
bar_fixed_point_last_step.get_unique_tensors(),
bar_fixed_point.get_unique_tensors(),
config.ad_custom_convergence_eps,
split_transfer=bar_fixed_point.is_split_transfer(),
)
count += 1
if logger.isEnabledFor(logging.DEBUG):
jax.debug.callback(lambda cnt, msr: logger.debug(f"Custom VJP: Step {cnt}: {msr}"), count, measure, ordered=True)
if config.ad_custom_verbose_output:
jax.debug.callback(print_verbose, verbose_data, ordered=True, ad=True)
return vjp_env, initial_bar, bar_fixed_point, converged, count, config, state
@jit
def _ctmrg_rev_workhorse(peps_tensors, new_unitcell, new_unitcell_bar, config, state):
if new_unitcell.is_triangular_peps():
_, vjp_peps_tensors = vjp(
lambda t: do_absorption_step_triangular(t, new_unitcell, config, state),
peps_tensors,
)
vjp_env = tree_util.Partial(
vjp(
lambda u: do_absorption_step_triangular(peps_tensors, u, config, state),
new_unitcell,
)[1]
)
elif new_unitcell.is_split_transfer():
_, vjp_peps_tensors = vjp(
lambda t: do_absorption_step_split_transfer(t, new_unitcell, config, state),
peps_tensors,
)
vjp_env = tree_util.Partial(
vjp(
lambda u: do_absorption_step_split_transfer(
peps_tensors, u, config, state
),
new_unitcell,
)[1]
)
else:
_, vjp_peps_tensors = vjp(
lambda t: do_absorption_step(t, new_unitcell, config, state), peps_tensors
)
vjp_env = tree_util.Partial(
vjp(
lambda u: do_absorption_step(peps_tensors, u, config, state),