-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexample_usage.py
More file actions
271 lines (216 loc) · 10.1 KB
/
example_usage.py
File metadata and controls
271 lines (216 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
"""
Example usage script for SAVOR: Skill Affordance Learning from Visuo-Haptic Perception
for Robot-Assisted Bite Acquisition
This script demonstrates how to use the SAVOR framework for food physical property prediction.
"""
import os
import torch
import numpy as np
from model import SAVORNet
from torch.utils.data import DataLoader
def create_sample_data(batch_size=1, seq_length=10):
"""
Create sample data for demonstration purposes.
Args:
batch_size: Number of samples in the batch
seq_length: Length of the sequence
Returns:
Tuple of (rgb_images, depth_images, force_data, pose_data, scores)
"""
# Create sample RGB images
rgb_images = torch.randn(batch_size, seq_length, 3, 224, 224)
# Create sample depth images
depth_images = torch.randn(batch_size, seq_length, 1, 224, 224)
# Create sample force data (6D: force + torque)
force_data = torch.randn(batch_size, seq_length, 6)
# Create sample pose data (6D: position + orientation)
pose_data = torch.randn(batch_size, seq_length, 6)
# Create sample scores (softness, moisture, viscosity + initial values)
# Food physical property is a discrete value from 1 to 5 (Likert scale)
scores = torch.randint(1, 6, (batch_size, 6, 1, 1)).float()
return rgb_images, depth_images, force_data, pose_data, scores
def demonstrate_model_usage():
"""Demonstrate basic model usage."""
print("SAVOR Model Usage Demonstration")
print("=" * 50)
# Create sample data
SEQ_LENGTH = 20
print("Creating sample data...")
rgb_images, depth_images, force_data, pose_data, scores = create_sample_data(seq_length=SEQ_LENGTH)
print(f" RGB Images shape: {rgb_images.shape}")
print(f" Depth Images shape: {depth_images.shape}")
print(f" Force data shape: {force_data.shape}")
print(f" Pose data shape: {pose_data.shape}")
print(f" Scores (food physical property) shape: {scores.shape}")
# Initialize model
print("\nInitializing model...")
model = SAVORNet(seq_length=SEQ_LENGTH)
# print(f" Model parameters: {sum(p.numel() for p in model.parameters()):,}")
# Forward pass
print("\nRunning forward pass...")
model.eval()
with torch.no_grad():
outputs = model(rgb_images, depth_images, force_data, pose_data)
print(f" Output shape: {outputs.shape}")
print(f" Sample predictions: {outputs[0, 0, :].numpy()}")
# Calculate loss
print("\nCalculating loss (food physical property)...")
criterion = torch.nn.CrossEntropyLoss()
scores_expanded = scores.squeeze(-1).squeeze(-1)[:, :3] # [batch_size, 3]
# Convert to class labels (0-4) for 5 discrete Likert values
class_labels = (scores_expanded - 1).long() # 1-5 -> 0-4
class_labels = class_labels.unsqueeze(1).repeat(1, SEQ_LENGTH, 1) # [batch_size, seq_length, 3]
# Calculate loss for each food physical property
loss = 0.0
for attr_idx in range(3):
attr_output = outputs[:, :, attr_idx, :] # [batch_size, seq_length, 5]
attr_labels = class_labels[:, :, attr_idx] # [batch_size, seq_length]
# Reshape for cross-entropy: [batch_size * seq_length, 5] and [batch_size * seq_length]
attr_output_flat = attr_output.view(-1, 5) # [batch_size * seq_length, 5]
attr_labels_flat = attr_labels.view(-1) # [batch_size * seq_length]
loss += criterion(attr_output_flat, attr_labels_flat)
loss = loss / 3
print(f" Cross-Entropy Loss: {loss.item():.4f}")
print("\nDemonstration completed successfully!")
def demonstrate_training_loop():
"""Demonstrate a simple training loop."""
print("\nTraining Loop Demonstration")
print("=" * 50)
# Create sample data with smaller batch size for memory efficiency
SEQ_LENGTH = 10
rgb_images, depth_images, force_data, pose_data, scores = create_sample_data(batch_size=1, seq_length=SEQ_LENGTH)
# Initialize model and optimizer
print("Initializing model and optimizer...")
model = SAVORNet(seq_length=SEQ_LENGTH)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()
# Training loop
model.train()
for epoch in range(2): # Just 2 epochs for demo
optimizer.zero_grad()
# Prepare data
scores_expanded = scores.squeeze(-1).squeeze(-1)[:, :3] # [batch_size, 3]
# Convert to class labels (0-4) for 5 discrete Likert values
class_labels = (scores_expanded - 1).long() # 1-5 -> 0-4
class_labels = class_labels.unsqueeze(1).repeat(1, SEQ_LENGTH, 1) # [batch_size, seq_length, 3]
# Forward pass
outputs = model(rgb_images, depth_images, force_data, pose_data)
# Calculate loss for each food physical property
loss = 0.0
for attr_idx in range(3):
attr_output = outputs[:, :, attr_idx, :] # [batch_size, seq_length, 5]
attr_labels = class_labels[:, :, attr_idx] # [batch_size, seq_length]
# Reshape for cross-entropy: [batch_size * seq_length, 5] and [batch_size * seq_length]
attr_output_flat = attr_output.view(-1, 5) # [batch_size * seq_length, 5]
attr_labels_flat = attr_labels.view(-1) # [batch_size * seq_length]
loss += criterion(attr_output_flat, attr_labels_flat)
loss = loss / 3
# Backward pass
loss.backward()
optimizer.step()
print(f" Epoch {epoch+1}/2, Loss: {loss.item():.4f}")
# Clear cache to free memory
torch.cuda.empty_cache() if torch.cuda.is_available() else None
print("Training demonstration completed!")
def demonstrate_data_loading():
"""Demonstrate data loading (requires actual data)."""
print("\nRLSA Data Loading Demonstration")
print("=" * 50)
# Note: This requires actual data files
print("Note: This demonstration requires actual data files or RLDS dataset.")
print(" Please ensure you have the following structure:")
print(" data/")
print(" └── savor_rlds/")
print(" ├── 1.0.0/")
print(" ├── dataset_info.json")
print(" ├── features.json")
print(" └── savor_rlds-train.tfrecord-00000-of-00001")
try:
# Try to create dataset (will fail if data doesn't exist)
try:
from dataset_processor import SavorDataProcessor
except:
print(" SAVOR SavorDataProcessor not found. Skipping data loading demo.")
return
processor = SavorDataProcessor(
data_dir="./data",
batch_size=2,
sequence_length=5,
max_episodes=3, # Only process 3 episodes for testing
augment=False
)
print(f" Dataset size: {len(processor.get_data())}")
if len(processor.get_data()) > 0:
# Create dataloader
dataloader = DataLoader(processor.get_data(), batch_size=2, shuffle=True)
print(f" Number of batches: {len(dataloader)}")
# Load one batch
for images, depths, forces, poses, scores in dataloader:
print(f" Batch shapes:")
print(f" Images: {images.shape}")
print(f" Depths: {depths.shape}")
print(f" Forces: {forces.shape}")
print(f" Poses: {poses.shape}")
print(f" Scores (food physical property): {scores.shape}")
break
else:
print(" No data available. Skipping data loading demo.")
except FileNotFoundError:
print(" Data files not found. Skipping data loading demo.")
print(" Please add your data files to the 'data/' directory.")
def demonstrate_rawdata_loading():
"""Demonstrate raw data loading."""
print("\nRaw Data Loading Demonstration")
print("=" * 50)
print("Checking for raw data folders...")
print(" Expected structure:")
print(" data/")
print(" ├── subject1_food_rgb/")
print(" ├── subject1_food_depth/")
print(" ├── subject1_food_force/")
print(" └── subject1_food_pose/")
try:
# Check if data directory exists
if not os.path.exists("./data"):
print(" Data directory not found.")
print(" Please create the data directory and add your raw data folders.")
return
# Check for raw data folders
required_folders = [
"subject1_food_rgb",
"subject1_food_depth",
"subject1_food_force",
"subject1_food_pose"
]
print("\n[RAW DATA] Checking raw data folders:")
all_folders_exist = True
for folder in required_folders:
folder_path = os.path.join("./data", folder)
if os.path.exists(folder_path):
# Count files in folder
file_count = len([f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))])
print(f"{folder}/ - Found ({file_count} files)")
else:
print(f"{folder}/ - Missing")
all_folders_exist = False
if not all_folders_exist:
print("\nRaw data folders are missing.")
print("You can create the RLDS dataset from raw data using or write your own dataprocessor in self._load_raw_data().")
return
print("\nAll raw data folders found!")
except Exception as e:
print(f"[Error]Unexpected error: {e}")
print("Please check your data setup.")
def main():
"""Main demonstration function."""
print("SAVOR: Skill Affordance Learning from Visuo-Haptic Perception")
print(" for Robot-Assisted Bite Acquisition")
print("=" * 70)
# Run demonstrations
demonstrate_model_usage()
demonstrate_training_loop()
demonstrate_data_loading()
demonstrate_rawdata_loading()
print("\nAll demonstrations completed!")
if __name__ == "__main__":
main()