Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions devograph/datasets/datasets1.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ def normalize_array(self, array):
return array

def edge_feat_embedding(self, x, edge_index):
if edge_index.shape[1] == 0:
return np.empty((0, x.shape[1]), dtype=np.float32)
src, trg = edge_index
sub_x = x[src] - x[trg]
abs_sub = np.abs(sub_x)
Expand Down
48 changes: 48 additions & 0 deletions test/test_datasets1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import sys
import types
from unittest.mock import MagicMock

# hydra is not installed in this env; mock it before importing the module
hydra_mock = types.ModuleType("hydra")
hydra_utils_mock = types.ModuleType("hydra.utils")
hydra_utils_mock.get_original_cwd = MagicMock(return_value="/tmp")
hydra_mock.utils = hydra_utils_mock
sys.modules.setdefault("hydra", hydra_mock)
sys.modules.setdefault("hydra.utils", hydra_utils_mock)

import numpy as np
import torch

from devograph.datasets.datasets1 import CellTrackDataset


def make_dataset():
"""Bypass __init__ and set only the attributes edge_feat_embedding needs."""
ds = object.__new__(CellTrackDataset)
ds.edge_feat_embed_dict = {'p': 1, 'use_normalized_x': True, 'normalized_features': True}
ds.which_preprocess = 'MinMax'
ds.separate_models = False
ds.normalize_cols = np.array([True, True, True])
return ds


def test_edge_feat_embedding_empty_edge_index():
"""No crash and correct shape when graph has zero edges."""
ds = make_dataset()
x = np.random.rand(5, 3).astype(np.float32)
edge_index = torch.empty((2, 0), dtype=torch.long)

result = ds.edge_feat_embedding(x, edge_index)

assert result.shape == (0, 3), f"Expected (0, 3), got {result.shape}"


def test_edge_feat_embedding_nonempty_edge_index():
"""Normal case: returns (E, F) edge features."""
ds = make_dataset()
x = np.random.rand(5, 3).astype(np.float32)
edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]], dtype=torch.long)

result = ds.edge_feat_embedding(x, edge_index)

assert result.shape == (3, 3), f"Expected (3, 3), got {result.shape}"