-
Notifications
You must be signed in to change notification settings - Fork 8
New regression and classification datasets for ontology pre-training #130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 54 commits
Commits
Show all changes
61 commits
Select commit
Hold shift + click to select a range
ed2ca6c
create new class for solubility data
schnamo ead3007
adjusting new class
schnamo 5956183
add solubility yml file
schnamo c3afeed
adjusting solubility class to correctly download solubility data
schnamo d57b073
make it compatible with classification problem
schnamo 0faca31
onehotencoding for solubility labels
schnamo 4000215
adjust to regression, add yml files for regression
schnamo 0709188
adjust prediction to regression
schnamo f8bd06a
refactor code
schnamo 21fbde4
regression fix, yml files for mae loss
schnamo f3bfe08
take out kinect dataset
schnamo e26925d
adjust learning rate
schnamo 0f2f85f
adjustments for new solu dataset
schnamo d0da5c2
Merge branch 'dev' of https://github.com/schnamo/python-chebai into s…
schnamo 0d94b44
working on evaluation script, addded a bunch of things earlier for so…
schnamo 45228ba
further adjusting evaluation function for regression
schnamo dbf8532
regression adjustments
schnamo fa97f45
fix union expression
schnamo 8b91dce
fix tuple issue to make it backwards compatible
schnamo 677d6ec
wandb
schnamo 2c159e8
fix issue with solubility dataset read in
schnamo b537b7f
Fix missing label handling
MGlauer a99e438
add more datasets
schnamo 754de12
Merge commit 'b537b7fd776e6afc535e05a111a0bc6a493ec8e9' of https://gi…
schnamo 9b084cb
merge branches part 2
schnamo 326e9a2
add more datasets
schnamo c272f45
adjust metrics for classifications, add BBBP
schnamo dc9e104
more datasets
schnamo 9a3967d
bug fixes and different loss and electra params
schnamo 1bc8736
changes to missing labels: negate labels as well as logits, add them …
schnamo 4885960
try different splits, remove debugging comments
schnamo baa085f
Merge branch 'dev' of https://github.com/schnamo/python-chebai into s…
schnamo ba01607
fix issue with input args
schnamo f74964c
add missing configs
schnamo 59064af
add HIV dataset handling
schnamo 93d47eb
dd MUV dataset
schnamo 87babcc
debugging
schnamo ebe049e
final updates
schnamo 188f32f
add focal loss
schnamo dccc2e3
add focal loss
schnamo d57016f
format for lint
schnamo 41c0b1c
lint fix
schnamo 4c993a2
lint fix
schnamo 4aa1771
add regression to readme
schnamo d411c9e
fix union expression
schnamo 18d8e02
fix tuple issue to make it backwards compatible
schnamo af7df07
Merge branch 'sol_final' into dev and adjustments to new logic, adjus…
ed1d4b4
adjust to current dev branch
dca60a3
adjust all regression tasks to new logic
b6f0d23
adjust classification tasks for new logic
9b29411
lightning cli issue
d56e226
black-lint fix
schnamo fc444e0
fix load from checkpoint issues for pretrained models
schnamo 5304b3e
adding decoding of encoded tokens function
schnamo 426f1b0
remove print statements from debugging
schnamo b0b3113
Merge branch 'dev' of https://github.com/ChEB-AI/python-chebai into dev
schnamo fb6fdb7
lint fixes
schnamo 81f8025
ruff fixes
schnamo 9a24fd7
black fixes
schnamo e67e4eb
lint fixes
schnamo 01f9f5f
adjust unit tests for missing labels in tox21
schnamo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,152 @@ | ||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
|
|
||
|
|
||
| # from https://github.com/itakurah/Focal-loss-PyTorch | ||
|
|
||
|
|
||
| class FocalLoss(nn.Module): | ||
| def __init__( | ||
| self, | ||
| gamma=2, | ||
| alpha=None, | ||
| reduction="mean", | ||
| task_type="binary", | ||
| num_classes=None, | ||
| ): | ||
| """ | ||
| Unified Focal Loss class for binary, multi-class, and multi-label classification tasks. | ||
| :param gamma: Focusing parameter, controls the strength of the modulating factor (1 - p_t)^gamma | ||
| :param alpha: Balancing factor, can be a scalar or a tensor for class-wise weights. If None, no class balancing is used. | ||
| :param reduction: Specifies the reduction method: 'none' | 'mean' | 'sum' | ||
| :param task_type: Specifies the type of task: 'binary', 'multi-class', or 'multi-label' | ||
| :param num_classes: Number of classes (only required for multi-class classification) | ||
| """ | ||
| super(FocalLoss, self).__init__() | ||
| self.gamma = gamma | ||
| self.alpha = alpha | ||
| self.reduction = reduction | ||
| self.task_type = task_type | ||
| self.num_classes = num_classes | ||
|
|
||
| # Handle alpha for class balancing in multi-class tasks | ||
| if ( | ||
| task_type == "multi-class" | ||
| and alpha is not None | ||
| and isinstance(alpha, (list, torch.Tensor)) | ||
| ): | ||
| assert ( | ||
| num_classes is not None | ||
| ), "num_classes must be specified for multi-class classification" | ||
| if isinstance(alpha, list): | ||
| self.alpha = torch.Tensor(alpha) | ||
| else: | ||
| self.alpha = alpha | ||
|
|
||
| def forward(self, inputs, targets): | ||
| """ | ||
| Forward pass to compute the Focal Loss based on the specified task type. | ||
| :param inputs: Predictions (logits) from the model. | ||
| Shape: | ||
| - binary/multi-label: (batch_size, num_classes) | ||
| - multi-class: (batch_size, num_classes) | ||
| :param targets: Ground truth labels. | ||
| Shape: | ||
| - binary: (batch_size,) | ||
| - multi-label: (batch_size, num_classes) | ||
| - multi-class: (batch_size,) | ||
| """ | ||
| if self.task_type == "binary": | ||
| return self.binary_focal_loss(inputs, targets) | ||
| elif self.task_type == "multi-class": | ||
| return self.multi_class_focal_loss(inputs, targets) | ||
| elif self.task_type == "multi-label": | ||
| return self.multi_label_focal_loss(inputs, targets) | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported task_type '{self.task_type}'. Use 'binary', 'multi-class', or 'multi-label'." | ||
| ) | ||
|
|
||
| def binary_focal_loss(self, inputs, targets): | ||
| """Focal loss for binary classification.""" | ||
| probs = torch.sigmoid(inputs) | ||
| targets = targets.float() | ||
|
|
||
| # Compute binary cross entropy | ||
| bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") | ||
|
|
||
| # Compute focal weight | ||
| p_t = probs * targets + (1 - probs) * (1 - targets) | ||
| focal_weight = (1 - p_t) ** self.gamma | ||
|
|
||
| # Apply alpha if provided | ||
| if self.alpha is not None: | ||
| alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) | ||
| bce_loss = alpha_t * bce_loss | ||
|
|
||
| # Apply focal loss weighting | ||
| loss = focal_weight * bce_loss | ||
|
|
||
| if self.reduction == "mean": | ||
| return loss.mean() | ||
| elif self.reduction == "sum": | ||
| return loss.sum() | ||
| return loss | ||
|
|
||
| def multi_class_focal_loss(self, inputs, targets): | ||
| """Focal loss for multi-class classification.""" | ||
| if self.alpha is not None: | ||
| alpha = self.alpha.to(inputs.device) | ||
|
|
||
| # Convert logits to probabilities with softmax | ||
| probs = F.softmax(inputs, dim=1) | ||
|
|
||
| # One-hot encode the targets | ||
| targets_one_hot = F.one_hot(targets, num_classes=self.num_classes).float() | ||
|
|
||
| # Compute cross-entropy for each class | ||
| ce_loss = -targets_one_hot * torch.log(probs) | ||
|
|
||
| # Compute focal weight | ||
| p_t = torch.sum(probs * targets_one_hot, dim=1) # p_t for each sample | ||
| focal_weight = (1 - p_t) ** self.gamma | ||
|
|
||
| # Apply alpha if provided (per-class weighting) | ||
| if self.alpha is not None: | ||
| alpha_t = alpha.gather(0, targets) | ||
| ce_loss = alpha_t.unsqueeze(1) * ce_loss | ||
|
|
||
| # Apply focal loss weight | ||
| loss = focal_weight.unsqueeze(1) * ce_loss | ||
|
|
||
| if self.reduction == "mean": | ||
| return loss.mean() | ||
| elif self.reduction == "sum": | ||
| return loss.sum() | ||
| return loss | ||
|
|
||
| def multi_label_focal_loss(self, inputs, targets): | ||
| """Focal loss for multi-label classification.""" | ||
| probs = torch.sigmoid(inputs) | ||
|
|
||
| # Compute binary cross entropy | ||
| bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") | ||
|
|
||
| # Compute focal weight | ||
| p_t = probs * targets + (1 - probs) * (1 - targets) | ||
| focal_weight = (1 - p_t) ** self.gamma | ||
|
|
||
| # Apply alpha if provided | ||
| if self.alpha is not None: | ||
| alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) | ||
| bce_loss = alpha_t * bce_loss | ||
|
|
||
| # Apply focal loss weight | ||
| loss = focal_weight * bce_loss | ||
|
|
||
| if self.reduction == "mean": | ||
| return loss.mean() | ||
| elif self.reduction == "sum": | ||
| return loss.sum() | ||
| return loss |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -983,4 +983,4 @@ p | |
| [FH2+] | ||
| [ClH2+] | ||
| [BrH2+] | ||
| [IH2+] | ||
| [IH2+] | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will be a problem for merging. I have added new smiles tokens on a different branch (from pubchem) so the new pubchem-pretrained model (and all models based on that) will depend on those tokens.
Are the tokens you added here actually used by a model or are those just artifacts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have removed the part in question and will open an issue and look into what is going on with this