Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion chebai/preprocessing/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,13 @@ def _read_data(self, raw_data: str) -> List[int]:
print(f"RDKit failed to process {raw_data}")
print(f"\t{e}")
try:
mol = Chem.MolFromSmiles(raw_data.strip())
Comment thread
aditya0by0 marked this conversation as resolved.
if mol is None:
raise ValueError(f"Invalid SMILES: {raw_data}")
return [self._get_token_index(v[1]) for v in _tokenize(raw_data)]
except ValueError as e:
print(f"could not process {raw_data}")
print(f"\t{e}")
print(f"\tError: {e}")
return None

def _back_to_smiles(self, smiles_encoded):
Expand Down
32 changes: 18 additions & 14 deletions tests/unit/readers/testChemDataReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,22 @@ def test_read_data(self) -> None:
"""
Test the _read_data method with a SMILES string to ensure it correctly tokenizes the string.
"""
raw_data = "CC(=O)NC1[Mg-2]"
raw_data = "CC(=O)NC1CC1[Mg-2]"
# Expected output as per the tokens already in the cache, and ")" getting added to it.
expected_output: List[int] = [
EMBEDDING_OFFSET + 0, # C
EMBEDDING_OFFSET + 0, # C
EMBEDDING_OFFSET + 5, # =
EMBEDDING_OFFSET + 3, # O
EMBEDDING_OFFSET + 1, # N
EMBEDDING_OFFSET + len(self.reader.cache), # (
EMBEDDING_OFFSET + 2, # C
EMBEDDING_OFFSET + 5, # (
EMBEDDING_OFFSET + 3, # =
EMBEDDING_OFFSET + 1, # O
EMBEDDING_OFFSET + len(self.reader.cache), # ) - new token
EMBEDDING_OFFSET + 2, # N
EMBEDDING_OFFSET + 0, # C
EMBEDDING_OFFSET + 4, # 1
EMBEDDING_OFFSET + len(self.reader.cache) + 1, # [Mg-2]
EMBEDDING_OFFSET + 0, # C
EMBEDDING_OFFSET + 0, # C
EMBEDDING_OFFSET + 4, # 1
EMBEDDING_OFFSET + len(self.reader.cache) + 1, # [Mg-2] - new token
]
result = self.reader._read_data(raw_data)
self.assertEqual(
Expand Down Expand Up @@ -99,13 +102,14 @@ def test_read_data_with_invalid_input(self) -> None:
Test the _read_data method with an invalid input.
The invalid token should prompt a return value None
"""
raw_data = "%INVALID%"

result = self.reader._read_data(raw_data)
self.assertIsNone(
result,
"The output for invalid token '%INVALID%' should be None.",
)
# see https://github.com/ChEB-AI/python-chebai/issues/137
raw_datas = ["%INVALID%", "ADADAD", "ADASDAD", "CC(=O)NC1[Mg-2]"]
for raw_data in raw_datas:
result = self.reader._read_data(raw_data)
self.assertIsNone(
result,
f"The output for invalid token '{raw_data}' should be None.",
)

@patch("builtins.open", new_callable=mock_open)
def test_finish_method_for_new_tokens(self, mock_file: mock_open) -> None:
Expand Down
Loading