Skip to content

Commit 0ee9d42

Browse files
authored
Merge pull request #1187 from Wanglongzhi2001/master
feat: add the implementation of sample_weight in model.fit
2 parents 15763df + f5af07c commit 0ee9d42

File tree

13 files changed

+250
-100
lines changed

13 files changed

+250
-100
lines changed

src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Tensorflow.Keras.Engine;
22
using Tensorflow.Keras.Saving;
3+
using Tensorflow.NumPy;
34

45
namespace Tensorflow.Keras.ArgsDefinition
56
{
@@ -16,5 +17,7 @@ public class DataAdapterArgs: IKerasConfig
1617
public int Worker { get; set; }
1718
public bool UseMultiprocessing { get; set; }
1819
public IModel Model { get; set; }
20+
public Dictionary<int, float> ClassWeight = null;
21+
public NDArray SampleWeight = null;
1922
}
2023
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Tensorflow.Keras.Engine;
22
using Tensorflow.Keras.Saving;
3+
using Tensorflow.NumPy;
34

45
namespace Tensorflow.Keras.ArgsDefinition
56
{
@@ -18,5 +19,7 @@ public class DataHandlerArgs: IKerasConfig
1819
public bool UseMultiprocessing { get; set; } = false;
1920
public IModel Model { get; set; }
2021
public IVariableV1 StepsPerExecution { get; set; }
22+
public Dictionary<int, float> ClassWeight = null;
23+
public NDArray SampleWeight = null;
2124
}
2225
}

src/TensorFlowNET.Core/Keras/Engine/IModel.cs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using Tensorflow.Keras.Metrics;
44
using Tensorflow.Keras.Saving;
55
using Tensorflow.NumPy;
6+
using Tensorflow.Util;
67

78
namespace Tensorflow.Keras.Engine;
89

@@ -22,8 +23,10 @@ ICallback fit(NDArray x, NDArray y,
2223
int verbose = 1,
2324
List<ICallback> callbacks = null,
2425
float validation_split = 0f,
25-
(NDArray val_x, NDArray val_y)? validation_data = null,
26+
ValidationDataPack validation_data = null,
2627
bool shuffle = true,
28+
Dictionary<int, float> class_weight = null,
29+
NDArray sample_weight = null,
2730
int initial_epoch = 0,
2831
int max_queue_size = 10,
2932
int workers = 1,
@@ -35,8 +38,10 @@ ICallback fit(IEnumerable<NDArray> x, NDArray y,
3538
int verbose = 1,
3639
List<ICallback> callbacks = null,
3740
float validation_split = 0f,
38-
(IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null,
41+
ValidationDataPack validation_data = null,
3942
bool shuffle = true,
43+
Dictionary<int, float> class_weight = null,
44+
NDArray sample_weight = null,
4045
int initial_epoch = 0,
4146
int max_queue_size = 10,
4247
int workers = 1,
@@ -63,6 +68,8 @@ void load_weights(string filepath,
6368
Dictionary<string, float> evaluate(NDArray x, NDArray y,
6469
int batch_size = -1,
6570
int verbose = 1,
71+
NDArray sample_weight = null,
72+
6673
int steps = -1,
6774
int max_queue_size = 10,
6875
int workers = 1,

src/TensorFlowNET.Core/Util/Data.cs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
using Tensorflow.NumPy;
2+
3+
namespace Tensorflow.Util
4+
{
5+
/// <summary>
6+
/// ValidationDataPack is used to pass validation data to fit method.
7+
/// It can recive data which could be A tuple `(x_val, xy_val)` or `(x_val, y_val, sample_weight_val)` of Numpy arrays.
8+
/// </summary>
9+
public class ValidationDataPack
10+
{
11+
public NDArray val_x;
12+
public NDArray val_y;
13+
public NDArray val_sample_weight = null;
14+
15+
public ValidationDataPack((NDArray, NDArray) validation_data)
16+
{
17+
this.val_x = validation_data.Item1;
18+
this.val_y = validation_data.Item2;
19+
}
20+
21+
public ValidationDataPack((NDArray, NDArray, NDArray) validation_data)
22+
{
23+
this.val_x = validation_data.Item1;
24+
this.val_y = validation_data.Item2;
25+
this.val_sample_weight = validation_data.Item3;
26+
}
27+
28+
public ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data)
29+
{
30+
this.val_x = validation_data.Item1.ToArray()[0];
31+
this.val_y = validation_data.Item2;
32+
}
33+
34+
public ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data)
35+
{
36+
this.val_x = validation_data.Item1.ToArray()[0];
37+
this.val_y = validation_data.Item2;
38+
this.val_sample_weight = validation_data.Item3;
39+
}
40+
41+
public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data)
42+
=> new ValidationDataPack(validation_data);
43+
44+
public static implicit operator ValidationDataPack((NDArray, NDArray, NDArray) validation_data)
45+
=> new ValidationDataPack(validation_data);
46+
47+
public static implicit operator ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data)
48+
=> new ValidationDataPack(validation_data);
49+
50+
public static implicit operator ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data)
51+
=> new ValidationDataPack(validation_data);
52+
53+
public void Deconstruct(out NDArray val_x, out NDArray val_y)
54+
{
55+
val_x = this.val_x;
56+
val_y = this.val_y;
57+
}
58+
59+
public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight)
60+
{
61+
val_x = this.val_x;
62+
val_y = this.val_y;
63+
val_sample_weight = this.val_sample_weight;
64+
}
65+
}
66+
}

