-
Notifications
You must be signed in to change notification settings - Fork 56
Expand file tree
/
Copy pathForwardForward.cs
More file actions
111 lines (89 loc) · 4.55 KB
/
ForwardForward.cs
File metadata and controls
111 lines (89 loc) · 4.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
using System.IO;
using System.Collections.Generic;
using System.Diagnostics;
using TorchSharp;
using static TorchSharp.torchvision;
using TorchSharp.Examples;
using TorchSharp.Examples.Utils;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
using static TorchSharp.torch.nn.functional;
namespace CSharpExamples
{
/// <summary>
/// Forward-Forward MNIST classification
///
/// Based on: https://github.com/pytorch/examples/tree/main/mnist_forward_forward
///
/// Implements the Forward-Forward algorithm (Geoffrey Hinton, 2022). Instead of
/// backpropagation, each layer is trained independently using a local contrastive loss.
/// Positive examples have the correct label overlaid, negative examples have wrong labels.
/// </summary>
public class ForwardForward
{
internal static void Run(int epochs, int timeout, string logdir)
{
var device =
torch.cuda.is_available() ? torch.CUDA :
torch.mps_is_available() ? torch.MPS :
torch.CPU;
Console.WriteLine();
Console.WriteLine($"\tRunning Forward-Forward MNIST on {device.type} for {epochs} epochs.");
Console.WriteLine();
torch.random.manual_seed(1);
var dataset = "mnist";
var datasetPath = Path.Join(Environment.GetFolderPath(Environment.SpecialFolder.DesktopDirectory), "..", "Downloads", dataset);
var sourceDir = datasetPath;
var targetDir = Path.Combine(datasetPath, "test_data");
if (!Directory.Exists(targetDir)) {
Directory.CreateDirectory(targetDir);
Decompress.DecompressGZipFile(Path.Combine(sourceDir, "train-images-idx3-ubyte.gz"), targetDir);
Decompress.DecompressGZipFile(Path.Combine(sourceDir, "train-labels-idx1-ubyte.gz"), targetDir);
Decompress.DecompressGZipFile(Path.Combine(sourceDir, "t10k-images-idx3-ubyte.gz"), targetDir);
Decompress.DecompressGZipFile(Path.Combine(sourceDir, "t10k-labels-idx1-ubyte.gz"), targetDir);
}
Console.WriteLine($"\tLoading data...");
// Load full training set as a single batch for the Forward-Forward algorithm
int trainSize = 50000;
int testSize = 10000;
using (MNISTReader trainReader = new MNISTReader(targetDir, "train", trainSize, device: device),
testReader = new MNISTReader(targetDir, "t10k", testSize, device: device))
{
Stopwatch totalTime = new Stopwatch();
totalTime.Start();
// Get one big batch of training data
Tensor x = null, y = null, xTe = null, yTe = null;
foreach (var (data, target) in trainReader) {
// Flatten the images: (N, 1, 28, 28) -> (N, 784)
x = data.view(data.shape[0], -1);
y = target;
break; // Just the first (and only) batch
}
foreach (var (data, target) in testReader) {
xTe = data.view(data.shape[0], -1);
yTe = target;
break;
}
Console.WriteLine($"\tCreating Forward-Forward network [784, 500, 500]...");
var net = new ForwardForwardNet(new int[] { 784, 500, 500 }, device);
// Create positive and negative examples
var xPos = ForwardForwardNet.OverlayLabelOnInput(x, y);
var yNeg = ForwardForwardNet.GetNegativeLabels(y);
var xNeg = ForwardForwardNet.OverlayLabelOnInput(x, yNeg);
Console.WriteLine($"\tTraining...");
net.Train(xPos, xNeg, epochs, lr: 0.03, logInterval: 10);
// Evaluate
var trainPred = net.Predict(x);
var trainError = 1.0f - trainPred.eq(y).to_type(ScalarType.Float32).mean().item<float>();
Console.WriteLine($"\tTrain error: {trainError:F4}");
var testPred = net.Predict(xTe);
var testError = 1.0f - testPred.eq(yTe).to_type(ScalarType.Float32).mean().item<float>();
Console.WriteLine($"\tTest error: {testError:F4}");
totalTime.Stop();
Console.WriteLine($"Elapsed time: {totalTime.Elapsed.TotalSeconds:F1} s.");
}
}
}
}