Skip to content

Commit 0e68fa1

Browse files
committed
follow comments
1 parent 00f15ba commit 0e68fa1

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

tests/base.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ def train_input_fn(features, labels, batch_size=32):
66
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
77
return dataset
88

9-
10-
def prepare_dataset():
11-
pass
12-
9+
def eval_input_fn(features, labels, batch_size=32):
10+
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
11+
dataset = dataset.batch(batch_size)
12+
return dataset
1313

1414
class BaseTestCases:
1515
class BaseTest(object):
@@ -18,8 +18,14 @@ def setUp(self):
1818

1919
def test_train_and_predict(self):
2020
self.setUp()
21-
train_input_fn(self.features, self.label)
2221

23-
print('Calling BaseTest:testCommon')
24-
value = 5
25-
assert(value == 5)
22+
self.model.compile(optimizer=self.model.default_optimizer(),
23+
loss=self.model.default_loss(),
24+
metrics=["accuracy"])
25+
self.model.fit(train_input_fn(self.features, self.label),
26+
epochs=self.model.default_training_epochs(),
27+
steps_per_epoch=100, verbose=0)
28+
loss, acc = self.model.evaluate(eval_input_fn(self.features, self.label))
29+
print(loss, acc)
30+
assert(loss < 10)
31+
assert(acc > 0.3)

tests/test_dnnclassifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def setUp(self):
1111
"c2": [float(x) for x in range(100)],
1212
"c3": [float(x) for x in range(100)],
1313
"c4": [float(x) for x in range(100)]}
14-
self.label = {"label": [0 for _ in range(50)] + [1 for _ in range(50)]}
14+
self.label = [0 for _ in range(50)] + [1 for _ in range(50)]
1515
feature_columns = [tf.feature_column.numeric_column(key) for key in
1616
self.features]
1717
self.model = sqlflow_models.DNNClassifier(feature_columns=feature_columns)

0 commit comments

Comments
 (0)