src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System.Collections.Generic;
33
using System.Text;
44
using Tensorflow.Keras.ArgsDefinition;
5+
using Tensorflow.Util;
56

67
namespace Tensorflow.Keras.Engine.DataAdapters
78
{
@@ -34,9 +35,67 @@ public virtual (Tensors, Tensors) Expand1d(Tensors x, Tensors y)
3435
return (x, y);
3536
}
3637

38+
public virtual (Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight)
39+
{
40+
for (int i = 0; i < x.Length; i++)
41+
{
42+
if (x[i].shape.ndim == 1)
43+
x[i] = array_ops.expand_dims(x[i], axis: -1);
44+
}
45+
for (int i = 0; i < y.Length; i++)
46+
{
47+
if (y[i].shape.ndim == 1)
48+
y[i] = array_ops.expand_dims(y[i], axis: -1);
49+
}
50+
for (int i = 0; i < sample_weight.Length; i++)
51+
{
52+
if (sample_weight[i].shape.ndim == 1)
53+
sample_weight[i] = array_ops.expand_dims(sample_weight[i], axis: -1);
54+
}
55+
return (x, y, sample_weight);
56+
}
57+
3758
public virtual bool ShouldRecreateIterator()
3859
{
3960
return true;
4061
}
62+
63+
public static ((NDArray, NDArray, NDArray),ValidationDataPack) train_validation_split((NDArray, NDArray, NDArray) x_y_sample_weight, float validation_split)
64+
{
65+
var x = x_y_sample_weight.Item1;
66+
var y = x_y_sample_weight.Item2;
67+
var sample_weight = x_y_sample_weight.Item3;
68+
int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split));
69+
var train_x = x[new Slice(0, train_count)];
70+
var train_y = y[new Slice(0, train_count)];
71+
ValidationDataPack validation_data;
72+
if (sample_weight != null)
73+
{
74+
validation_data = (x[new Slice(train_count)], y[new Slice(train_count)], sample_weight[new Slice(train_count)]);
75+
sample_weight = sample_weight[new Slice(0, train_count)];
76+
}
77+
else
78+
{
79+
validation_data = (x[new Slice(train_count)], y[new Slice(train_count)]);
80+
}
81+
82+
return ((train_x, train_y, sample_weight), validation_data);
83+
}
84+
85+
public static ((IEnumerable<NDArray>, NDArray, NDArray), ValidationDataPack) train_validation_split((IEnumerable<NDArray>, NDArray, NDArray) x_y_sample_weight, float validation_split)
86+
{
87+
var x = x_y_sample_weight.Item1;
88+
var y = x_y_sample_weight.Item2;
89+
var sample_weight = x_y_sample_weight.Item3;
90+
int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split));
91+
var train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray);
92+
var train_y = y[new Slice(0, train_count)];
93+
var val_x = x.Select(x => x[new Slice(train_count)] as NDArray);
94+
var val_y = y[new Slice(train_count)];
95+
NDArray tmp_sample_weight = sample_weight;
96+
sample_weight = sample_weight[new Slice(0, train_count)];
97+
ValidationDataPack validation_data = (val_x, val_y, tmp_sample_weight[new Slice(train_count)]);
98+
return ((train_x, train_y, sample_weight), validation_data);
99+
}
41100
}
42101
}

src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System.Collections.Generic;
33
using Tensorflow.Keras.ArgsDefinition;
44
using static Tensorflow.Binding;
5+
using Tensorflow.Keras.Utils;
56

