Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions chebai/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def __init__(
if exclude_hyperparameter_logging is None:
exclude_hyperparameter_logging = tuple()
self.criterion = criterion
assert out_dim is not None and out_dim > 0, "out_dim must be specified"
assert input_dim is not None and input_dim > 0, "input_dim must be specified"
assert out_dim is not None, "out_dim must be specified"
assert input_dim is not None, "input_dim must be specified"
self.out_dim = out_dim
self.input_dim = input_dim
print(
Expand Down
4 changes: 2 additions & 2 deletions chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class ElectraPre(ChebaiBaseNet):
replace_p (float): Probability of replacing tokens during training.
"""

def __init__(self, config: Dict[str, Any] = None, **kwargs: Any):
super().__init__(config=config, **kwargs)
def __init__(self, config: Dict[str, Any], **kwargs: Any):
super().__init__(**kwargs)

self.generator_config = ElectraConfig(**config["generator"])
self.generator = ElectraForMaskedLM(self.generator_config)
Expand Down
6 changes: 5 additions & 1 deletion chebai/preprocessing/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,11 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
*((d["features"], d["labels"], d.get("ident")) for d in data)
)
missing_labels = [
d.get("missing_labels", [False for _ in y[0]]) for d in data
d.get(
"missing_labels",
[False for _ in y[0]] if y[0] is not None else [False],
)
for d in data
]

if any(x is not None for x in y):
Expand Down
25 changes: 20 additions & 5 deletions chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,11 +514,13 @@ def setup(self, *args, **kwargs) -> None:

rank_zero_info(f"Check for processed data in {self.processed_dir}")
rank_zero_info(f"Cross-validation enabled: {self.use_inner_cross_validation}")
rank_zero_info(f"Looking for files: {self.processed_file_names}")
if any(
not os.path.isfile(os.path.join(self.processed_dir, f))
for f in self.processed_file_names
):
rank_zero_info(
f"Did not find one of: {', '.join(self.processed_file_names)} in {self.processed_dir}"
)
self.setup_processed()

self._after_setup(**kwargs)
Expand Down Expand Up @@ -627,17 +629,17 @@ def raw_file_names_dict(self) -> dict:
raise NotImplementedError

@property
def classes_txt_file_path(self) -> str:
def classes_txt_file_path(self) -> Optional[str]:
"""
Returns the filename for the classes text file.
Returns the filename for the classes text file (for labeled datasets that produce a list of labels).

Returns:
str: The filename for the classes text file.
Optional[str]: The filename for the classes text file.
"""
# This property also used in following places:
# - chebai/result/prediction.py: to load class names for csv columns names
# - chebai/cli.py: to link this property to `model.init_args.classes_txt_file_path`
return os.path.join(self.processed_dir_main, "classes.txt")
return None


class MergedDataset(XYBaseDataModule):
Expand Down Expand Up @@ -1406,3 +1408,16 @@ def processed_file_names_dict(self) -> dict:
if self.n_token_limit is not None:
return {"data": f"data_maxlen{self.n_token_limit}.pt"}
return {"data": "data.pt"}

@property
def classes_txt_file_path(self) -> str:
"""
Returns the filename for the classes text file.

Returns:
str: The filename for the classes text file.
"""
# This property also used in following places:
# - chebai/result/prediction.py: to load class names for csv columns names
# - chebai/cli.py: to link this property to `model.init_args.classes_txt_file_path`
return os.path.join(self.processed_dir_main, "classes.txt")
5 changes: 3 additions & 2 deletions chebai/preprocessing/datasets/pubchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,9 @@ def _set_processed_data_props(self):
self._num_of_labels = 0
self._feature_vector_size = 0

print(f"Number of labels for loaded data: {self._num_of_labels}")
print(f"Feature vector size: {self._feature_vector_size}")
print(
f"Number of labels and feature vector size set to: {self._num_of_labels} / {self._feature_vector_size} (default values, not used for self-supervised learning)"
)

def _perform_data_preparation(self, *args, **kwargs):
"""
Expand Down