-
Notifications
You must be signed in to change notification settings - Fork 108
Expand file tree
/
Copy pathconftest.py
More file actions
123 lines (99 loc) · 2.98 KB
/
conftest.py
File metadata and controls
123 lines (99 loc) · 2.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import tempfile
import shutil
from pathlib import Path
import pytest
import yaml
from unittest.mock import MagicMock
@pytest.fixture
def temp_dir():
"""Create a temporary directory for test files."""
temp_path = tempfile.mkdtemp()
yield Path(temp_path)
shutil.rmtree(temp_path)
@pytest.fixture
def mock_config():
"""Create a mock configuration object."""
config = MagicMock()
config.MODEL = MagicMock()
config.MODEL.TYPE = 'swin'
config.MODEL.NAME = 'swin_base'
config.MODEL.DROP_PATH_RATE = 0.1
config.MODEL.NUM_CLASSES = 1000
config.DATA = MagicMock()
config.DATA.IMG_SIZE = 224
config.DATA.BATCH_SIZE = 32
config.DATA.DATA_PATH = '/path/to/data'
config.TRAIN = MagicMock()
config.TRAIN.EPOCHS = 100
config.TRAIN.BASE_LR = 1e-4
config.TRAIN.WEIGHT_DECAY = 0.05
return config
@pytest.fixture
def sample_yaml_config(temp_dir):
"""Create a sample YAML configuration file."""
config_data = {
'MODEL': {
'TYPE': 'vit',
'NAME': 'vit_base',
'NUM_CLASSES': 1000,
'DROP_PATH_RATE': 0.1
},
'DATA': {
'IMG_SIZE': 224,
'BATCH_SIZE': 64,
'DATA_PATH': '/data/imagenet'
},
'TRAIN': {
'EPOCHS': 300,
'BASE_LR': 5e-4,
'WEIGHT_DECAY': 0.05
}
}
config_path = temp_dir / 'test_config.yaml'
with open(config_path, 'w') as f:
yaml.dump(config_data, f)
return config_path
@pytest.fixture
def mock_dataset():
"""Create a mock dataset object."""
dataset = MagicMock()
dataset.__len__ = MagicMock(return_value=1000)
dataset.__getitem__ = MagicMock(return_value=(MagicMock(), 0))
return dataset
@pytest.fixture
def mock_model():
"""Create a mock model object."""
model = MagicMock()
model.forward = MagicMock(return_value=MagicMock(shape=(32, 1000)))
model.parameters = MagicMock(return_value=[MagicMock() for _ in range(10)])
return model
@pytest.fixture
def sample_image_tensor():
"""Create a sample image tensor for testing."""
try:
import torch
return torch.randn(1, 3, 224, 224)
except ImportError:
# Return a mock if torch is not available during testing
return MagicMock(shape=(1, 3, 224, 224))
@pytest.fixture
def captured_output():
"""Capture stdout and stderr for testing print outputs."""
import sys
from io import StringIO
old_stdout = sys.stdout
old_stderr = sys.stderr
sys.stdout = StringIO()
sys.stderr = StringIO()
yield sys.stdout, sys.stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
@pytest.fixture(autouse=True)
def reset_modules():
"""Reset imported modules to ensure clean test state."""
import sys
modules_to_reset = [m for m in sys.modules.keys() if m.startswith(('models', 'data'))]
for module in modules_to_reset:
sys.modules.pop(module, None)
yield