-
Notifications
You must be signed in to change notification settings - Fork 301
Expand file tree
/
Copy pathactivation_function_class.py
More file actions
37 lines (25 loc) · 1019 Bytes
/
activation_function_class.py
File metadata and controls
37 lines (25 loc) · 1019 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
"""Classes for configuring activation functions."""
import dataclasses
from enum import Enum
import torch
from hydra.core.config_store import ConfigStore
class ActivationFunctionClass(Enum):
"""Enum of activation function classes."""
TanH = torch.nn.Tanh
ReLU = torch.nn.ReLU
LeakyReLU = torch.nn.LeakyReLU
@dataclasses.dataclass
class Config:
"""Base class for activation function configs."""
activation_function_class: ActivationFunctionClass
_target_: str = "imitation_cli.utils.activation_function_class.Config.make"
@staticmethod
def make(activation_function_class: ActivationFunctionClass) -> type:
return activation_function_class.value
TanH = Config(ActivationFunctionClass.TanH)
ReLU = Config(ActivationFunctionClass.ReLU)
LeakyReLU = Config(ActivationFunctionClass.LeakyReLU)
def register_configs(group: str):
cs = ConfigStore.instance()
for cls in ActivationFunctionClass:
cs.store(group=group, name=cls.name.lower(), node=Config(cls))