Skip to content

Commit 2cbf63a

Browse files
[SYSTEMDS-3259] Implementation of the shampoo optimizer (initial prototype)
Add Shampoo optimizer support for neural network training in SystemDS. Includes full-matrix and diagonal preconditioning, momentum updates, and a heuristic variant with delayed preconditioner updates / infrequent root recomputation. Extend the existing NN training scripts with tests to validate correctness and convergence. Tests: src/test/scripts/applications/nn/component/shampoo_test.dml src/test/scripts/applications/nn/component/shampoo_test2.dml src/test/scripts/applications/nn/component/shampoo_test.py Experiments in staging/shampoo_optimizer
1 parent 3f841b7 commit 2cbf63a

29 files changed

Lines changed: 3113 additions & 0 deletions

scripts/nn/optim/shampoo.dml

Lines changed: 499 additions & 0 deletions
Large diffs are not rendered by default.

scripts/staging/shampoo_optimizer/diagram_creation.ipynb

Lines changed: 482 additions & 0 deletions
Large diffs are not rendered by default.
63.8 KB
Loading
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
1.0,2.0366285569366878,0.26598543709489775,1.9111378509607304,0.328
2+
2.0,1.863508219437754,0.3412155143759349,1.8029196566374601,0.3632
3+
3.0,1.7781075000737478,0.37343204670101376,1.7365597405952091,0.3826
4+
4.0,1.7178380710578425,0.3947648329732425,1.6806893259419955,0.4002
5+
5.0,1.669877784216493,0.41048020192787105,1.6390592825854502,0.4138
6+
6.0,1.6350585694820974,0.42428172261924546,1.6146616228664268,0.4242
7+
7.0,1.6090974892669985,0.4327987369120824,1.5910762546497168,0.4314
8+
8.0,1.587529581892735,0.44248223782615925,1.5692068211402614,0.4374
9+
9.0,1.5657862599669354,0.4502518904769819,1.5481026555575144,0.4452
10+
10.0,1.5430813878099718,0.45822149742396545,1.5275131045309027,0.4538
11+
11.0,1.5199649145732006,0.46613397457204586,1.506689447582152,0.4618
12+
12.0,1.4982848263026722,0.4736133039720792,1.488505237593026,0.4718
13+
13.0,1.4789747560435527,0.48032605534319434,1.4727941864822314,0.4766
14+
14.0,1.4621960038234303,0.4855388898121987,1.4603915792006466,0.4802
15+
15.0,1.4475009074972,0.4910233505068971,1.4511544780956973,0.483
16+
16.0,1.4349582625238917,0.49627929200598303,1.4428575600357896,0.485
17+
17.0,1.423895901496466,0.4997927746385241,1.4336665640797934,0.4904
18+
18.0,1.414059771299443,0.5039585757021772,1.4282307999663446,0.4916
19+
19.0,1.404377145998948,0.5074814068472661,1.4222616630433387,0.4968
20+
20.0,1.3953288419933125,0.5116233172677415,1.4168349642383111,0.499
21+
21.0,1.386648730995609,0.5150796701013794,1.4117382901756448,0.5014
22+
22.0,1.37796944077023,0.5181361143426957,1.4065785507454347,0.5026
23+
23.0,1.3696668695223018,0.5211639936845605,1.402195989771859,0.505
24+
24.0,1.3616434672260533,0.5245060869203922,1.3972141118943933,0.5052
25+
25.0,1.3542630371676898,0.5273719253780954,1.3935068999222704,0.5078
26+
26.0,1.3473694290558793,0.5303141100216054,1.3907051889091704,0.5102
27+
27.0,1.3402073426222127,0.5322232840285857,1.3892659032226624,0.508
28+
28.0,1.3336872003973426,0.5347987992354993,1.3865499433195474,0.506
29+
29.0,1.327384157679365,0.537255380588333,1.3873974009402807,0.5034
30+
30.0,1.3216604201787392,0.5391406639521356,1.3836836288514893,0.5054
31+
31.0,1.3161280598647895,0.5419685889978394,1.3832381357516508,0.507
32+
32.0,1.3105416127136869,0.5436824829649327,1.3800900635777025,0.5088
33+
33.0,1.3047664781705213,0.5450535981386072,1.3801428125459738,0.5088
34+
34.0,1.2996126109982393,0.5470817059996675,1.3776929353987932,0.5116
35+
35.0,1.293823188412723,0.5487145795246802,1.375557900607078,0.514
36+
36.0,1.2884030433697722,0.5512854204753199,1.3705886811509542,0.5142
37+
37.0,1.2831253551292954,0.552880380588333,1.3683649044462805,0.5134
38+
38.0,1.2783286448400695,0.553456352833638,1.3651622891751962,0.515
39+
39.0,1.2733306632853874,0.5543226483297324,1.3643001109138162,0.5164
40+
40.0,1.268891600263991,0.5557508933023101,1.3617116557796567,0.5192
41+
41.0,1.264266710657484,0.557293397872694,1.3587595770290883,0.5178
42+
42.0,1.2596929659095248,0.5590452052517866,1.3554118331969724,0.519
43+
43.0,1.2553917233114704,0.560554470666445,1.3546034200212493,0.5224
44+
44.0,1.2516899362081932,0.5619302600963936,1.353135177682238,0.5236
45+
45.0,1.2473519464193588,0.5630967467176333,1.3521054237005408,0.5234
46+
46.0,1.2433013995256341,0.5652105492770483,1.3513343518829966,0.5234
47+
47.0,1.239448787494256,0.5670148121987701,1.3487089014064364,0.5268
48+
48.0,1.2354618536988764,0.5683812531161708,1.3461108297962203,0.5266
49+
49.0,1.2318261814835512,0.5694667192953299,1.3446704091638717,0.5274
50+
50.0,1.2282065364187986,0.5719851047033405,1.344586378891869,0.5296
51+
51.0,1.224133646245753,0.5726992271896294,1.3432699487117177,0.5304
52+
52.0,1.220546875066253,0.5752082640850922,1.3426149725440661,0.5308
53+
53.0,1.217120690696172,0.5763555343194283,1.3426393843545892,0.5318
54+
54.0,1.213568203392273,0.577783779292006,1.3404433750471476,0.5328
55+
55.0,1.2103219298678027,0.5786931818181817,1.3395115771735926,0.5354
56+
56.0,1.206984211848854,0.5799261467508725,1.3372971072601654,0.5356
57+
57.0,1.2037463317947639,0.5816639313611434,1.3349564292855296,0.5356
58+
58.0,1.2002898800179351,0.5821495346518198,1.334121740633766,0.535
59+
59.0,1.196807843220183,0.5832921306298819,1.3337720183945212,0.537
60+
60.0,1.193598377912578,0.5850345894964267,1.3323687113361222,0.5366
61+
61.0,1.1900622506273248,0.5858058417816187,1.331289361658365,0.5352
62+
62.0,1.1871378053459447,0.5867438092072461,1.3306094609968828,0.5378
63+
63.0,1.1841863183111265,0.5882577488781785,1.3306050856582037,0.5376
64+
64.0,1.1808866625784924,0.5883720084759847,1.3297544647820971,0.5372
65+
65.0,1.1784611619616456,0.5891718256606282,1.3293163604280938,0.539
66+
66.0,1.1754922767281581,0.5895146044540468,1.3281268708805458,0.539
67+
67.0,1.1727712033092526,0.591171368622237,1.3243216554218156,0.5398
68+
68.0,1.16986368405737,0.5916948853249128,1.3250558910354457,0.5402
69+
69.0,1.1669975974070492,0.5933277588499252,1.32206820973747,0.541
70+
70.0,1.1640628770424661,0.5947419810536813,1.3229685907464526,0.5398
71+
71.0,1.1614204583284387,0.5954083222536147,1.324079172392963,0.5418
72+
72.0,1.1584428785720262,0.5964034194781452,1.322592790482728,0.54
73+
73.0,1.155897042957851,0.5969840659797241,1.321316808597242,0.5404
74+
74.0,1.1533538946445323,0.5984221788266578,1.3240338764937882,0.5402
75+
75.0,1.150596395153521,0.5990838457703174,1.3250515243843302,0.5414
76+
76.0,1.1479828276607542,0.5997361642014293,1.3213816765518247,0.5408
77+
77.0,1.1456013991649443,0.6007359356822337,1.322209337122594,0.5414
78+
78.0,1.1428948834559163,0.6021402900116337,1.3208159001419877,0.5416
79+
79.0,1.1405840977255453,0.6030258018946318,1.3190405319298553,0.5436
80+
80.0,1.138456276159628,0.6046540011633704,1.319016162754993,0.5428
81+
81.0,1.1360969314308267,0.604801499916902,1.317674437313949,0.5436
82+
82.0,1.1335010399508012,0.6058584011966096,1.3186716090087884,0.5436
83+
83.0,1.13123556274039,0.60646761259764,1.3147843189430144,0.5432
84+
84.0,1.1291076858663522,0.6074959489778959,1.3154507688752388,0.5422
85+
85.0,1.1270488055037702,0.6081815065647332,1.3113865161079785,0.5454
86+
86.0,1.1246272411531353,0.6084671555592488,1.31230086333935,0.543
87+
87.0,1.1226729512819174,0.6089195196941998,1.3106131954976976,0.5432
88+
88.0,1.1202892968081963,0.6099572045870034,1.3103511664799978,0.5434
89+
89.0,1.11780839667264,0.6109523018115339,1.3103589073105326,0.5418
90+
90.0,1.115877186853028,0.6116378593983712,1.3095933331067162,0.5456
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"data_type": "matrix",
3+
"value_type": "double",
4+
"rows": 90,
5+
"cols": 5,
6+
"nnz": 450,
7+
"format": "csv",
8+
"author": "nicol",
9+
"header": false,
10+
"sep": ",",
11+
"created": "2026-01-17 19:35:45 MEZ"
12+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
1.0,0.6600831748999706,0.7919920634920635,0.45846249291763935,0.844
2+
2.0,0.25601354551075184,0.923,0.3060521207758516,0.909
3+
3.0,0.1791216355591798,0.947125,0.21327004292691254,0.933
4+
4.0,0.1399477948175259,0.95775,0.1736507581045767,0.942
5+
5.0,0.10750665179781986,0.9675,0.1505918516055093,0.952
6+
6.0,0.09024092394376199,0.971875,0.15272215195884234,0.949
7+
7.0,0.08109318636628057,0.974,0.18253549984362463,0.943
8+
8.0,0.07343285225967074,0.9765,0.1963133682103849,0.936
9+
9.0,0.06492667107394304,0.97925,0.16139850527851352,0.948
10+
10.0,0.05564659771922558,0.9825,0.15813024537483578,0.949
11+
11.0,0.04885065099874338,0.984625,0.13685280674984188,0.958
12+
12.0,0.04345266142523965,0.986,0.1295635143772901,0.963
13+
13.0,0.03789947918193832,0.988125,0.1280999416145119,0.964
14+
14.0,0.03440348961873195,0.98975,0.12033225446550745,0.967
15+
15.0,0.03113956579003195,0.9905,0.14465871496448326,0.956
16+
16.0,0.02782411652521333,0.992,0.13935065575895655,0.959
17+
17.0,0.025903924198902143,0.992875,0.13805741037482513,0.96
18+
18.0,0.023049034921953607,0.994,0.1390425642677865,0.958
19+
19.0,0.021960489895278133,0.99375,0.14567394365139114,0.953
20+
20.0,0.022683080115491985,0.992875,0.13101678745364734,0.962
21+
21.0,0.01951068056535099,0.99375,0.16387851254291194,0.953
22+
22.0,0.02072305049782118,0.993375,0.13199021229927832,0.963
23+
23.0,0.01573298327963868,0.9955,0.11595083667347869,0.966
24+
24.0,0.016863144556757713,0.994625,0.12464655755935368,0.964
25+
25.0,0.011906263883559153,0.99675,0.09555637996964318,0.975
26+
26.0,0.01387731248733787,0.995,0.07513476435524133,0.973
27+
27.0,0.014348454929750287,0.99525,0.09799803592380817,0.973
28+
28.0,0.014791121236119583,0.995,0.09913795348726741,0.97
29+
29.0,0.01745363921146775,0.995,0.10855312385534872,0.967
30+
30.0,0.016810133122882962,0.994625,0.09143590778251723,0.972
31+
31.0,0.013776919442576465,0.995125,0.09972306275201406,0.973
32+
32.0,0.012081841240356836,0.995875,0.1042540514767386,0.971
33+
33.0,0.009673850822736923,0.997,0.08166328277502062,0.976
34+
34.0,0.010130896530972777,0.99775,0.08359973442180495,0.975
35+
35.0,0.009486136169168992,0.996375,0.09465088299135849,0.969
36+
36.0,0.011784477117724483,0.9964980158730159,0.12253367435178857,0.964
37+
37.0,0.01285823612554023,0.9955,0.0778909504372984,0.973
38+
38.0,0.005531593403143221,0.99875,0.08527911368874061,0.978
39+
39.0,0.004174380412227944,0.99925,0.06256664584964217,0.98
40+
40.0,0.003724836264838745,0.999375,0.06596277050627802,0.977
41+
41.0,0.00373747439026559,0.999,0.106305429921357,0.975
42+
42.0,0.00446614766949216,0.99875,0.06967587883698868,0.983
43+
43.0,0.0033526953383332878,0.999125,0.06330396928658376,0.984
44+
44.0,0.0028230799970961214,0.999375,0.08233284112176356,0.979
45+
45.0,0.0027498272954112166,0.99975,0.06967829332529306,0.983
46+
46.0,0.0021902541798457316,0.9995,0.12228901254433457,0.971
47+
47.0,0.014584827217146804,0.994625,0.11751439804792264,0.973
48+
48.0,0.003940440886971421,0.999125,0.07298610902463333,0.982
49+
49.0,0.002838167116175086,0.999375,0.05795638638785478,0.986
50+
50.0,0.0024289928108280536,0.99925,0.07105060997549256,0.98
51+
51.0,8.843184635752194E-4,1.0,0.06353335404859009,0.984
52+
52.0,6.38977018076913E-4,1.0,0.06250178394164058,0.986
53+
53.0,0.0010020022685720159,0.99975,0.08592793990596928,0.977
54+
54.0,0.003981127656437652,0.99875,0.14202391884937093,0.964
55+
55.0,0.009452855843030253,0.99725,0.11272917235094357,0.969
56+
56.0,0.02030383425854622,0.994,0.22102710131762598,0.961
57+
57.0,0.0201962723872667,0.992625,0.0746092794797957,0.98
58+
58.0,0.012778119867117035,0.99625,0.08488781163569933,0.98
59+
59.0,0.005098323074058392,0.998625,0.0705106763422874,0.983
60+
60.0,0.002205910690743962,0.9995,0.06720935609152047,0.985
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"data_type": "matrix",
3+
"value_type": "double",
4+
"rows": 60,
5+
"cols": 5,
6+
"nnz": 300,
7+
"format": "csv",
8+
"author": "nicol",
9+
"header": false,
10+
"sep": ",",
11+
"created": "2026-01-18 12:40:31 MEZ"
12+
}
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
1.0,2.308857970725247,0.09733879009473159,2.303781414916082,0.1044
2+
2.0,2.304008930124782,0.09928120325743726,2.3033522297439544,0.1044
3+
3.0,2.3037258759929133,0.09828143177663287,2.303188402061223,0.1044
4+
4.0,2.3035719744954597,0.09843931776632873,2.3031009169012147,0.1044
5+
5.0,2.3034703247407293,0.09878209655974737,2.303046278570225,0.1044
6+
6.0,2.303396340692825,0.09849644756523183,2.303008793363021,0.1044
7+
7.0,2.303339197122346,0.09803940917400697,2.3029813853473597,0.1044
8+
8.0,2.303293240514144,0.09798227937510387,2.3029603952712767,0.1044
9+
9.0,2.3032551783927606,0.097525240983879,2.302943743402374,0.1044
10+
10.0,2.303222941068266,0.09741098138607279,2.3029301622777862,0.1044
11+
11.0,2.30319515139901,0.0971538972910088,2.3029188364644875,0.1044
12+
12.0,2.3031708517290452,0.09735385158716968,2.3029092180000568,0.1044
13+
13.0,2.303149352177799,0.09749667608442746,2.30290092519086,0.1044
14+
14.0,2.303130141169306,0.09738241648662123,2.3028936839986205,0.1044
15+
15.0,2.303112830057921,0.09746811118497589,2.3028872925240407,0.1044
16+
16.0,2.303097117460817,0.0978394548778461,2.3028815986504196,0.1044
17+
17.0,2.303082765498494,0.09755380588333055,2.3028764855041284,0.1044
18+
18.0,2.303069583506499,0.09738241648662123,2.302871861726882,0.1044
19+
19.0,2.303057416587966,0.09738241648662123,2.3028676548032334,0.1044
20+
20.0,2.3030461373916387,0.09772519528003988,2.3028638063815388,0.1044
21+
21.0,2.303035640092217,0.09789658467674921,2.302860268927418,0.1044
22+
22.0,2.303025835907424,0.09781088997839454,2.3028570032873428,0.1044
23+
23.0,2.3030166497080473,0.09775376017949143,2.302853976886092,0.1044
24+
24.0,2.303008017419134,0.09758237078278212,2.302851162373504,0.1044
25+
25.0,2.3029998840026313,0.09763950058168522,2.3028485365948654,0.1044
26+
26.0,2.3029922018735722,0.09755380588333055,2.3028460797978676,0.0958
27+
27.0,2.3029849296437037,0.09726815688881502,2.3028437750148405,0.0958
28+
28.0,2.302978031115318,0.09718246219046035,2.3028416075764797,0.0958
29+
29.0,2.3029714744683667,0.09718246219046035,2.302839564725359,0.0958
30+
30.0,2.302965231598419,0.09675398869868705,2.3028376353059916,0.0958
31+
31.0,2.3029592775733714,0.09721102708991192,2.3028358095141903,0.0958
32+
32.0,2.3029535901845293,0.09703963769320259,2.3028340786928005,0.0958
33+
33.0,2.3029481495732043,0.09718246219046035,2.3028324351640124,0.0958
34+
34.0,2.302942937918265,0.09683968339704171,2.3028308720907593,0.0958
35+
35.0,2.3029379391731855,0.09683968339704171,2.3028293833614257,0.0958
36+
36.0,2.302933138843571,0.09695394299484793,2.3028279634933666,0.0958
37+
37.0,2.30292852379795,0.09689681319594481,2.302826607551716,0.0958
38+
38.0,2.3029240821061667,0.09692537809539638,2.3028253110806913,0.0958
39+
39.0,2.3029198029006626,0.09689681319594481,2.302824070045189,0.0958
40+
40.0,2.302915676256911,0.09695394299484793,2.3028228807808824,0.0958
41+
41.0,2.302911693089989,0.09701107279375104,2.3028217399514026,0.0958
42+
42.0,2.3029078450647407,0.09695394299484793,2.302820644511436,0.0958
43+
43.0,2.302904124517459,0.09695394299484793,2.302819591674796,0.0958
44+
44.0,2.302900524387369,0.09695394299484793,2.3028185788866904,0.0958
45+
45.0,2.30289703815653,0.09686824829649326,2.3028176037995474,0.0958
46+
46.0,2.302893659796928,0.09669685889978394,2.3028166642518713,0.0958
47+
47.0,2.3028903837237644,0.09695394299484793,2.302815758249686,0.0958
48+
48.0,2.3028872047540805,0.09695394299484793,2.302814883950203,0.0958
49+
49.0,2.302884118070067,0.09701107279375104,2.3028140396473953,0.0958
50+
50.0,2.302881119186381,0.09709676749210569,2.3028132237592303,0.0958
51+
51.0,2.302878203920975,0.09741098138607279,2.302812434816333,0.0958
52+
52.0,2.302875368369051,0.09726815688881502,2.302811671451892,0.0958
53+
53.0,2.3028726088796594,0.09712533239155725,2.3028109323926556,0.0958
54+
54.0,2.302869922034726,0.09695394299484793,2.302810216450874,0.0958
55+
55.0,2.3028673046301313,0.09701107279375104,2.3028095225170846,0.0958
56+
56.0,2.302864753658669,0.09692537809539638,2.3028088495536134,0.0958
57+
57.0,2.302862266294632,0.0971538972910088,2.3028081965887446,0.0958
58+
58.0,2.3028598398798685,0.0971538972910088,2.302807562711441,0.0958
59+
59.0,2.30285747191109,0.09726815688881502,2.3028069470665855,0.0958
60+
60.0,2.3028551600284204,0.09726815688881502,2.302806348850665,0.0958
61+
61.0,2.3028529020048816,0.09726815688881502,2.302805767307857,0.0958
62+
62.0,2.3028506957368564,0.09741098138607279,2.3028052017264735,0.0958
63+
63.0,2.3028485392353937,0.09732528668771813,2.3028046514357197,0.0958
64+
64.0,2.3028464306182124,0.09743954628552434,2.3028041158027452,0.0958
65+
65.0,2.3028443681023987,0.09769663038058833,2.3028035942299465,0.0958
66+
66.0,2.302842349997736,0.09772519528003988,2.3028030861525015,0.0958
67+
67.0,2.3028403747005197,0.09766806548113678,2.30280259103611,0.0958
68+
68.0,2.3028384406879288,0.09761093568223367,2.3028021083749253,0.0958
69+
69.0,2.3028365465127987,0.09755380588333055,2.3028016376896474,0.0958
70+
70.0,2.302834690798834,0.09758237078278212,2.3028011785257787,0.0958
71+
71.0,2.302832872236157,0.09749667608442746,2.302800730452008,0.0958
72+
72.0,2.3028310895772277,0.09772519528003988,2.302800293058731,0.0958
73+
73.0,2.3028293416330525,0.09769663038058833,2.3027998659566746,0.0958
74+
74.0,2.3028276272696626,0.097782325078943,2.302799448775633,0.0958
75+
75.0,2.302825945404852,0.09786801977729766,2.3027990411632944,0.0958
76+
76.0,2.302824295005185,0.09801084427455542,2.302798642784158,0.0958
77+
77.0,2.3028226750831555,0.09841075286687717,2.3027982533185245,0.0958
78+
78.0,2.3028210846945876,0.09838218796742562,2.3027978724615648,0.0958
79+
79.0,2.3028195229362116,0.09826792836961941,2.3027974999224523,0.0958
80+
80.0,2.302817988943382,0.09806797407345853,2.302797135423554,0.0958
81+
81.0,2.3028164818879793,0.09818223367126475,2.302796778699684,0.0958
82+
82.0,2.3028150009764223,0.09830064816353665,2.302796429497398,0.0958
83+
83.0,2.302813545447827,0.09830064816353665,2.3027960875743467,0.0958
84+
84.0,2.302812114572304,0.09818638856573043,2.3027957526986635,0.0958
85+
85.0,2.302810707649296,0.09815782366627888,2.3027954246483984,0.0958
86+
86.0,2.3028093240061196,0.09818638856573043,2.302795103210983,0.0958
87+
87.0,2.302807962996513,0.09804356406847267,2.3027947881827373,0.0958
88+
88.0,2.3028066239993126,0.09804356406847267,2.3027944793684005,0.0958
89+
89.0,2.3028053064172282,0.09784360977231178,2.3027941765806963,0.0958
90+
90.0,2.3028040096756257,0.09795786937011801,2.3027938796399248,0.0958
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"data_type": "matrix",
3+
"value_type": "double",
4+
"rows": 90,
5+
"cols": 5,
6+
"nnz": 450,
7+
"format": "csv",
8+
"author": "nicol",
9+
"header": false,
10+
"sep": ",",
11+
"created": "2026-01-17 14:31:06 MEZ"
12+
}

0 commit comments

Comments
 (0)