67
namespace Tensorflow.Keras.Engine.DataAdapters
78
{
@@ -28,6 +29,7 @@ public class DataHandler
2829
public DataHandler(DataHandlerArgs args)
2930
{
3031
this.args = args;
32+
3133
if (args.StepsPerExecution == null)
3234
{
3335
_steps_per_execution = tf.Variable(1L);
@@ -48,6 +50,7 @@ public DataHandler(DataHandlerArgs args)
4850
BatchSize = args.BatchSize,
4951
Steps = args.StepsPerEpoch,
5052
Epochs = args.Epochs - args.InitialEpoch,
53+
SampleWeight = args.SampleWeight,
5154
Shuffle = args.Shuffle,
5255
MaxQueueSize = args.MaxQueueSize,
5356
Worker = args.Workers,

src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ public interface IDataAdapter
1717
IDatasetV2 GetDataset();
1818
int GetSize();
1919
(Tensors, Tensors) Expand1d(Tensors x, Tensors y);
20+
(Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight);
21+
2022
bool ShouldRecreateIterator();
2123
}
2224
}

src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public class TensorLikeDataAdapter : DataAdapter, IDataAdapter
2020
public TensorLikeDataAdapter(DataAdapterArgs args)
2121
{
2222
this.args = args;
23-
_process_tensorlike();
23+
Tensor sample_weight_tensor = args.SampleWeight != null ? _process_tensorlike(args.SampleWeight) : null;
2424
num_samples = (int)args.X.shape[0];
2525
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize;
2626
_batch_size = batch_size;
@@ -37,6 +37,8 @@ public TensorLikeDataAdapter(DataAdapterArgs args)
3737
inputs.AddRange(args.X);
3838
if (args.Y != null)
3939
inputs.AddRange(args.Y);
40+
if (sample_weight_tensor != null)
41+
inputs.Add(sample_weight_tensor);
4042
dataset = slice_inputs(indices_dataset, inputs);
4143
dataset.FirstInputTensorCount = args.X.Length;
4244
}
@@ -94,8 +96,9 @@ IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensors elements)
9496

9597
public override bool ShouldRecreateIterator() => false;
9698

97-
void _process_tensorlike()
99+
Tensor _process_tensorlike(NDArray sample_weights)
98100
{
101+
return tf.convert_to_tensor(sample_weights);
99102
}
100103
}
101104
}

src/TensorFlowNET.Keras/Engine/LossesContainer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ public LossesContainer(ILossFunc losses, string[] output_names = null)
2626
/// </summary>
2727
/// <param name="y_true"></param>
2828
/// <param name="y_pred"></param>
29-
public Tensor Call(Tensor y_true, Tensor y_pred)
29+
public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
3030
{
3131
if (!_built)
3232
Build(y_pred);
33-
var loss_value = _losses.Call(y_true, y_pred);
33+
var loss_value = _losses.Call(y_true, y_pred, sample_weight:sample_weight);
3434
var loss_metric_value = loss_value;
3535
var batch_dim = array_ops.shape(y_true)[0];
3636

src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ public partial class Model
3030
public Dictionary<string, float> evaluate(NDArray x, NDArray y,
3131
int batch_size = -1,
3232
int verbose = 1,
33+
NDArray sample_weight = null,
3334
int steps = -1,
3435
int max_queue_size = 10,
3536
int workers = 1,
@@ -51,6 +52,7 @@ public Dictionary<string, float> evaluate(NDArray x, NDArray y,
5152
StepsPerEpoch = steps,
5253
InitialEpoch = 0,
5354
Epochs = 1,
55+
SampleWeight = sample_weight,
5456
MaxQueueSize = max_queue_size,
5557
Workers = workers,
5658
UseMultiprocessing = use_multiprocessing,
@@ -140,7 +142,8 @@ Dictionary<string, float> evaluate(DataHandler data_handler, CallbackList callba
140142
Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator)
141143
{
142144
var data = iterator.next();
143-
var outputs = test_step(data_handler, data[0], data[1]);
145+
var outputs = data.Length == 2 ? test_step(data_handler, data[0], data[1]) :
146+
test_step(data_handler, data[0], data[1], data[2]);
144147
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
145148
return outputs;
146149
}
@@ -149,17 +152,23 @@ Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handl
149152
{
150153
var data = iterator.next();
151154
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
152-
var outputs = test_step(data_handler, data.Take(x_size).ToArray(), data.Skip(x_size).ToArray());
155+
var outputs = data.Length == 2 ?
156+
test_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())) :
157+
test_step(
158+
data_handler,
159+
new Tensors(data.Take(x_size).ToArray()),
160+
new Tensors(data.Skip(x_size).Take(x_size).ToArray()),
161+
new Tensors(data.Skip(2 * x_size).ToArray()));
153162
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
154163
return outputs;
155164
}
156165

157166

158-
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y)
167+
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null)
159168
{
160-
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
169+
(x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight);
161170
var y_pred = Apply(x, training: false);
162-
var loss = compiled_loss.Call(y, y_pred);
171+
var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight);
163172
compiled_metrics.update_state(y, y_pred);
164173
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);
165174
}

0 commit comments

Comments
 (0)