Skip to content

Commit c71745c

Browse files
committed
model compile overload.
1 parent d639ce3 commit c71745c

File tree

5 files changed

+84
-31
lines changed

5 files changed

+84
-31
lines changed

src/TensorFlowNET.Core/Keras/Engine/IModel.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
using Tensorflow.Functions;
22
using Tensorflow.Keras.Losses;
3+
using Tensorflow.Keras.Metrics;
34
using Tensorflow.Keras.Saving;
45
using Tensorflow.NumPy;
56

67
namespace Tensorflow.Keras.Engine;
78

89
public interface IModel : ILayer
910
{
10-
void compile(IOptimizer optimizer = null,
11-
ILossFunc loss = null,
12-
string[] metrics = null);
11+
void compile(IOptimizer optimizer, ILossFunc loss);
12+
13+
void compile(IOptimizer optimizer, ILossFunc loss, string[] metrics);
1314

1415
void compile(string optimizer, string loss, string[] metrics);
1516

17+
void compile(IOptimizer optimizer, ILossFunc loss, IMetricFunc[] metrics);
18+
1619
ICallback fit(NDArray x, NDArray y,
1720
int batch_size = -1,
1821
int epochs = 1,
@@ -55,7 +58,7 @@ void load_weights(string filepath,
5558
bool skip_mismatch = false,
5659
object options = null);
5760

58-
void evaluate(NDArray x, NDArray y,
61+
Dictionary<string, float> evaluate(NDArray x, NDArray y,
5962
int batch_size = -1,
6063
int verbose = 1,
6164
int steps = -1,

src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ public ILayer LayerNormalization(Axis? axis,
156156
IInitializer beta_initializer = null,
157157
IInitializer gamma_initializer = null);
158158

159-
public ILayer Normalization(int? axis = -1, float? mean = null, float? variance = null, bool invert = false);
159+
public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false);
160160
public ILayer LeakyReLU(float alpha = 0.3f);
161161

162162
public ILayer LSTM(int units,

src/TensorFlowNET.Keras/Engine/Model.Compile.cs

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,17 @@ public partial class Model
1010
LossesContainer compiled_loss;
1111
MetricsContainer compiled_metrics;
1212

13-
public void compile(IOptimizer optimizer = null,
14-
ILossFunc loss = null,
15-
string[] metrics = null)
13+
public void compile(IOptimizer optimizer,
14+
ILossFunc loss)
1615
{
1716
this.optimizer = optimizer ?? new RMSprop(new RMSpropArgs
1817
{
1918
});
2019

2120
this.loss = loss ?? new MeanSquaredError();
2221

23-
compiled_loss = new LossesContainer(loss, output_names: output_names);
24-
compiled_metrics = new MetricsContainer(metrics, output_names: output_names);
22+
compiled_loss = new LossesContainer(this.loss, output_names: output_names);
23+
compiled_metrics = new MetricsContainer(new string[0], output_names: output_names);
2524

2625
int experimental_steps_per_execution = 1;
2726
_configure_steps_per_execution(experimental_steps_per_execution);
@@ -31,17 +30,17 @@ public void compile(IOptimizer optimizer = null,
3130
_is_compiled = true;
3231
}
3332

34-
public void compile(IOptimizer optimizer = null,
35-
ILossFunc loss = null,
36-
IMetricFunc[] metrics = null)
33+
public void compile(IOptimizer optimizer,
34+
ILossFunc loss,
35+
string[] metrics)
3736
{
3837
this.optimizer = optimizer ?? new RMSprop(new RMSpropArgs
3938
{
4039
});
4140

4241
this.loss = loss ?? new MeanSquaredError();
4342

44-
compiled_loss = new LossesContainer(loss, output_names: output_names);
43+
compiled_loss = new LossesContainer(this.loss, output_names: output_names);
4544
compiled_metrics = new MetricsContainer(metrics, output_names: output_names);
4645

4746
int experimental_steps_per_execution = 1;
@@ -52,25 +51,58 @@ public void compile(IOptimizer optimizer = null,
5251
_is_compiled = true;
5352
}
5453

55-
public void compile(string optimizer, string loss, string[] metrics)
54+
public void compile(string optimizer,
55+
string loss,
56+
string[] metrics)
5657
{
57-
var _optimizer = optimizer switch
58+
this.optimizer = optimizer switch
5859
{
5960
"rmsprop" => new RMSprop(new RMSpropArgs
6061
{
6162

6263
}),
63-
_ => throw new NotImplementedException("")
64+
_ => new RMSprop(new RMSpropArgs
65+
{
66+
})
6467
};
6568

66-
ILossFunc _loss = loss switch
69+
this.loss = loss switch
6770
{
6871
"mse" => new MeanSquaredError(),
6972
"mae" => new MeanAbsoluteError(),
70-
_ => throw new NotImplementedException("")
73+
_ => new MeanSquaredError()
7174
};
7275

73-
compile(optimizer: _optimizer, loss: _loss, metrics: metrics);
76+
compiled_loss = new LossesContainer(this.loss, output_names: output_names);
77+
compiled_metrics = new MetricsContainer(metrics, output_names: output_names);
78+
79+
int experimental_steps_per_execution = 1;
80+
_configure_steps_per_execution(experimental_steps_per_execution);
81+
82+
// Initialize cache attrs.
83+
_reset_compile_cache();
84+
_is_compiled = true;
85+
}
86+
87+
public void compile(IOptimizer optimizer,
88+
ILossFunc loss,
89+
IMetricFunc[] metrics)
90+
{
91+
this.optimizer = optimizer ?? new RMSprop(new RMSpropArgs
92+
{
93+
});
94+
95+
this.loss = loss ?? new MeanSquaredError();
96+
97+
compiled_loss = new LossesContainer(this.loss, output_names: output_names);
98+
compiled_metrics = new MetricsContainer(metrics, output_names: output_names);
99+
100+
int experimental_steps_per_execution = 1;
101+
_configure_steps_per_execution(experimental_steps_per_execution);
102+
103+
// Initialize cache attrs.
104+
_reset_compile_cache();
105+
_is_compiled = true;
74106
}
75107
}
76108
}

src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public partial class Model
2626
/// <param name="workers"></param>
2727
/// <param name="use_multiprocessing"></param>
2828
/// <param name="return_dict"></param>
29-
public void evaluate(NDArray x, NDArray y,
29+
public Dictionary<string, float> evaluate(NDArray x, NDArray y,
3030
int batch_size = -1,
3131
int verbose = 1,
3232
int steps = -1,
@@ -63,12 +63,12 @@ public void evaluate(NDArray x, NDArray y,
6363
});
6464
callbacks.on_test_begin();
6565

66+
IEnumerable<(string, Tensor)> logs = null;
6667
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
6768
{
6869
reset_metrics();
69-
//callbacks.on_epoch_begin(epoch);
70+
callbacks.on_epoch_begin(epoch);
7071
// data_handler.catch_stop_iteration();
71-
IEnumerable<(string, Tensor)> logs = null;
7272

7373
foreach (var step in data_handler.steps())
7474
{
@@ -78,12 +78,16 @@ public void evaluate(NDArray x, NDArray y,
7878
callbacks.on_test_batch_end(end_step, logs);
7979
}
8080
}
81-
Console.WriteLine();
82-
GC.Collect();
83-
GC.WaitForPendingFinalizers();
81+
82+
var results = new Dictionary<string, float>();
83+
foreach (var log in logs)
84+
{
85+
results[log.Item1] = (float)log.Item2;
86+
}
87+
return results;
8488
}
8589

86-
public KeyValuePair<string, float>[] evaluate(IDatasetV2 x)
90+
public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1)
8791
{
8892
var data_handler = new DataHandler(new DataHandlerArgs
8993
{
@@ -92,21 +96,34 @@ public KeyValuePair<string, float>[] evaluate(IDatasetV2 x)
9296
StepsPerExecution = _steps_per_execution
9397
});
9498

99+
var callbacks = new CallbackList(new CallbackParams
100+
{
101+
Model = this,
102+
Verbose = verbose,
103+
Steps = data_handler.Inferredsteps
104+
});
105+
callbacks.on_test_begin();
106+
95107
IEnumerable<(string, Tensor)> logs = null;
96108
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
97109
{
98110
reset_metrics();
99-
// callbacks.on_epoch_begin(epoch)
111+
callbacks.on_epoch_begin(epoch);
100112
// data_handler.catch_stop_iteration();
101113

102-
103114
foreach (var step in data_handler.steps())
104115
{
105116
// callbacks.on_train_batch_begin(step)
106117
logs = test_function(data_handler, iterator);
107118
}
108119
}
109-
return logs.Select(x => new KeyValuePair<string, float>(x.Item1, (float)x.Item2)).ToArray();
120+
121+
var results = new Dictionary<string, float>();
122+
foreach (var log in logs)
123+
{
124+
results[log.Item1] = (float)log.Item2;
125+
}
126+
return results;
110127
}
111128

112129
IEnumerable<(string, Tensor)> test_function(DataHandler data_handler, OwnedIterator iterator)

src/TensorFlowNET.Keras/Layers/LayersApi.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -873,9 +873,10 @@ public ILayer CategoryEncoding(int num_tokens, string output_mode = "one_hot", b
873873
CountWeights = count_weights
874874
});
875875

876-
public ILayer Normalization(int? axis = -1, float? mean = null, float? variance = null, bool invert = false)
876+
public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false)
877877
=> new Normalization(new NormalizationArgs
878878
{
879+
InputShape = input_shape,
879880
Axis = axis,
880881
Mean = mean,
881882
Variance = variance,

0 commit comments

Comments
 (0)