Skip to content

Commit 4f794e8

Browse files
AsakusaRinneOceania2018
authored andcommitted
Refine the resnet example.
1 parent 69889a3 commit 4f794e8

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

README.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ using static Tensorflow.KerasApi;
131131
using Tensorflow;
132132
using Tensorflow.NumPy;
133133

134-
var layers = new LayersApi();
134+
var layers = keras.layers;
135135
// input layer
136136
var inputs = keras.Input(shape: (32, 32, 3), name: "img");
137137
// convolutional layer
@@ -155,17 +155,19 @@ var model = keras.Model(inputs, outputs, name: "toy_resnet");
155155
model.summary();
156156
// compile keras model in tensorflow static graph
157157
model.compile(optimizer: keras.optimizers.RMSprop(1e-3f),
158-
loss: keras.losses.CategoricalCrossentropy(from_logits: true),
158+
loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true),
159159
metrics: new[] { "acc" });
160160
// prepare dataset
161161
var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data();
162+
// normalize the input
162163
x_train = x_train / 255.0f;
163-
y_train = np_utils.to_categorical(y_train, 10);
164164
// training
165165
model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)],
166-
batch_size: 64,
167-
epochs: 10,
168-
validation_split: 0.2f);
166+
batch_size: 64,
167+
epochs: 10,
168+
validation_split: 0.2f);
169+
// save the model
170+
model.save("./toy_resnet_model");
169171
```
170172

171173
The F# example for linear regression is available [here](docs/Example-fsharp.md).

docs/README-CN.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ using static Tensorflow.KerasApi;
130130
using Tensorflow;
131131
using Tensorflow.NumPy;
132132

133-
var layers = new LayersApi();
133+
var layers = keras.layers;
134134
// input layer
135135
var inputs = keras.Input(shape: (32, 32, 3), name: "img");
136136
// convolutional layer
@@ -154,17 +154,19 @@ var model = keras.Model(inputs, outputs, name: "toy_resnet");
154154
model.summary();
155155
// compile keras model in tensorflow static graph
156156
model.compile(optimizer: keras.optimizers.RMSprop(1e-3f),
157-
loss: keras.losses.CategoricalCrossentropy(from_logits: true),
157+
loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true),
158158
metrics: new[] { "acc" });
159159
// prepare dataset
160160
var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data();
161+
// normalize the input
161162
x_train = x_train / 255.0f;
162-
y_train = np_utils.to_categorical(y_train, 10);
163163
// training
164164
model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)],
165-
batch_size: 64,
166-
epochs: 10,
167-
validation_split: 0.2f);
165+
batch_size: 64,
166+
epochs: 10,
167+
validation_split: 0.2f);
168+
// save the model
169+
model.save("./toy_resnet_model");
168170
```
169171

170172
此外,Tensorflow.NET也支持用F#搭建上述模型进行训练和推理。

0 commit comments

Comments
 (0)