Skip to content

Commit cbf2d81

Browse files
committed
ICallback
1 parent 76a964f commit cbf2d81

32 files changed

+269
-216
lines changed

src/TensorFlowNET.Console/Tensorflow.Console.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
</PropertyGroup>
2121

2222
<ItemGroup>
23-
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.10.0" />
23+
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.11.0" />
2424
</ItemGroup>
2525

2626
<ItemGroup>
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
namespace Tensorflow.Keras.Engine;
2+
3+
public interface ICallback
4+
{
5+
Dictionary<string, List<float>> history { get; set; }
6+
void on_train_begin();
7+
void on_epoch_begin(int epoch);
8+
void on_train_batch_begin(long step);
9+
void on_train_batch_end(long end_step, Dictionary<string, float> logs);
10+
void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs);
11+
void on_predict_begin();
12+
void on_predict_batch_begin(long step);
13+
void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs);
14+
void on_predict_end();
15+
}
Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,65 @@
1-
namespace Tensorflow.Keras.Engine
1+
using Tensorflow.Functions;
2+
using Tensorflow.Keras.Losses;
3+
using Tensorflow.Keras.Saving;
4+
using Tensorflow.NumPy;
5+
6+
namespace Tensorflow.Keras.Engine;
7+
8+
public interface IModel : ILayer
29
{
3-
public interface IModel
4-
{
5-
}
10+
void compile(IOptimizer optimizer = null,
11+
ILossFunc loss = null,
12+
string[] metrics = null);
13+
14+
void compile(string optimizer, string loss, string[] metrics);
15+
16+
ICallback fit(NDArray x, NDArray y,
17+
int batch_size = -1,
18+
int epochs = 1,
19+
int verbose = 1,
20+
float validation_split = 0f,
21+
bool shuffle = true,
22+
int initial_epoch = 0,
23+
int max_queue_size = 10,
24+
int workers = 1,
25+
bool use_multiprocessing = false);
26+
27+
void save(string filepath,
28+
bool overwrite = true,
29+
bool include_optimizer = true,
30+
string save_format = "tf",
31+
SaveOptions? options = null,
32+
ConcreteFunction? signatures = null,
33+
bool save_traces = true);
34+
35+
void save_weights(string filepath,
36+
bool overwrite = true,
37+
string save_format = null,
38+
object options = null);
39+
40+
void load_weights(string filepath,
41+
bool by_name = false,
42+
bool skip_mismatch = false,
43+
object options = null);
44+
45+
void evaluate(NDArray x, NDArray y,
46+
int batch_size = -1,
47+
int verbose = 1,
48+
int steps = -1,
49+
int max_queue_size = 10,
50+
int workers = 1,
51+
bool use_multiprocessing = false,
52+
bool return_dict = false);
53+
54+
Tensors predict(Tensor x,
55+
int batch_size = -1,
56+
int verbose = 0,
57+
int steps = -1,
58+
int max_queue_size = 10,
59+
int workers = 1,
60+
bool use_multiprocessing = false);
61+
62+
void summary(int line_length = -1, float[] positions = null);
63+
64+
IKerasConfig get_config();
665
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
namespace Tensorflow.Keras.Engine;
2+
3+
public interface IOptimizer
4+
{
5+
Tensor[] aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars);
6+
Tensor[] clip_gradients(Tensor[] grads);
7+
void apply_gradients((Tensor, ResourceVariable) grads_and_vars,
8+
string name = null,
9+
bool experimental_aggregate_gradients = true);
10+
void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars,
11+
string name = null,
12+
bool experimental_aggregate_gradients = true);
13+
}

src/TensorFlowNET.Core/Keras/IKerasApi.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow.Keras.Engine;
45
using Tensorflow.Keras.Layers;
56
using Tensorflow.Keras.Losses;
67
using Tensorflow.Keras.Metrics;
@@ -13,5 +14,13 @@ public interface IKerasApi
1314
public ILossesApi losses { get; }
1415
public IMetricsApi metrics { get; }
1516
public IInitializersApi initializers { get; }
17+
18+
/// <summary>
19+
/// `Model` groups layers into an object with training and inference features.
20+
/// </summary>
21+
/// <param name="input"></param>
22+
/// <param name="output"></param>
23+
/// <returns></returns>
24+
public IModel Model(Tensors inputs, Tensors outputs, string name = null);
1625
}
1726
}

src/TensorFlowNET.Core/Keras/Layers/ILayer.cs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
using System.Collections.Generic;
2-
using Tensorflow.Keras.ArgsDefinition;
3-
using Tensorflow.Keras.Engine;
1+
using Tensorflow.Keras.Engine;
42
using Tensorflow.Keras.Saving;
53
using Tensorflow.Training;
64

@@ -15,7 +13,7 @@ public interface ILayer: IWithTrackable, IKerasConfigable
1513
List<ILayer> Layers { get; }
1614
List<INode> InboundNodes { get; }
1715
List<INode> OutboundNodes { get; }
18-
Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false);
16+
Tensors Apply(Tensors inputs, Tensor state = null, bool training = false);
1917
List<IVariableV1> TrainableVariables { get; }
2018
List<IVariableV1> TrainableWeights { get; }
2119
List<IVariableV1> NonTrainableWeights { get; }

