Skip to content

Commit 9c161b1

Browse files
committed
_ListFetchMapper for multiple fetch in Operation and Tensor.
1 parent 86986d8 commit 9c161b1

File tree

6 files changed

+86
-24
lines changed

6 files changed

+86
-24
lines changed

src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ namespace Tensorflow
1010
/// </summary>
1111
public class _ElementFetchMapper : _FetchMapper
1212
{
13-
private List<object> _unique_fetches = new List<object>();
1413
private Func<List<object>, object> _contraction_fn;
1514

1615
public _ElementFetchMapper(object[] fetches, Func<List<object>, object> contraction_fn)
@@ -32,7 +31,7 @@ public _ElementFetchMapper(object[] fetches, Func<List<object>, object> contract
3231
/// </summary>
3332
/// <param name="values"></param>
3433
/// <returns></returns>
35-
public NDArray build_results(List<object> values)
34+
public override NDArray build_results(List<object> values)
3635
{
3736
NDArray result = null;
3837

@@ -51,10 +50,5 @@ public NDArray build_results(List<object> values)
5150

5251
return result;
5352
}
54-
55-
public List<object> unique_fetches()
56-
{
57-
return _unique_fetches;
58-
}
5953
}
6054
}

src/TensorFlowNET.Core/Sessions/_FetchHandler.cs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@ namespace Tensorflow
1010
/// </summary>
1111
public class _FetchHandler
1212
{
13-
private _ElementFetchMapper _fetch_mapper;
13+
private _FetchMapper _fetch_mapper;
1414
private List<Tensor> _fetches = new List<Tensor>();
1515
private List<bool> _ops = new List<bool>();
1616
private List<Tensor> _final_fetches = new List<Tensor>();
1717
private List<object> _targets = new List<object>();
1818

1919
public _FetchHandler(Graph graph, object fetches, Dictionary<object, object> feeds = null, Action feed_handles = null)
2020
{
21-
_fetch_mapper = new _FetchMapper().for_fetch(fetches);
21+
_fetch_mapper = _FetchMapper.for_fetch(fetches);
2222
foreach(var fetch in _fetch_mapper.unique_fetches())
2323
{
2424
switch (fetch)
@@ -58,7 +58,18 @@ public NDArray build_results(BaseSession session, NDArray[] tensor_values)
5858
{
5959
var value = tensor_values[j];
6060
j += 1;
61-
full_values.Add(value);
61+
switch (value.dtype.Name)
62+
{
63+
case "Int32":
64+
full_values.Add(value.Data<int>(0));
65+
break;
66+
case "Single":
67+
full_values.Add(value.Data<float>(0));
68+
break;
69+
case "Double":
70+
full_values.Add(value.Data<double>(0));
71+
break;
72+
}
6273
}
6374
i += 1;
6475
}
Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,32 @@
1-
using System;
1+
using NumSharp.Core;
2+
using System;
23
using System.Collections.Generic;
34
using System.Text;
45

56
namespace Tensorflow
67
{
78
public class _FetchMapper
89
{
9-
public _ElementFetchMapper for_fetch(object fetch)
10+
protected List<object> _unique_fetches = new List<object>();
11+
12+
public static _FetchMapper for_fetch(object fetch)
1013
{
11-
var fetches = new object[] { fetch };
14+
var fetches = fetch.GetType().IsArray ? (object[])fetch : new object[] { fetch };
15+
16+
if (fetch.GetType().IsArray)
17+
return new _ListFetchMapper(fetches);
18+
else
19+
return new _ElementFetchMapper(fetches, (List<object> fetched_vals) => fetched_vals[0]);
20+
}
1221

13-
return new _ElementFetchMapper(fetches, (List<object> fetched_vals) =>
14-
{
15-
return fetched_vals[0];
16-
});
22+
public virtual NDArray build_results(List<object> values)
23+
{
24+
return values.ToArray();
25+
}
26+
27+
public virtual List<object> unique_fetches()
28+
{
29+
return _unique_fetches;
1730
}
1831
}
1932
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
6+
namespace Tensorflow
7+
{
8+
public class _ListFetchMapper : _FetchMapper
9+
{
10+
private _FetchMapper[] _mappers;
11+
12+
public _ListFetchMapper(object[] fetches)
13+
{
14+
_mappers = fetches.Select(fetch => _FetchMapper.for_fetch(fetch)).ToArray();
15+
_unique_fetches.AddRange(fetches);
16+
}
17+
}
18+
}

test/TensorFlowNET.Examples/LogisticRegression.cs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,7 @@ private void PrepareData()
4040
var pred = tf.nn.softmax(tf.matmul(x, W) + b); // Softmax
4141

4242
// Minimize error using cross entropy
43-
var log = tf.log(pred);
44-
var mul = y * log;
45-
var sum = tf.reduce_sum(mul, reduction_indices: 1);
46-
var neg = -sum;
47-
var cost = tf.reduce_mean(neg);
43+
var cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices: 1));
4844

4945
// Gradient Descent
5046
var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost);
@@ -68,14 +64,23 @@ private void PrepareData()
6864
{
6965
var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size);
7066
// Run optimization op (backprop) and cost op (to get loss value)
71-
var (_, c) = sess.run(optimizer,
67+
var result = sess.run(new object[] { optimizer, cost },
7268
new FeedItem(x, batch_xs),
7369
new FeedItem(y, batch_ys));
7470

71+
var c = (float)result[1];
7572
// Compute average loss
7673
avg_cost += c / total_batch;
7774
}
75+
76+
// Display logs per epoch step
77+
if ((epoch + 1) % display_step == 0)
78+
print($"Epoch: {(epoch + 1).ToString("D4")} cost= {avg_cost.ToString("G9")}");
7879
}
80+
81+
print("Optimization Finished!");
82+
83+
// Test model
7984
});
8085
}
8186
}

test/TensorFlowNET.Examples/Utility/DataSet.cs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,28 @@ public DataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape)
5252
// Finished epoch
5353
_epochs_completed += 1;
5454

55-
throw new NotImplementedException("next_batch");
55+
// Get the rest examples in this epoch
56+
var rest_num_examples = _num_examples - start;
57+
var images_rest_part = _images[np.arange(start, _num_examples)];
58+
var labels_rest_part = _labels[np.arange(start, _num_examples)];
59+
// Shuffle the data
60+
if (shuffle)
61+
{
62+
var perm = np.arange(_num_examples);
63+
np.random.shuffle(perm);
64+
_images = images[perm];
65+
_labels = labels[perm];
66+
}
67+
68+
start = 0;
69+
_index_in_epoch = batch_size - rest_num_examples;
70+
var end = _index_in_epoch;
71+
var images_new_part = _images[np.arange(start, end)];
72+
var labels_new_part = _labels[np.arange(start, end)];
73+
74+
/*return (np.concatenate(new float[][] { images_rest_part.Data<float>(), images_new_part.Data<float>() }, axis: 0),
75+
np.concatenate(new float[][] { labels_rest_part.Data<float>(), labels_new_part.Data<float>() }, axis: 0));*/
76+
return (images_new_part, labels_new_part);
5677
}
5778
else
5879
{

0 commit comments

Comments
 (0)