-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathdataset.py
More file actions
109 lines (83 loc) · 3.3 KB
/
dataset.py
File metadata and controls
109 lines (83 loc) · 3.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from typing import Dict, Any, List
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
import image_utils
def min_max_norm_fn(x: np.ndarray) -> np.ndarray:
"""Normalize array using min-max normalization."""
min_vals = np.amin(x, keepdims=True)
max_vals = np.amax(x, keepdims=True)
return (x - min_vals) / (max_vals - min_vals + 1e-8)
class SubCellDataset(Dataset):
"""PyTorch Dataset for SubCell image processing"""
def __init__(self, path_list_file, model_channels="rybg"):
"""
Args:
path_list_file (str): Path to the CSV file containing image paths
model_channels (str): Channel configuration (rybg, rbg, ybg, bg)
"""
self.model_channels = model_channels
self.data_list = []
self.uses_old_format = False
# Define channel mapping
self.channel_mapping = {
"r": "r_image",
"y": "y_image",
"b": "b_image",
"g": "g_image",
}
# Read CSV
df = pd.read_csv(path_list_file)
# Remove the '#' from column names if present
df.columns = df.columns.str.lstrip("#")
# Detect CSV format (old vs new)
self.uses_old_format = "output_folder" in df.columns
self.data_list = df.to_dict("records")
def __len__(self) -> int:
"""Return the number of samples in the dataset."""
return len(self.data_list)
def __getitem__(self, idx: int) -> Dict[str, Any]:
"""Load and preprocess a single image set"""
item = self.data_list[idx]
# Load images based on model channels configuration
cell_data = []
# Only process channels specified in model_channels
for channel_name in self.model_channels:
channel_key = self.channel_mapping[channel_name]
# load the channel image
img = image_utils.read_grayscale_image(item[channel_key])
cell_data.append(img)
# Stack images along channel dimension
cell_data = np.stack(cell_data, axis=0) # Shape: (channels, height, width)
cell_data = min_max_norm_fn(
cell_data
) # (always normalized to 0-1 range as required by model)
result = {
"images": cell_data.astype(np.float32),
"output_prefix": item["output_prefix"],
"original_item": item,
}
# Include output_folder only if using old format
if self.uses_old_format:
result["output_folder"] = item["output_folder"]
return result
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Custom collate function for batching.
Args:
batch: List of dataset items
Returns:
Dictionary with batched tensors and lists
"""
# Stack numpy arrays first, then convert to tensor
images_np = np.stack([item["images"] for item in batch])
images = torch.from_numpy(images_np)
result = {
"images": images,
"output_prefixes": [item["output_prefix"] for item in batch],
"original_items": [item["original_item"] for item in batch],
}
# Include output_folders only if present (old format)
if "output_folder" in batch[0]:
result["output_folders"] = [item["output_folder"] for item in batch]
return result