src/TensorFlowNET.Core/Tensorflow.Binding.csproj

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
<AssemblyName>Tensorflow.Binding</AssemblyName>
66
<RootNamespace>Tensorflow</RootNamespace>
77
<TargetTensorFlow>2.10.0</TargetTensorFlow>
8-
<Version>0.100.4</Version>
8+
<Version>1.0.0</Version>
99
<LangVersion>10.0</LangVersion>
1010
<Nullable>enable</Nullable>
1111
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
@@ -20,7 +20,7 @@
2020
<Description>Google's TensorFlow full binding in .NET Standard.
2121
Building, training and infering deep learning models.
2222
https://tensorflownet.readthedocs.io</Description>
23-
<AssemblyVersion>0.100.4.0</AssemblyVersion>
23+
<AssemblyVersion>1.0.0.0</AssemblyVersion>
2424
<PackageReleaseNotes>
2525
tf.net 0.100.x and above are based on tensorflow native 2.10.0
2626

@@ -38,7 +38,7 @@ https://tensorflownet.readthedocs.io</Description>
3838
tf.net 0.7x.x aligns with TensorFlow v2.7.x native library.
3939
tf.net 0.10x.x aligns with TensorFlow v2.10.x native library.
4040
</PackageReleaseNotes>
41-
<FileVersion>0.100.4.0</FileVersion>
41+
<FileVersion>1.0.0.0</FileVersion>
4242
<PackageLicenseFile>LICENSE</PackageLicenseFile>
4343
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
4444
<SignAssembly>true</SignAssembly>
Lines changed: 56 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,63 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow.Keras.Engine;
45

5-
namespace Tensorflow.Keras.Callbacks
6+
namespace Tensorflow.Keras.Callbacks;
7+
8+
public class CallbackList
69
{
7-
public class CallbackList
8-
{
9-
List<ICallback> callbacks = new List<ICallback>();
10-
public History History => callbacks[0] as History;
11-
12-
public CallbackList(CallbackParams parameters)
13-
{
14-
callbacks.Add(new History(parameters));
15-
callbacks.Add(new ProgbarLogger(parameters));
16-
}
17-
18-
public void on_train_begin()
19-
{
20-
callbacks.ForEach(x => x.on_train_begin());
21-
}
22-
23-
public void on_epoch_begin(int epoch)
24-
{
25-
callbacks.ForEach(x => x.on_epoch_begin(epoch));
26-
}
27-
28-
public void on_train_batch_begin(long step)
29-
{
30-
callbacks.ForEach(x => x.on_train_batch_begin(step));
31-
}
32-
33-
public void on_train_batch_end(long end_step, Dictionary<string, float> logs)
34-
{
35-
callbacks.ForEach(x => x.on_train_batch_end(end_step, logs));
36-
}
37-
38-
public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
39-
{
40-
callbacks.ForEach(x => x.on_epoch_end(epoch, epoch_logs));
41-
}
42-
43-
public void on_predict_begin()
44-
{
45-
callbacks.ForEach(x => x.on_predict_begin());
46-
}
47-
48-
public void on_predict_batch_begin(long step)
49-
{
50-
callbacks.ForEach(x => x.on_predict_batch_begin(step));
51-
}
52-
53-
public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs)
54-
{
55-
callbacks.ForEach(x => x.on_predict_batch_end(end_step, logs));
56-
}
57-
58-
public void on_predict_end()
59-
{
60-
callbacks.ForEach(x => x.on_predict_end());
61-
}
10+
List<ICallback> callbacks = new List<ICallback>();
11+
public History History => callbacks[0] as History;
12+
13+
public CallbackList(CallbackParams parameters)
14+
{
15+
callbacks.Add(new History(parameters));
16+
callbacks.Add(new ProgbarLogger(parameters));
17+
}
18+
19+
public void on_train_begin()
20+
{
21+
callbacks.ForEach(x => x.on_train_begin());
22+
}
23+
24+
public void on_epoch_begin(int epoch)
25+
{
26+
callbacks.ForEach(x => x.on_epoch_begin(epoch));
27+
}
28+
29+
public void on_train_batch_begin(long step)
30+
{
31+
callbacks.ForEach(x => x.on_train_batch_begin(step));
32+
}
33+
34+
public void on_train_batch_end(long end_step, Dictionary<string, float> logs)
35+
{
36+
callbacks.ForEach(x => x.on_train_batch_end(end_step, logs));
37+
}
38+
39+
public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
40+
{
41+
callbacks.ForEach(x => x.on_epoch_end(epoch, epoch_logs));
42+
}
43+
44+
public void on_predict_begin()
45+
{
46+
callbacks.ForEach(x => x.on_predict_begin());
47+
}
48+
49+
public void on_predict_batch_begin(long step)
50+
{
51+
callbacks.ForEach(x => x.on_predict_batch_begin(step));
52+
}
53+
54+
public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs)
55+
{
56+
callbacks.ForEach(x => x.on_predict_batch_end(end_step, logs));
57+
}
58+
59+
public void on_predict_end()
60+
{
61+
callbacks.ForEach(x => x.on_predict_end());
6262
}
6363
}

0 commit comments

Comments
 (0)