-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathaugmenter_factory.py
More file actions
39 lines (33 loc) · 1.58 KB
/
augmenter_factory.py
File metadata and controls
39 lines (33 loc) · 1.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from typing import Dict, Tuple, Type
from const_utils.default_values import AppSettings
from services.augmenter.base_augmenter import BaseAugmenter
from services.augmenter.image_augmenter.umap_dhash import UmapDhashAugmenter
from tests.test_file_operation import settings
class AugmenterFactory:
"""
Factory class to create augmenter instances based on the datatype and method specified in the settings.
Args:
_strategies (Dict[Tuple[str, str], Type[BaseAugmenter]]) : A mapping of (datatype, method) to the corresponding
augmenter class.
"""
_strategies: Dict[Tuple[str, str], Type[BaseAugmenter]] = {
("image", "umap_hash"): UmapDhashAugmenter
}
@staticmethod
def get_augmenter(settings: AppSettings, method: str) -> BaseAugmenter:
"""
Retrieves an augmenter instance based on the datatype and method specified in the settings.
Args:
settings (AppSettings): Global configuration containing the datatype.
method (str): The augmentation method to be used.
Returns:
BaseAugmenter: An instance of the augmenter corresponding to the specified datatype and method.
Raises:
ValueError: If no augmenter is found for the given datatype and method.
"""
datatype = settings.datatype
key = (datatype, method)
augmenter_class = AugmenterFactory._strategies.get(key)
if not augmenter_class:
raise ValueError(f"No augmenter found for datatype '{datatype}' and method '{method}'")
return augmenter_class(settings=settings)