Skip to content

Commit c6df3d8

Browse files
committed
generate gif in cifar10 evaluation
1 parent 9069a1d commit c6df3d8

1 file changed

Lines changed: 11 additions & 1 deletion

File tree

tasks/class_cifar10/eval_class_cifar10.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,15 @@
2525

2626

2727
from ncalab import ( # noqa: E402
28+
Animator,
2829
ClassificationNCAModel,
2930
fix_random_seed,
3031
get_compute_device,
3132
)
3233

34+
FIGURE_PATH = TASK_PATH / "figures"
35+
FIGURE_PATH.mkdir(exist_ok=True)
36+
3337
T = transforms.Compose(
3438
[
3539
v2.ToImage(),
@@ -52,7 +56,7 @@ def eval_class_cifar10(
5256
root=TASK_PATH / "data", train=False, download=True, transform=T
5357
)
5458
loader_test = torch.utils.data.DataLoader(
55-
testset, batch_size=128, shuffle=False, num_workers=2
59+
testset, batch_size=8, shuffle=False, num_workers=2
5660
)
5761

5862
class_names = [
@@ -125,6 +129,12 @@ def eval_class_cifar10(
125129
macro_auc_ = macro_auc.compute().item()
126130
micro_f1_ = micro_f1.compute().item()
127131

132+
seed = next(iter(loader_test))[0].to(device)
133+
out_path = FIGURE_PATH / "classification_cifar10.gif"
134+
animator = Animator(nca, seed, interval=100, steps=42, hidden=True, show_input=True)
135+
animator.save(out_path)
136+
print(f"Animation saved to: {out_path}")
137+
128138
print()
129139
print(
130140
f"ACC macro: {macro_acc_:.3f} ACC micro: {micro_acc_:.3f} AUC macro: {macro_auc_:.3f} F1 micro: {micro_f1_:.3f}"

0 commit comments

Comments
 (0)