-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathtest_save_config.py
More file actions
189 lines (146 loc) · 6.24 KB
/
test_save_config.py
File metadata and controls
189 lines (146 loc) · 6.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import os
import tempfile
import unittest
from unittest.mock import MagicMock, patch
from lightning import LightningModule, Trainer
from lightning.pytorch.loggers import WandbLogger
from chebai.callbacks.save_config import CustomSaveConfigCallback
class DummyModule(LightningModule):
"""Dummy module for testing."""
def __init__(self):
super().__init__()
self.layer = None
def forward(self, x):
return x
class TestCustomSaveConfigCallback(unittest.TestCase):
"""Test CustomSaveConfigCallback functionality."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
def test_callback_uploads_config_with_wandb_logger(self):
"""Test that the callback uploads config when WandbLogger is present."""
# Create a mock parser and config
mock_parser = MagicMock()
mock_config = MagicMock()
# Create the callback
callback = CustomSaveConfigCallback(
parser=mock_parser,
config=mock_config,
config_filename="lightning_config.yaml",
overwrite=True,
)
# Create a config file in the temp directory
config_path = os.path.join(self.temp_dir, "lightning_config.yaml")
with open(config_path, "w") as f:
f.write("test: config\n")
# Create a mock WandbLogger
mock_wandb_logger = MagicMock(spec=WandbLogger)
# Create a mock trainer with the WandbLogger
mock_trainer = MagicMock(spec=Trainer)
mock_trainer.log_dir = self.temp_dir
mock_trainer.loggers = [mock_wandb_logger]
mock_trainer.is_global_zero = True
# Create a dummy module
pl_module = DummyModule()
# Mock wandb module
with patch("wandb.save") as mock_wandb_save:
# Call save_config
callback.save_config(mock_trainer, pl_module, "fit")
# Verify wandb.save was called with the correct arguments
mock_wandb_save.assert_called_once_with(
config_path, base_path=self.temp_dir, policy="now"
)
def test_callback_skips_upload_without_wandb_logger(self):
"""Test that the callback skips upload when no WandbLogger is present."""
# Create a mock parser and config
mock_parser = MagicMock()
mock_config = MagicMock()
# Create the callback
callback = CustomSaveConfigCallback(
parser=mock_parser,
config=mock_config,
config_filename="lightning_config.yaml",
overwrite=True,
)
# Create a config file in the temp directory
config_path = os.path.join(self.temp_dir, "lightning_config.yaml")
with open(config_path, "w") as f:
f.write("test: config\n")
# Create a mock trainer WITHOUT WandbLogger
mock_trainer = MagicMock(spec=Trainer)
mock_trainer.log_dir = self.temp_dir
mock_trainer.loggers = [] # No loggers
mock_trainer.is_global_zero = True
# Create a dummy module
pl_module = DummyModule()
# Mock wandb module
with patch("wandb.save") as mock_wandb_save:
# Call save_config
callback.save_config(mock_trainer, pl_module, "fit")
# Verify wandb.save was NOT called
mock_wandb_save.assert_not_called()
def test_callback_handles_missing_config_file(self):
"""Test that the callback handles missing config file gracefully."""
# Create a mock parser and config
mock_parser = MagicMock()
mock_config = MagicMock()
# Create the callback
callback = CustomSaveConfigCallback(
parser=mock_parser,
config=mock_config,
config_filename="nonexistent_config.yaml",
overwrite=True,
)
# Create a mock WandbLogger
mock_wandb_logger = MagicMock(spec=WandbLogger)
# Create a mock trainer with the WandbLogger
mock_trainer = MagicMock(spec=Trainer)
mock_trainer.log_dir = self.temp_dir
mock_trainer.loggers = [mock_wandb_logger]
mock_trainer.is_global_zero = True
# Create a dummy module
pl_module = DummyModule()
# Mock wandb module
with patch("wandb.save") as mock_wandb_save:
# Call save_config - should not raise an error
callback.save_config(mock_trainer, pl_module, "fit")
# Verify wandb.save was NOT called (because file doesn't exist)
mock_wandb_save.assert_not_called()
def test_callback_handles_wandb_not_installed(self):
"""Test that the callback handles missing wandb package gracefully."""
# Create a mock parser and config
mock_parser = MagicMock()
mock_config = MagicMock()
# Create the callback
callback = CustomSaveConfigCallback(
parser=mock_parser,
config=mock_config,
config_filename="lightning_config.yaml",
overwrite=True,
)
# Create a config file in the temp directory
config_path = os.path.join(self.temp_dir, "lightning_config.yaml")
with open(config_path, "w") as f:
f.write("test: config\n")
# Create a mock WandbLogger
mock_wandb_logger = MagicMock(spec=WandbLogger)
# Create a mock trainer with the WandbLogger
mock_trainer = MagicMock(spec=Trainer)
mock_trainer.log_dir = self.temp_dir
mock_trainer.loggers = [mock_wandb_logger]
mock_trainer.is_global_zero = True
# Create a dummy module
pl_module = DummyModule()
# Mock wandb import to raise ImportError
# This simulates wandb not being installed
with patch("builtins.__import__") as mock_import:
def import_side_effect(name, *args, **kwargs):
if name == "wandb":
raise ImportError("No module named 'wandb'")
return __import__(name, *args, **kwargs)
mock_import.side_effect = import_side_effect
# Call save_config - should not raise an error
# The callback should catch the ImportError and continue gracefully
callback.save_config(mock_trainer, pl_module, "fit")
if __name__ == "__main__":
unittest.main()