-
Notifications
You must be signed in to change notification settings - Fork 88
Expand file tree
/
Copy pathmnist.js
More file actions
28 lines (24 loc) · 708 Bytes
/
mnist.js
File metadata and controls
28 lines (24 loc) · 708 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
const tf = require('../');
const dataset = tf.keras.datasets.mnist();
const model = tf.keras.models.Sequential([
tf.keras.layers.Flatten({
input_shape: [28, 28]
}),
tf.keras.layers.Dense(128, {
activation: 'relu'
}),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
]);
model.summary();
const loss_fn = tf.keras.losses.SparseCategoricalCrossentropy({ from_logits: true });
model.compile({
optimizer: 'adam',
loss: loss_fn,
metrics: [ 'accuracy' ],
});
console.log('compiled model');
model.fit(dataset.train.x, dataset.train.y, { epochs: 5 });
console.log('train done');
model.evaluate(dataset.test.x, dataset.test.y, { verbose: 2 });
model.save(__dirname + '/mnist.h5');