Skip to content

Commit 2584b68

Browse files
committed
Fix conv_net.load_weights #956
1 parent 290c162 commit 2584b68

File tree

3 files changed

+183
-1
lines changed

3 files changed

+183
-1
lines changed

src/TensorFlowNET.Keras/Engine/Layer.Layers.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public partial class Layer
1010

1111
protected void StackLayers(params ILayer[] layers)
1212
{
13-
_layers.AddRange(layers);
13+
_self_tracked_trackables.AddRange(layers);
1414
}
1515

1616
public virtual Shape ComputeOutputShape(Shape input_shape)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using System;
2+
using System.Linq;
3+
using Tensorflow.Graphs;
4+
using Tensorflow.Keras.ArgsDefinition;
5+
using Tensorflow.Keras.Losses;
6+
using Tensorflow.Keras.Optimizers;
7+
using static Tensorflow.Binding;
8+
using static Tensorflow.KerasApi;
9+
10+
namespace Tensorflow.Keras.Engine
11+
{
12+
public partial class Model
13+
{
14+
public override void build(Shape input_shape)
15+
{
16+
var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph();
17+
18+
graph.as_default();
19+
20+
var x = tf.placeholder(DType, input_shape);
21+
Call(x, training: false);
22+
23+
graph.Exit();
24+
25+
base.build(input_shape);
26+
}
27+
}
28+
}

src/python/subclassing.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from __future__ import absolute_import, division, print_function
2+
3+
import tensorflow as tf
4+
from tensorflow.keras import Model, layers
5+
import numpy as np
6+
7+
# MNIST dataset parameters.
8+
num_classes = 10 # total classes (0-9 digits).
9+
10+
# Training parameters.
11+
learning_rate = 0.001
12+
training_steps = 100
13+
batch_size = 128
14+
display_step = 10
15+
16+
# Network parameters.
17+
conv1_filters = 32 # number of filters for 1st conv layer.
18+
conv2_filters = 64 # number of filters for 2nd conv layer.
19+
fc1_units = 1024 # number of neurons for 1st fully-connected layer.
20+
21+
# Prepare MNIST data.
22+
from tensorflow.keras.datasets import mnist
23+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
24+
# Convert to float32.
25+
x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)
26+
# Normalize images value from [0, 255] to [0, 1].
27+
x_train, x_test = x_train / 255., x_test / 255.
28+
29+
# Use tf.data API to shuffle and batch data.
30+
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
31+
train_data = train_data.repeat().shuffle(5000).batch(batch_size).prefetch(1)
32+
33+
# Create TF Model.
34+
class ConvNet(Model):
35+
# Set layers.
36+
def __init__(self):
37+
super(ConvNet, self).__init__()
38+
# Convolution Layer with 32 filters and a kernel size of 5.
39+
self.conv1 = layers.Conv2D(32, kernel_size=5, activation=tf.nn.relu)
40+
# Max Pooling (down-sampling) with kernel size of 2 and strides of 2.
41+
self.maxpool1 = layers.MaxPool2D(2, strides=2)
42+
43+
# Convolution Layer with 64 filters and a kernel size of 3.
44+
self.conv2 = layers.Conv2D(64, kernel_size=3, activation=tf.nn.relu)
45+
# Max Pooling (down-sampling) with kernel size of 2 and strides of 2.
46+
self.maxpool2 = layers.MaxPool2D(2, strides=2)
47+
48+
# Flatten the data to a 1-D vector for the fully connected layer.
49+
self.flatten = layers.Flatten()
50+
51+
# Fully connected layer.
52+
self.fc1 = layers.Dense(1024)
53+
# Apply Dropout (if is_training is False, dropout is not applied).
54+
self.dropout = layers.Dropout(rate=0.5)
55+
56+
# Output layer, class prediction.
57+
self.out = layers.Dense(num_classes)
58+
59+
# Set forward pass.
60+
def call(self, x, is_training=False):
61+
x = tf.reshape(x, [-1, 28, 28, 1])
62+
x = self.conv1(x)
63+
x = self.maxpool1(x)
64+
x = self.conv2(x)
65+
x = self.maxpool2(x)
66+
x = self.flatten(x)
67+
x = self.fc1(x)
68+
x = self.dropout(x, training=is_training)
69+
x = self.out(x)
70+
if not is_training:
71+
# tf cross entropy expect logits without softmax, so only
72+
# apply softmax when not training.
73+
x = tf.nn.softmax(x)
74+
return x
75+
'''
76+
# Build neural network model.
77+
conv_net = ConvNet()
78+
79+
# Cross-Entropy Loss.
80+
# Note that this will apply 'softmax' to the logits.
81+
def cross_entropy_loss(x, y):
82+
# Convert labels to int 64 for tf cross-entropy function.
83+
y = tf.cast(y, tf.int64)
84+
# Apply softmax to logits and compute cross-entropy.
85+
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=x)
86+
# Average loss across the batch.
87+
return tf.reduce_mean(loss)
88+
89+
# Accuracy metric.
90+
def accuracy(y_pred, y_true):
91+
# Predicted class is the index of highest score in prediction vector (i.e. argmax).
92+
correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.cast(y_true, tf.int64))
93+
return tf.reduce_mean(tf.cast(correct_prediction, tf.float32), axis=-1)
94+
95+
# Stochastic gradient descent optimizer.
96+
optimizer = tf.optimizers.Adam(learning_rate)
97+
98+
# Optimization process.
99+
def run_optimization(x, y):
100+
# Wrap computation inside a GradientTape for automatic differentiation.
101+
with tf.GradientTape() as g:
102+
# Forward pass.
103+
pred = conv_net(x, is_training=True)
104+
# Compute loss.
105+
loss = cross_entropy_loss(pred, y)
106+
107+
# Variables to update, i.e. trainable variables.
108+
trainable_variables = conv_net.trainable_variables
109+
110+
# Compute gradients.
111+
gradients = g.gradient(loss, trainable_variables)
112+
113+
# Update W and b following gradients.
114+
optimizer.apply_gradients(zip(gradients, trainable_variables))
115+
116+
# Run training for the given number of steps.
117+
118+
for step, (batch_x, batch_y) in enumerate(train_data.take(training_steps), 1):
119+
# Run the optimization to update W and b values.
120+
run_optimization(batch_x, batch_y)
121+
122+
if step % display_step == 0:
123+
pred = conv_net(batch_x)
124+
loss = cross_entropy_loss(pred, batch_y)
125+
acc = accuracy(pred, batch_y)
126+
print("step: %i, loss: %f, accuracy: %f" % (step, loss, acc))
127+
128+
# Test model on validation set.
129+
pred = conv_net(x_test)
130+
print("Test Accuracy: %f" % accuracy(pred, y_test))
131+
132+
conv_net.save_weights('weights.h5')
133+
'''
134+
135+
conv_net = ConvNet()
136+
conv_net.build(x_test.shape)
137+
conv_net.load_weights('weights.h5')
138+
# Test model on validation set.
139+
pred = conv_net(x_test)
140+
# print("Test Accuracy: %f" % accuracy(pred, y_test))
141+
142+
# Visualize predictions.
143+
import matplotlib.pyplot as plt
144+
145+
# Predict 5 images from validation set.
146+
n_images = 5
147+
test_images = x_test[:n_images]
148+
predictions = conv_net(test_images)
149+
150+
# Display image and model prediction.
151+
for i in range(n_images):
152+
plt.imshow(np.reshape(test_images[i], [28, 28]), cmap='gray')
153+
plt.show()
154+
print("Model prediction: %i" % np.argmax(predictions.numpy()[i]))

0 commit comments

Comments
 (0)