From 726b742157eb2bdc4f5063e9a0f9093bfd03aa33 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Sat, 14 Jan 2023 15:45:02 +0800 Subject: [PATCH 01/15] Add check for dims of x and y in model.fit. --- src/TensorFlowNET.Keras/Engine/Model.Fit.cs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index e0b4af78c..40dd4ab6f 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -31,6 +31,11 @@ public void fit(NDArray x, NDArray y, int workers = 1, bool use_multiprocessing = false) { + if (x.dims[0] != y.dims[0]) + { + throw new InvalidArgumentError( + $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); + } int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); var train_x = x[new Slice(0, train_count)]; var train_y = y[new Slice(0, train_count)]; From bb8168b5ca9bc78a824d429eb7bd5f4ac9e4fa8d Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Sat, 21 Jan 2023 11:07:07 +0800 Subject: [PATCH 02/15] Init the serialization of keras pb model. --- src/TensorFlowNET.Core/APIs/tf.compat.cs | 22 ++ .../Checkpoint/CheckPointUtils.cs | 150 +++++++++ .../Checkpoint/CheckpointOptions.cs | 5 + .../Checkpoint/ObjectGraphView.cs | 63 ++++ .../Checkpoint/SaveUtilV1.cs | 229 ++++++++++++++ .../Checkpoint/TrackableSaver.cs | 109 +++++++ .../Checkpoint/TrackableView.cs | 75 +++++ .../Exceptions/AssertionError.cs | 14 + .../Framework/meta_graph.cs | 63 +++- .../Functions/ConcreteFunction.cs | 3 +- src/TensorFlowNET.Core/Functions/Function.cs | 11 +- .../ModelSaving/SaveOptions.cs | 8 +- .../Operations/resource_variable_ops.cs | 6 + .../Protobuf/SavedObjectGraph.cs | 10 +- .../Protobuf/TrackableObjectGraph.cs | 6 + .../Training/AutoTrackable.cs | 15 + src/TensorFlowNET.Core/Training/Optimizer.cs | 7 +- .../Training/Saving/SaveableObject.cs | 14 + .../Training/Saving/SavedModel/AssetInfo.cs | 11 + .../Saving/SavedModel/AugmentedGraphView.cs | 60 ++++ .../Training/Saving/SavedModel/Constants.cs | 33 ++ .../Saving/SavedModel/RevivedTypes.cs | 17 + .../Training/Saving/SavedModel/SaveType.cs | 9 + .../Saving/SavedModel/SaveableView.cs | 299 ++++++++++++++++++ .../Saving/SavedModel/TagConstants.cs | 10 + .../Training/Saving/SavedModel/builder.cs | 22 ++ .../Training/Saving/SavedModel/save.cs | 256 +++++++++++++++ .../SavedModel/signature_serialization.cs | 58 ++++ .../Training/Saving/SavedModel/utils.cs | 52 +++ .../Saving/saveable_object_util.py.cs | 19 +- src/TensorFlowNET.Core/Training/Trackable.cs | 79 ++++- .../Training/TrackableUtils.cs | 148 +++++++++ .../Variables/BaseResourceVariable.cs | 1 + src/TensorFlowNET.Core/ops.cs | 18 ++ .../Engine/Layer.Serialize.cs | 31 ++ src/TensorFlowNET.Keras/Engine/Layer.cs | 4 +- src/TensorFlowNET.Keras/Engine/Model.Save.cs | 15 +- src/TensorFlowNET.Keras/Engine/Model.cs | 6 + .../Protobuf/SavedMetadata.cs | 12 + src/TensorFlowNET.Keras/Protobuf/Versions.cs | 7 + .../Saving/SavedModel/Constants.cs | 41 +++ .../Saving/SavedModel/KerasObjectWrapper.cs | 11 + .../Saving/SavedModel/Save.cs | 115 +++++++ .../Saving/SavedModel/SaveImpl.cs | 19 ++ .../Saving/SavedModel/base_serialization.cs | 40 +++ .../Saving/SavedModel/layer_serialization.cs | 62 ++++ .../Saving/SavedModel/utils.cs | 33 ++ test/TensorFlowNET.Keras.UnitTest/SaveTest.cs | 60 ++++ .../Tensorflow.Binding.UnitTest.csproj | 2 +- 49 files changed, 2347 insertions(+), 13 deletions(-) create mode 100644 src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/TrackableView.cs create mode 100644 src/TensorFlowNET.Core/Exceptions/AssertionError.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs create mode 100644 src/TensorFlowNET.Core/Training/TrackableUtils.cs create mode 100644 src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs create mode 100644 test/TensorFlowNET.Keras.UnitTest/SaveTest.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.compat.cs b/src/TensorFlowNET.Core/APIs/tf.compat.cs index 4d979eb55..5b2b5a107 100644 --- a/src/TensorFlowNET.Core/APIs/tf.compat.cs +++ b/src/TensorFlowNET.Core/APIs/tf.compat.cs @@ -14,6 +14,8 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System.Text; + namespace Tensorflow { public partial class tensorflow @@ -23,6 +25,26 @@ public partial class tensorflow public class CompatApi { public CompatV1Api v1 { get; } = new CompatV1Api(); + + internal string as_text(string bytes_or_text, Encoding? encoding = null) + { + if(encoding is null) encoding = Encoding.UTF8; + return bytes_or_text; + } + internal string as_text(byte[] bytes_or_text, Encoding? encoding = null) + { + if(encoding is null) encoding = Encoding.UTF8; + return encoding.GetString(bytes_or_text); + } + + internal string as_str(string bytes_or_text, Encoding? encoding = null) + { + return as_text(bytes_or_text, encoding); + } + internal string as_str(byte[] bytes_or_text, Encoding? encoding = null) + { + return as_text(bytes_or_text, encoding); + } } public bool executing_eagerly() diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs new file mode 100644 index 000000000..70d771559 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs @@ -0,0 +1,150 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; + +namespace Tensorflow.Checkpoint; + +public static class CheckPointUtils +{ + private static string _ESCAPE_CHAR = "."; + public static (List, Dictionary>, Dictionary, + IDictionary>, + Dictionary) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) + { + var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); + Dictionary object_names = new(); + foreach (var pair in node_paths) + { + object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); + } + + Dictionary node_ids = new(); + for (int i = 0; i < trackable_objects.Count; i++) + { + node_ids[trackable_objects[i]] = i; + } + + var slot_variables = serialize_slot_variables(trackable_objects, node_ids, object_names); + return (trackable_objects, node_paths, node_ids, slot_variables, object_names); + } + + public static + IDictionary> + serialize_slot_variables(IEnumerable trackable_objects, + IDictionary node_ids, IDictionary object_names) + { + var non_slot_objects = trackable_objects.ToList(); + Dictionary> + slot_variables = new(); + foreach (var trackable in non_slot_objects) + { + if (trackable is not Optimizer) + { + continue; + } + + var optim = (Optimizer)trackable; + var slot_names = optim.get_slot_names(); + foreach (var slot_name in slot_names) + { + for (int original_variable_node_id = 0; + original_variable_node_id < non_slot_objects.Count; + original_variable_node_id++) + { + var original_variable = non_slot_objects[original_variable_node_id]; + IVariableV1 slot_variable; + if (original_variable is not IVariableV1) + { + slot_variable = null; + } + slot_variable = optim.get_slot((IVariableV1)original_variable, slot_name); + if(slot_variable is null) continue; + + // There're some problems about the inherits of `Variable` and `Trackable`. + throw new NotImplementedException(); + } + } + } + + return slot_variables; + } + + public static Trackable get_mapped_trackable(Trackable trackable, IDictionary? object_map) + { + if (object_map is null || !object_map.TryGetValue(trackable, out var possible_res)) + { + return trackable; + } + else + { + return possible_res; + } + } + + public static string get_full_name(Trackable var) + { + // TODO: This state is not correct, the whole framework need to be updated in the future. + if (!(var is IVariableV1 || resource_variable_ops.is_resource_variable(var))) + { + return ""; + } + // skip the check of attribute `_save_slice_info` . + + // TODO: Need to be revised!!! + return ((ResourceVariable)(object)var).Name; + } + + public static void add_checkpoint_values_check(TrackableObjectGraph object_graph_proto) + { + HashSet checkpointed_trackables = new(); + Dictionary> parents = new(); + for (int i = 0; i < object_graph_proto.Nodes.Count; i++) + { + var object_proto = object_graph_proto.Nodes[i]; + // skip the process of registered saver. + if (object_proto.Attributes is not null && object_proto.Attributes.Count > 0 || + object_proto.SlotVariables is not null && object_proto.SlotVariables.Count > 0) + { + checkpointed_trackables.Add(i); + } + + foreach (var child_proto in object_proto.Children) + { + var child = child_proto.NodeId; + if (!parents.ContainsKey(child)) + { + parents[child] = new HashSet(); + } + + parents[child].Add(i); + } + } + + Queue to_visit = new(checkpointed_trackables.AsEnumerable()); + while (to_visit.Count > 0) + { + var trackable = to_visit.Dequeue(); + if (!parents.ContainsKey(trackable)) continue; + var current_parents = parents[trackable]; + foreach (var parent in current_parents) + { + checkpointed_trackables.Add(parent); + if (parents.ContainsKey(parent)) + { + to_visit.Enqueue(parent); + } + } + parents.Remove(trackable); + } + + // TODO: Complete it after supporting checkpoint. + // for (int i = 0; i < object_graph_proto.Nodes.Count; i++) + // { + // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); + // } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs b/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs new file mode 100644 index 000000000..d8297ea3f --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs @@ -0,0 +1,5 @@ +namespace Tensorflow.Checkpoint; + +public record class CheckpointOptions( + string experimental_io_device = null, + bool experimental_enable_async_checkpoint = false); \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs b/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs new file mode 100644 index 000000000..2ad554485 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs @@ -0,0 +1,63 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Serilog.Debugging; +using Tensorflow.Train; + +namespace Tensorflow.Checkpoint; + +public class ObjectGraphView: TrackableView, ICloneable +{ + protected IEnumerable? _attached_dependencies; + // TODO: attached_dependencies + public ObjectGraphView(Trackable root, IEnumerable? attached_dependencies = null): base(root) + { + _attached_dependencies = attached_dependencies; + } + + public object Clone() + { + // TODO: Implement real deep copy corresponding to tensorflow/python/checkpoint/graph_view.ObjectGraphView.__deepcopy__ + return new ObjectGraphView(Root, _attached_dependencies); + } + + public virtual List list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + { + List res = base.children(obj, save_type) + .Select(x => new TrackableReference(x.Key, x.Value)).ToList(); + // Check the reference, not value. + if (obj == Root && _attached_dependencies is not null) + { + res.AddRange(_attached_dependencies); + } + + return res; + } + + public override IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + { + return list_children(obj, save_type).ToDictionary(x => x.Name, x => x.Refer); + } + + public IEnumerable? AttachedDependencies + { + get => _attached_dependencies; + } + + public virtual (List, Dictionary>) breadth_first_traversal() + { + return base._descendants_with_paths(); + } + + // TODO: complete the implementation + public void serialize_object_graph(object? saveables_cache = null) + { + throw new NotImplementedException(); + } + + // TODO: complete the implementation + public void frozen_saveable_objects(object? object_map = null, object? to_graph = null, object call_with_mapped_captures = null) + { + throw new NotImplementedException(); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs new file mode 100644 index 000000000..7724c6b70 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs @@ -0,0 +1,229 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Exceptions; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; +using static Tensorflow.Binding; + +namespace Tensorflow.Checkpoint; + +public static class SaveUtilV1 +{ + public static (Dictionary>, object?) get_checkpoint_factories_and_keys(IDictionary object_names, + IDictionary? object_map = null) + { + // According to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/registration/README.md, + // till now only internal registrations are allowed. So, we won't return a saver in this function. + // The implementation of this function should be updated if tensorflow update it. + Dictionary> checkpoint_factory_map = new(); + foreach (var pair in object_names) + { + var trackable = pair.Key; + var object_name = pair.Value; + var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map); + + // skip the registration process. + + List current_list = new(); + foreach (var name_and_factory in saveable_object_util.saveable_objects_from_trackable(object_to_save)) + { + // treat name as key_suffix. + var name = name_and_factory.Key; + var checkpoint_key = TrackableUtils.checkpoint_key(object_name, name); + + current_list.Add(new CheckpointFactoryData(name_and_factory.Value, name, checkpoint_key)); + } + + checkpoint_factory_map[trackable] = current_list; + } + + return (checkpoint_factory_map, null); + } + + public static (List, object?) frozen_saveables_and_savers(ObjectGraphView graph_view, + IDictionary object_map, Graph? to_graph, bool call_with_mapped_captures, + object? saveables_cache = null) + { + + Graph target_context; + if (to_graph is not null) + { + using (to_graph.as_default()) + { + var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, + object_map, call_with_mapped_captures, saveables_cache); + // tensorflow python: `with ops.device("/cpu:0")` + var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); + named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); + return (named_saveable_objects, registered_savers); + } + } + else + { + using (new ops.NullContextManager()) + { + var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, + object_map, call_with_mapped_captures, saveables_cache); + // tensorflow python: `with ops.device("/cpu:0")` + var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); + named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); + return (named_saveable_objects, registered_savers); + } + } + } + + public static (List, TrackableObjectGraph, object?, object?) serialize_gathered_objects(ObjectGraphView graph_view, + IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) + { + var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); + Dictionary object_names = new(); + foreach (var pair in node_paths) + { + object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); + } + + Dictionary node_ids = new(); + for (int i = 0; i < trackable_objects.Count; i++) + { + node_ids[trackable_objects[i]] = i; + } + + var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names); + var object_graph_proto = fill_object_graph_proto(graph_view, trackable_objects, node_ids, slot_variables); + var (named_saveable_objects, feed_additions, registered_savers) = add_attributes_to_object_graph( + trackable_objects, object_graph_proto, node_ids, object_names, object_map, call_with_mapped_captures, + saveables_cache); + + CheckPointUtils.add_checkpoint_values_check(object_graph_proto); + return (named_saveable_objects, object_graph_proto, feed_additions, registered_savers); + } + + private static TrackableObjectGraph fill_object_graph_proto(ObjectGraphView graph_view, IList trackable_objects, + IDictionary node_ids, + IDictionary> + slot_variables) + { + TrackableObjectGraph object_graph_proto = new(); + for (int i = 0; i < trackable_objects.Count; i++) + { + var trackable = trackable_objects[i]; + Debug.Assert(node_ids[trackable] == i); + TrackableObjectGraph.Types.TrackableObject object_proto; + if (slot_variables.TryGetValue(trackable, out var slots)) + { + object_proto = new TrackableObjectGraph.Types.TrackableObject(slots); + } + else + { + object_proto = new TrackableObjectGraph.Types.TrackableObject(); + } + object_graph_proto.Nodes.Add(object_proto); + foreach (var child in graph_view.list_children(trackable)) + { + object_proto.Children.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference() + { NodeId = node_ids[child.Refer], LocalName = child.Name }); + } + } + + return object_graph_proto; + } + + private static (List, object?, object?) add_attributes_to_object_graph(IList trackable_objects, + TrackableObjectGraph object_graph_proto, IDictionary node_ids, + IDictionary object_names, IDictionary object_map, + bool call_with_mapped_captures, object? saveables_cache = null) + { + int cnt = Math.Min(trackable_objects.Count, object_graph_proto.Nodes.Count); + for (int i = 0; i < cnt; i++) + { + Debug.Assert(node_ids[trackable_objects[i]] == i); + } + + var (checkpoint_factory_map, unmmaped_registered_savers) = + get_checkpoint_factories_and_keys(object_names, object_map); + + // skip the process of registered savers + + var (named_saveable_objects, feed_additions) = generate_saveable_objects(checkpoint_factory_map, + object_graph_proto, node_ids, object_map, call_with_mapped_captures, saveables_cache); + return (named_saveable_objects, feed_additions, null); + } + + public static (List, object?) generate_saveable_objects( + IDictionary> checkpoint_factory_map, + TrackableObjectGraph? object_graph_proto, IDictionary? node_ids, + IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) + { + List named_saveable_objects = new(); + foreach (var pair in checkpoint_factory_map) + { + var trackable = pair.Key; + var factory_data_list = pair.Value; + bool fill_object_proto = object_graph_proto is not null && node_ids is not null; + TrackableObjectGraph.Types.TrackableObject object_proto = null!; + if (fill_object_proto) + { + object_proto = object_graph_proto.Nodes[node_ids[trackable]]; + } + + var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map); + // skip cache + + foreach (var factory_data in factory_data_list) + { + var name = factory_data.name; + var key = factory_data.checkpoint_key; + var saveable_factory = factory_data.factory; + + // TODO: oneflow python has a process with callable `saveable_factory`. + var maybe_saveable = saveable_factory; + IEnumerable savesbles; + if (maybe_saveable is MySaveableObject) + { + savesbles = new List() { (MySaveableObject)maybe_saveable }; + } + else if (maybe_saveable is Tensor) + { + savesbles = saveable_object_util.saveable_objects_for_op((Tensor)maybe_saveable, key); + } + else + { + throw new TypeError("Unexpected type."); + } + + foreach (var saveable in savesbles) + { + if (!saveable.name.Contains(key)) + { + throw new AssertionError($"The object {trackable} produced a SaveableObject with name " + + $"'{saveable.name}' for attribute '{name}'. Expected a name" + + $" containing '{key}'."); + } + } + + // skip the process of PythonState + + named_saveable_objects.AddRange(savesbles); + + if(!fill_object_proto) continue; + + // skip the process of TrackableSaveable + + object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() + { Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) }); + } + } + + return (named_saveable_objects, null); + } +} + +public record class CheckpointFactoryData +( + object factory, + string name, + string checkpoint_key +); \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs new file mode 100644 index 000000000..7d101d5e5 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs @@ -0,0 +1,109 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Contexts; +using Tensorflow.Eager; + +namespace Tensorflow.Checkpoint; + +public class TrackableSaver +{ + private ObjectGraphView _graph_view; + private EagerTensor _cached_save_operation; + private TrackableObjectGraph _last_save_object_graph; + private Tensor? _object_graph_feed_tensor = null; + private Tensor? _file_prefix_feed_tensor = null; + public TrackableSaver(ObjectGraphView graph_view) + { + _graph_view = graph_view; + + // TODO: cache when not executing eagerly. + // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, + // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` + + } + + private void gather_serialized_tensors(Tensor? object_graph_tensor = null) + { + throw new NotImplementedException(); + } + + private (EagerTensor, IDictionary) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) + { + throw new NotImplementedException(); + } + + // TODO: parameter write_done_callback + public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null, + CheckpointOptions? options = null) + { + if (options is null) + { + options = new CheckpointOptions(); + } + + Dictionary feed_dict = new(); + bool use_session = (!new Context().executing_eagerly() && !ops.inside_function()); + if (checkpoint_number is not null) + { + file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; + } + + Tensor file_prefix_tensor; + Tensor object_graph_tensor; + if (use_session) + { + if (_object_graph_feed_tensor is null) + { + // In python there is `with ops.device("/cpu:0")`. + _object_graph_feed_tensor = constant_op.constant("", dtypes.variant); + _file_prefix_feed_tensor = constant_op.constant("", dtypes.variant); + } + + object_graph_tensor = _object_graph_feed_tensor; + file_prefix_tensor = _file_prefix_feed_tensor; + feed_dict[file_prefix_tensor] = file_prefix; + } + else + { + // In python there is `with ops.device("/cpu:0")`. + file_prefix_tensor = ops.convert_to_tensor(file_prefix, dtypes.variant); + object_graph_tensor = null; + } + + var (save_path, new_feed_additions) = + save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options); + + if (new_feed_additions is not null) + { + foreach (var pair in new_feed_additions) + { + feed_dict.Add(pair.Key, pair.Value); + } + } + if(!use_session) + { + session = null; + } + else if (session is null) + { + session = new Session(); // In python it uses `get_session`. + } + + if (session is not null) + { + var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray(); + return session.run((Tensor)save_path, s); + } + else if (use_session) + { + throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " + + "in graph mode without a default session. Please use " + + "`with tf.Session():` to create a session."); + } + else + { + return save_path; + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs new file mode 100644 index 000000000..ed1f3ec47 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs @@ -0,0 +1,75 @@ +using System; +using Tensorflow.Train; +using System.Collections.Generic; +using System.IO; + +namespace Tensorflow.Checkpoint; + +public class TrackableView +{ + protected WeakReference _root_ref; + public TrackableView(Trackable obj) + { + _root_ref = new WeakReference(obj); + } + + public TrackableView(WeakReference obj) + { + _root_ref = obj; + } + + public virtual IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + { + obj._maybe_initialize_trackable(); + // Note: in python the return type of `Trackable._trackable_children` is not fixed. + // Therefore it uses `convert_to_trackable` to have an extra process. + return obj._trackable_children(save_type); + } + + public Trackable Root + { + get + { + if (_root_ref.TryGetTarget(out Trackable res)) + { + return res; + } + else + { + throw new InvalidDataException( + "Cannot get the object from the weak reference. Please consider if a null reference is passed to the constructor."); + } + } + } + + /// + /// Returns a list of all nodes and its paths from self.root using a breadth first traversal. + /// Corresponding to tensorflow/python/checkpoint/trackable_view.Trackable._descendants_with_paths + /// + protected (List, Dictionary>) _descendants_with_paths() + { + List bfs_sorted = new(); + Queue to_visit = new(); + Dictionary> node_paths = new(); + node_paths[this.Root] = new List(); + while (!to_visit.empty()) + { + var current_trackable = to_visit.Dequeue(); + bfs_sorted.Add(current_trackable); + var children_dict = this.children(current_trackable); + foreach (var name in children_dict.Keys) + { + var dependency = children_dict[name]; + if (!node_paths.ContainsKey(dependency)) + { + var list = new List(node_paths[current_trackable]); + list.Add(new TrackableReference(name, dependency)); + node_paths[dependency] = list; + to_visit.Enqueue(dependency); + } + } + } + + return (bfs_sorted, node_paths); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Exceptions/AssertionError.cs b/src/TensorFlowNET.Core/Exceptions/AssertionError.cs new file mode 100644 index 000000000..84ec24cbf --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/AssertionError.cs @@ -0,0 +1,14 @@ +namespace Tensorflow.Exceptions; + +public class AssertionError : TensorflowException +{ + public AssertionError() : base() + { + + } + + public AssertionError(string message) : base(message) + { + + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.cs b/src/TensorFlowNET.Core/Framework/meta_graph.cs index 6ce3bf3c5..cce13b55d 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.cs @@ -304,7 +304,7 @@ private static void add_collection_def(MetaGraphDef meta_graph_def, } } - private static OpList stripped_op_list_for_graph(GraphDef graph_def) + public static OpList stripped_op_list_for_graph(GraphDef graph_def) { var used_ops = ops_used_by_graph_def(graph_def); @@ -345,5 +345,66 @@ private static string[] ops_used_by_graph_def(GraphDef graph_def) return used_ops.ToArray(); } + + private static bool is_default_attr_value(OpDef op_def, string attr_name, AttrValue attr_value) + { + foreach (var attr_def in op_def.Attr) + { + if (attr_def.Name == attr_name) + { + if (attr_def.DefaultValue is null) return false; + // TODO: add new c_api `EqualAttrValueWrapper` and complete the check. + return true; + } + } + + return false; + } + + public static void strip_graph_default_valued_attrs(MetaGraphDef meta_graph_def) + { + Dictionary op_name_to_function = new(); + foreach (var function_def in meta_graph_def.GraphDef.Library.Function) + { + op_name_to_function[function_def.Signature.Name] = function_def; + } + + Action _strip_node_default_valued_attrs = (node_def) => + { + if (op_name_to_function.ContainsKey(node_def.Op)) return; + + var op_def = op_def_registry.GetOpDef(node_def.Op); + if(op_def is null) return; + + HashSet attrs_to_strip = new(); + foreach (var attr in node_def.Attr) + { + if (is_default_attr_value(op_def, attr.Key, attr.Value)) + { + attrs_to_strip.Add(attr.Key); + } + } + + foreach (var attr in attrs_to_strip) + { + node_def.Attr.Remove(attr); + } + }; + + foreach (var node_def in meta_graph_def.GraphDef.Node) + { + _strip_node_default_valued_attrs(node_def); + } + + foreach (var function_def in meta_graph_def.GraphDef.Library.Function) + { + foreach (var function_node_def in function_def.NodeDef) + { + _strip_node_default_valued_attrs(function_node_def); + } + } + + meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; + } } } diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index c52d0b5f5..bac9cedbf 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -3,6 +3,7 @@ using System.Linq; using Tensorflow.Framework.Models; using Tensorflow.Graphs; +using Tensorflow.Train; using static Tensorflow.Binding; namespace Tensorflow.Functions @@ -10,7 +11,7 @@ namespace Tensorflow.Functions /// /// /// - public class ConcreteFunction + public class ConcreteFunction: Trackable { FuncGraph func_graph; ForwardBackwardCall forward_backward; diff --git a/src/TensorFlowNET.Core/Functions/Function.cs b/src/TensorFlowNET.Core/Functions/Function.cs index d57097ae9..056d15f4d 100644 --- a/src/TensorFlowNET.Core/Functions/Function.cs +++ b/src/TensorFlowNET.Core/Functions/Function.cs @@ -1,16 +1,23 @@ using System; +using Tensorflow.Train; namespace Tensorflow { - public class Function + public class Function: Trackable { #pragma warning disable CS0169 // The field 'Function._handle' is never used private IntPtr _handle; #pragma warning restore CS0169 // The field 'Function._handle' is never used - + + public string Name { get; set; } public Function() { } + + public Function(string name) + { + Name = name; + } } } diff --git a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs b/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs index e25537d80..fce42850f 100644 --- a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs +++ b/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs @@ -9,7 +9,13 @@ namespace Tensorflow.ModelSaving /// public class SaveOptions { - bool save_debug_info; + public bool save_debug_info = false; + public IList? namespace_white_list { get; set; } = null; + public IDictionary? function_aliases { get; set; } = null; + public string? experimental_io_device { get; set; } = null; + // TODO: experimental + public Object? experimental_variable_polict { get; set; } = null; + public bool experimental_custom_gradients { get; set; } = true; public SaveOptions(bool save_debug_info = false) { this.save_debug_info = save_debug_info; diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs index ee751acf4..d5a32c10e 100644 --- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -17,6 +17,7 @@ limitations under the License. using System; using System.Linq; using Tensorflow.Framework; +using Tensorflow.Train; using static Tensorflow.CppShapeInferenceResult.Types; namespace Tensorflow @@ -38,6 +39,11 @@ public static bool is_resource_variable(IVariableV1 var) { return var is ResourceVariable; } + + public static bool is_resource_variable(Trackable var) + { + return var is BaseResourceVariable; + } /// /// Creates a variable handle with information to do shape inference. diff --git a/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs b/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs index 9d3e854ac..f2597574b 100644 --- a/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs +++ b/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs @@ -156,7 +156,7 @@ public SavedObjectGraph Clone() { /// Nodes[0] is considered the root node. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField Nodes { + public pbc::RepeatedField Nodes { get { return nodes_; } } @@ -286,6 +286,7 @@ public SavedObject() { [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public SavedObject(SavedObject other) : this() { children_ = other.children_.Clone(); + dependencies_ = other.dependencies_.Clone(); slotVariables_ = other.slotVariables_.Clone(); saveableObjects_ = other.saveableObjects_.Clone(); switch (other.KindCase) { @@ -328,6 +329,7 @@ public SavedObject Clone() { private static readonly pb::FieldCodec _repeated_children_codec = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); private readonly pbc::RepeatedField children_ = new pbc::RepeatedField(); + private readonly pbc::RepeatedField dependencies_ = new pbc::RepeatedField(); /// /// Objects which this object depends on: named edges in the dependency /// graph. @@ -338,6 +340,11 @@ public SavedObject Clone() { public pbc::RepeatedField Children { get { return children_; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Dependencies { + get { return dependencies_; } + } /// Field number for the "slot_variables" field. public const int SlotVariablesFieldNumber = 3; @@ -617,6 +624,7 @@ public void MergeFrom(SavedObject other) { return; } children_.Add(other.children_); + dependencies_.Add(other.dependencies_); slotVariables_.Add(other.slotVariables_); saveableObjects_.Add(other.saveableObjects_); switch (other.KindCase) { diff --git a/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs b/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs index 3aa747c20..934136671 100644 --- a/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs +++ b/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs @@ -198,6 +198,12 @@ public sealed partial class TrackableObject : pb::IMessage { public TrackableObject() { OnConstruction(); } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrackableObject(pbc::RepeatedField slot) { + OnConstruction(); + slotVariables_ = slot; + } partial void OnConstruction(); diff --git a/src/TensorFlowNET.Core/Training/AutoTrackable.cs b/src/TensorFlowNET.Core/Training/AutoTrackable.cs index d2198e37e..d8f6314bc 100644 --- a/src/TensorFlowNET.Core/Training/AutoTrackable.cs +++ b/src/TensorFlowNET.Core/Training/AutoTrackable.cs @@ -2,5 +2,20 @@ { public abstract class AutoTrackable : Trackable { + public void _delete_tracking(string name) + { + _maybe_initialize_trackable(); + if (_unconditional_dependency_names.ContainsKey(name)) + { + _unconditional_dependency_names.Remove(name); + for (int i = _unconditional_checkpoint_dependencies.Count - 1; i >= 0; i--) + { + if (_unconditional_checkpoint_dependencies[i].Name == name) + { + _unconditional_checkpoint_dependencies.RemoveAt(i); + } + } + } + } } } diff --git a/src/TensorFlowNET.Core/Training/Optimizer.cs b/src/TensorFlowNET.Core/Training/Optimizer.cs index f985c6566..e656fe96d 100644 --- a/src/TensorFlowNET.Core/Training/Optimizer.cs +++ b/src/TensorFlowNET.Core/Training/Optimizer.cs @@ -351,7 +351,7 @@ public virtual void _prepare() /// /// /// - protected IVariableV1 get_slot(IVariableV1 var, string name) + internal IVariableV1 get_slot(IVariableV1 var, string name) { var named_slots = _slots.ContainsKey(name) ? _slots[name] : null; if (named_slots == null) @@ -360,6 +360,11 @@ protected IVariableV1 get_slot(IVariableV1 var, string name) return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null; } + internal IEnumerable get_slot_names() + { + return _slots.Keys; + } + private string _var_key(IVariableV1 var) { return $"{var.Op.graph.graph_key}.{var.Op.name}"; diff --git a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs index c86075f86..6239030ba 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs @@ -48,4 +48,18 @@ public virtual Operation restore(Tensor[] restored_tensors, Shape[] restored_sha validate_shape: restored_shapes == null && op.shape.IsFullyDefined); } } + + public class NoRestoreSaveable: MySaveableObject + { + public NoRestoreSaveable(Tensor tensor, string name, TF_DataType dtype = TF_DataType.DtInvalid, string? device = null) : base(tensor, + new SaveSpec[] { new SaveSpec(tensor, "", name, dtype) }, name) + { + + } + + public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) + { + return control_flow_ops.no_op(); + } + } } diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs new file mode 100644 index 000000000..24c8f2f05 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs @@ -0,0 +1,11 @@ +using System.Collections.Generic; + +namespace Tensorflow; + +public record class AssetInfo +( + List asset_defs, + Dictionary asset_initializers_by_resource, + Dictionary asset_filename_map, + Dictionary asset_index +); \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs new file mode 100644 index 000000000..6723206c0 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs @@ -0,0 +1,60 @@ +using System; +using Tensorflow.Checkpoint; +using Tensorflow.Train; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Functions; + +namespace Tensorflow; + +public class AugmentedGraphView: ObjectGraphView +{ + // private object _children_cache; + // private object _serialization_cache; + private List _untraces_functions; + public AugmentedGraphView(Trackable root): base(root) + { + _untraces_functions = new(); + } + + public void set_signature(object signature_map, object wrapped_functions) + { + // TODO: cache + list_children(Root); + } + + public override List list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + { + Dictionary children = new(); + foreach (var pair in base.list_children(obj, save_type)) + { + var name = pair.Name; + var child = pair.Refer; + children[name] = child; + } + + if (obj is Function && children.Count == 0) + { + _untraces_functions.Add(((Function)obj).Name); + } + + return children.Select(x => new TrackableReference(x.Key, x.Value)).ToList(); + } + + public override (List, Dictionary>) breadth_first_traversal() + { + // TODO: implement it if needed. + return base.breadth_first_traversal(); + } + + public List<(string, Trackable)> list_dependencies(Trackable obj) + { + // TODO: deal with cache. + return obj.deserialization_dependencies(null).Select(x => (x.Key, x.Value)).ToList(); + } + + public Trackable get_child(Trackable obj, string name) + { + throw new NotImplementedException(); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs new file mode 100644 index 000000000..cb7abadad --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs @@ -0,0 +1,33 @@ +namespace Tensorflow; + +public static class Constants +{ + public static readonly string ASSETS_DIRECTORY = "assets"; + public static readonly string ASSETS_KEY = "saved_model_assets"; + + public static readonly string DEBUG_DIRECTORY = "debug"; + + public static readonly string DEBUG_INFO_FILENAME_PB = "saved_model_debug_info.pb"; + + public static readonly string EXTRA_ASSETS_DIRECTORY = "assets.extra"; + + public static readonly string FINGERPRINT_FILENAME = "fingerprint.pb"; + + public static readonly string INIT_OP_SIGNATURE_KEY = "__saved_model_init_op"; + + public static readonly string LEGACY_INIT_OP_KEY = "legacy_init_op"; + + public static readonly string MAIN_OP_KEY = "saved_model_main_op"; + + public static readonly string SAVED_MODEL_FILENAME_PB = "saved_model.pb"; + public static readonly string SAVED_MODEL_FILENAME_PBTXT = "saved_model.pbtxt"; + + public static readonly int SAVED_MODEL_SCHEMA_VERSION = 1; + + public static readonly string TRAIN_OP_KEY = "saved_model_train_op"; + + public static readonly string TRAIN_OP_SIGNATURE_KEY = "__saved_model_train_op"; + + public static readonly string VARIABLES_DIRECTORY = "variables"; + public static readonly string VARIABLES_FILENAME = "variables"; +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs new file mode 100644 index 000000000..fa9d6e504 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs @@ -0,0 +1,17 @@ +using Tensorflow.Train; + +namespace Tensorflow; + +public class RevivedTypes +{ + /// + /// Create a SavedUserObject from a trackable object. + /// + /// + /// + public static SavedUserObject? serialize(Trackable obj) + { + // TODO: complete the implementation. + return null; + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs new file mode 100644 index 000000000..b973fd417 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs @@ -0,0 +1,9 @@ +using System; + +namespace Tensorflow; + +public enum SaveType +{ + SAVEDMODEL, + CHECKPOINT +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs new file mode 100644 index 000000000..6a241f0e7 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs @@ -0,0 +1,299 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Checkpoint; +using Tensorflow.Contexts; +using Tensorflow.Functions; +using Tensorflow.ModelSaving; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public class SaveableView +{ + private AugmentedGraphView _augmented_graph_view; + private SaveOptions _options; + private List _trackable_objects; + private List _nodes; + private Dictionary> _node_paths; + private Dictionary _node_ids; + private IDictionary> + _slot_variables; + private Dictionary _object_names; + private List _gradient_functions; // to be completed + private List _gradient_defs; // to be completed + private List _concrete_functions; + private Dictionary _captured_tensor_node_ids; + private Dictionary> _saveable_objects_map; + private Dictionary _obj_to_registered_saver; + + public AugmentedGraphView AugmentedGraphView + { + get => _augmented_graph_view; + } + + public Trackable Root + { + get => _nodes[0]; + } + public List Nodes + { + get => _nodes; + } + public Dictionary NodeIds + { + get => _node_ids; + } + public List GradientDefs + { + get => _gradient_defs; + } + public Dictionary> NodePaths + { + get => _node_paths; + } + public SaveableView(AugmentedGraphView augmented_graph_view, SaveOptions options) + { + _augmented_graph_view = augmented_graph_view; + _options = options; + + (_trackable_objects, _node_paths, _node_ids, _slot_variables, _object_names) = + CheckPointUtils.objects_ids_and_slot_variables_and_paths(_augmented_graph_view); + + // TODO: deal with untraced functions. + + initialize_save_and_restore_functions(); + initialize_nodes_and_concrete_functions(); + + _captured_tensor_node_ids = new(); + } + + private void initialize_save_and_restore_functions() + { + // TODO: deal with the return value of `get_checkpoint_factories_and_keys`. + SaveUtilV1.get_checkpoint_factories_and_keys(_object_names); + // skip the process of registered savers and the generation of saveable_objects_map and _obj_to_registered_saver. + _obj_to_registered_saver = new(); + _saveable_objects_map = new(); + } + + private void initialize_nodes_and_concrete_functions() + { + _nodes = _trackable_objects.ConvertAll(x => x); // deep copy + _gradient_functions = new(); + _gradient_defs = new(); + + // TODO: deal with the condition that obj in `_saveable_objects_map`. + // foreach (var obj in _nodes) + // { + // + // } + + foreach (var obj in _nodes) + { + if (obj is ConcreteFunction) + { + _concrete_functions.Add((ConcreteFunction)obj); + } + } + } + + public List get_concrete_resource_initializers() + { + // TODO: complete the implementation. + return new List(); + } + + public (Dictionary, Dictionary, AssetInfo) map_resources() + { + Debug.Assert(!tf.Context.executing_eagerly()); + + Dictionary object_map = new(); + Dictionary tensor_map = new(); + + AssetInfo assetInfo = new(new List(), new Dictionary(), + new Dictionary(), new Dictionary()); + + foreach (var node_id in dependency_sorted_node_ids()) + { + var obj = _nodes[node_id]; + var tensors = obj.export_to_saved_model_graph(object_map, tensor_map, _options); + // TODO: deal with Asset (if obj is Asset) + foreach (var tensor in tensors) + { + _captured_tensor_node_ids[tensor] = node_id; + } + } + + return (object_map, tensor_map, assetInfo); + } + + /// + /// Returns topologically sorted nodes, sorted by dependencies. + /// + public List dependency_sorted_node_ids() + { + Dictionary> dependency_map = new(); + foreach (var node in _nodes) + { + var node_id = _node_ids[node]; + List deps = new(); + + // TODO: deal with captured tensor. + + string node_path; + foreach (var (_, dep) in _augmented_graph_view.list_dependencies(node)) + { + if (!_node_ids.ContainsKey(dep)) + { + node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]); + throw new ValueError( + $"Found an untracked dependency. Object {node_path} depends on {dep}, " + + $"but this dependency isn't listed as a child. Please track this child by " + + $"overriding `_trackable_children` or use `._track_trackable`."); + } + deps.Add(_node_ids[dep]); + } + } + + try + { + return TrackableUtils.order_by_dependency(dependency_map); + } + catch (TrackableUtils.CyclicDependencyError err) + { + List pretty_printed_nodes = new(); + List pretty_printed_dependencies = new(); + + foreach (var pair in err.LeftOverDependencyMap) + { + var x = pair.Key; + var deps = pair.Value; + var node_path = TrackableUtils.pretty_print_node_path(_node_paths[_nodes[x]]); + pretty_printed_nodes.Add($"\tNode {x.ToString()} = {node_path} (type {_nodes[x]})"); + pretty_printed_dependencies.Add( + $"\tNode {x.ToString()} depends on nodes [{string.Join(", ", deps.Select(x => x.ToString()))}]"); + } + + throw new ValueError($"There is one or more dependency cycle in the saved Trackable object. " + + $"Saving cannot continue until this cycle is resolved." + + $"\n>> Unresolved nodes:\n{string.Join("\n", pretty_printed_nodes)}" + + $"\n>> Unresolved cyclic dependencies:\n{string.Join("\n", pretty_printed_dependencies)}"); + } + } + + /// + /// Corresponding to tensorflow/python/saved_model/save.py/_serialize_object_graph + /// + /// + /// + public SavedObjectGraph serialize_object_graph(IDictionary asset_file_def_index, SaveOptions options) + { + SavedObjectGraph proto = new(); + fill_object_graph_proto(proto); + + // TODO: complete the process of concrete functions. + + int cnt = Math.Min(_nodes.Count, proto.Nodes.Count); + for (int i = 0; i < cnt; i++) + { + var obj = _nodes[i]; + var obj_proto = proto.Nodes[i]; + write_object_proto(obj, obj_proto, asset_file_def_index, x => _augmented_graph_view.list_children(x), + options); + } + + return proto; + } + + private static void write_object_proto(Trackable obj, SavedObject proto, + IDictionary asset_file_def_index, Func> list_children_fn, SaveOptions options) + { + // skip the process of type Asset + if (resource_variable_ops.is_resource_variable(obj)) + { + // TODO: complete it. + throw new NotImplementedException(); + } + else if (obj is Function) + { + // TODO: complete it. + throw new NotImplementedException(); + } + else if (obj is ConcreteFunction) + { + // TODO: complete it. + throw new NotImplementedException(); + } + // skip the process of type `_CapturedTensor` and `CapturableResource`. + else + { + var registered_type_proto = RevivedTypes.serialize(obj); + if (registered_type_proto is null) + { + registered_type_proto = new SavedUserObject() + { + Identifier = obj.ObjectIdentifier, + Version = new VersionDef() + { + Producer = 1, + MinConsumer = 1, + BadConsumers = { } + } + }; + } + + proto.UserObject = new SavedUserObject(registered_type_proto); + } + + // TODO: try get the registered_name from `registration`. + } + + public void fill_object_graph_proto(SavedObjectGraph proto) + { + for (int node_id = 0; node_id < _nodes.Count; node_id++) + { + var node = _nodes[node_id]; + Debug.Assert(_node_ids[node] == node_id); + SavedObject object_proto = new(); + if (_slot_variables.TryGetValue(node, out var value)) + { + object_proto.SlotVariables.AddRange(value); + } + // skip the check of type `_CapturedTensor` + foreach (var child in _augmented_graph_view.list_children(node)) + { + var child_proto = new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference(); + child_proto.NodeId = _node_ids[child.Refer]; + child_proto.LocalName = child.Name; + object_proto.Children.Add(child_proto); + } + + foreach (var pair in _augmented_graph_view.list_dependencies(node)) + { + var child_proto = new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference(); + child_proto.NodeId = _node_ids[pair.Item2]; + child_proto.LocalName = pair.Item1; + object_proto.Dependencies.Add(child_proto); + } + + if (_saveable_objects_map.ContainsKey(node)) + { + // TODO: complete it. + throw new NotImplementedException(); + } + else if(_obj_to_registered_saver.ContainsKey(node)) + { + // TODO: complete it. + // We now skip it for the lack of `SavedObject.registered_saver` API. + throw new NotImplementedException(); + } + + proto.Nodes.Add(object_proto); + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs new file mode 100644 index 000000000..9a066eed7 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs @@ -0,0 +1,10 @@ +namespace Tensorflow; + +public static class TagConstants +{ + public static readonly string SERVING = "serve"; + public static readonly string TRAINING = "train"; + public static readonly string EVAL = "eval"; + public static readonly string GPU = "gpu"; + public static readonly string TPU = "tpu"; +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs new file mode 100644 index 000000000..bcd3ae05a --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public class BuilderUtils +{ + public static void copy_assets_to_destination_dir(IDictionary asset_filename_map, + string destination_dir, HashSet? saved_files = null) + { + if (saved_files is null) saved_files = new HashSet(); + + var asset_destination_dir = SavedModelUtils.get_or_create_assets_dir(destination_dir); + + // TODO: complete the implementation of this function. + if (asset_filename_map is not null && asset_filename_map.Count > 0) + { + throw new NotImplementedException(); + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs new file mode 100644 index 000000000..692356054 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs @@ -0,0 +1,256 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using Google.Protobuf; +using Tensorflow.Checkpoint; +using Tensorflow.Functions; +using Tensorflow.ModelSaving; +using Tensorflow.Train; +using Tensorflow.Exceptions; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public static partial class SavedModelUtils +{ + private static readonly IEnumerable byte_swappable = new List() + { + dtypes.float16, dtypes.float32, dtypes.float64, TF_DataType.TF_BFLOAT16, + dtypes.complex64, dtypes.complex128, TF_DataType.TF_UINT16, dtypes.uint32, + dtypes.uint64, TF_DataType.TF_INT16, dtypes.int32, dtypes.int64, TF_DataType.TF_QINT16, + TF_DataType.TF_QUINT16, TF_DataType.TF_QINT32 + }.Select(x => (int)x); + + public static (IList, IDictionary>) save_and_return_nodes(Trackable obj, + string export_dir, IDictionary? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false) + { + if (options is null) + { + options = new SaveOptions(); + } + + var saved_model = new Tensorflow.SavedModel(); + var meta_graph_def = new MetaGraphDef(); + saved_model.MetaGraphs.Add(meta_graph_def); + + var (_, exported_graph, object_saver, asset_info, saved_nodes, node_paths) = + _build_meta_graph(obj, signatures, options, meta_graph_def); + saved_model.SavedModelSchemaVersion = Tensorflow.Constants.SAVED_MODEL_SCHEMA_VERSION; + + if (!experimental_skip_checkpoint) + { + Tensorflow.SavedModelUtils.get_or_create_variables_dir(export_dir); + CheckpointOptions ckpt_options = new(options.experimental_io_device); + object_saver.save(Tensorflow.SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options); + } + BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir); + + if (tf.Context.executing_eagerly()) + { + // tensorflow python has a check of `context.async_wait()` here. + } + + // TODO: deal with `pywrap_saved_model.Save(export_dir)`. + + var saved_model_serialized = saved_model.ToString(); + + // This is a state depending on some py-c APIs. Here we temporarily set it as `true`. + if (true) + { + var fingerprint_path = Path.Combine(tf.compat.as_str(export_dir), + tf.compat.as_str(Constants.FINGERPRINT_FILENAME)); + // TODO: add c api and complete the fingerprint def. + var fingerprint_proto = ""; + File.WriteAllText(fingerprint_path, fingerprint_proto); + } + + var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB)); + File.WriteAllText(path, saved_model.ToString()); + + if (options.save_debug_info) + { + throw new NotImplementedException(); + } + + ops.dismantle_graph(exported_graph); + + return (saved_nodes, node_paths); + } + + private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List, + Dictionary>) _build_meta_graph(Trackable obj, + IDictionary? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) + { + if (ops.inside_function()) + { + throw new AssertionError("`tf.saved_model.save` is not supported inside a traced @tf.function. " + + "Move the call to the outer eagerly-executed context."); + } + + if (meta_graph_def is null) + { + meta_graph_def = new MetaGraphDef(); + } + + AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); + if (signatures is not null) + { + throw new NotImplementedException(); + } + + // TODO: process of aignatures and wrapped_functions + + SaveableView saveable_view = new SaveableView(augmented_graph_view, options); + TrackableSaver object_saver = new TrackableSaver(augmented_graph_view); + var (asset_info, exported_graph) = _fill_meta_graph_def(meta_graph_def, saveable_view, signatures, + options.namespace_white_list, options.experimental_custom_gradients); + if (options.function_aliases is not null) + { + var function_aliases = meta_graph_def.MetaInfoDef.FunctionAliases; + foreach (var pair in options.function_aliases) + { + var alias = pair.Key; + var func = pair.Value; + // TODO: complete it. + throw new NotImplementedException(); + } + } + + var object_graph_proto = saveable_view.serialize_object_graph(asset_info.asset_index, options); + meta_graph_def.ObjectGraphDef = new SavedObjectGraph(object_graph_proto); + + return (meta_graph_def, exported_graph, object_saver, asset_info, saveable_view.Nodes, saveable_view.NodePaths); + } + + private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view, + IDictionary signatures, IEnumerable namespace_whitelist, + bool save_custom_gradients) + { + var resource_initializers = saveable_view.get_concrete_resource_initializers(); + var exported_graph = new Graph(); + + Dictionary object_map; + Dictionary tensor_map; + AssetInfo asset_info; + using (var g = exported_graph.as_default()) + { + (object_map, tensor_map, asset_info) = saveable_view.map_resources(); + // TODO: deal with signatures. + if (save_custom_gradients) + { + // TODO: trace gradient functions. + } + + foreach (var resource_initializer_function in resource_initializers) + { + // List asset_dependencies = new(); + // TODO: deal with initializers + } + + // using(ops.control_dependencies(...)) + var init_op = control_flow_ops.no_op(); + if (meta_graph_def.CollectionDef.ContainsKey(Tensorflow.Constants.MAIN_OP_KEY)) + { + meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY].NodeList.Value.Append(init_op.name); + } + else + { + meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY] = new CollectionDef(); + } + // Lack `CopyFrom` API + // meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY] + } + + foreach (var obj in object_map.Values) + { + obj._maybe_initialize_trackable(); + } + + var (named_saveable_objects, registered_savers) = + SaveUtilV1.frozen_saveables_and_savers(saveable_view.AugmentedGraphView, object_map, exported_graph, false); + + // TODO: complete the save of checkpoints with `MultiDeviceSaver`. + + saveable_view.dependency_sorted_node_ids(); + + var graph_def = exported_graph.as_graph_def(true); + graph_def.Library.RegisteredGradients.AddRange(saveable_view.GradientDefs); + verify_ops(graph_def, namespace_whitelist); + + meta_graph_def.GraphDef = new GraphDef(graph_def); + meta_graph_def.MetaInfoDef.Tags.Add(TagConstants.SERVING); + meta_graph_def.MetaInfoDef.TensorflowVersion = tf.VERSION; + // TODO: add git version. + meta_graph_def.MetaInfoDef.TensorflowGitVersion = ""; + meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; + meta_graph_def.MetaInfoDef.StrippedOpList.MergeFrom(meta_graph.stripped_op_list_for_graph(meta_graph_def.GraphDef)); + meta_graph_def.AssetFileDef.AddRange(asset_info.asset_defs); + + // TODO: deal with signatures here. + + meta_graph.strip_graph_default_valued_attrs(meta_graph_def); + + if (!BitConverter.IsLittleEndian) + { + swap_function_tensor_content(meta_graph_def); + } + + return (asset_info, exported_graph); + } + + private static void verify_ops(GraphDef graph_def, IEnumerable? namespace_whitelist) + { + return; + // if (namespace_whitelist is null || !namespace_whitelist.Any()) + // { + // return; + // } + + // skip the check for the lack of `meta_graph.ops_used_by_graph_def`. + } + + public static void swap_function_tensor_content(MetaGraphDef meta_graph_def) + { + var functions = meta_graph_def.GraphDef.Library.Function; + foreach (var function in functions) + { + var node_def = function.NodeDef; + foreach (var node in node_def) + { + if (node.Op == "Const") + { + var tensor = node.Attr["value"].Tensor; + byte_swap_tensor_content(tensor); + } + } + } + } + + public static void byte_swap_tensor_content(TensorProto tensor) + { + if (byte_swappable.Contains((int)tensor.Dtype)) + { + var tshape = tensor.TensorShape.Dim; + var tensor_bytes = tensor.TensorContent; + if (tensor_bytes is not null && !tensor_bytes.IsEmpty) + { + long tensor_size = 1; + foreach (var sz in tshape) + { + tensor_size *= sz.Size; + } + + var chunksize = tensor_bytes.Length / tensor_size; + List reversed_bytes = new(); + for (int i = 0; i < tensor_bytes.Length; i += (int)chunksize) + { + var current = tensor_bytes.Skip(i).Take((int)chunksize).Reverse(); + reversed_bytes.AddRange(current); + } + tensor.TensorContent = ByteString.CopyFrom(reversed_bytes.ToArray()); + } + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs new file mode 100644 index 000000000..21272941f --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs @@ -0,0 +1,58 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Functions; +using Tensorflow.Train; + +namespace Tensorflow; + +public class SignatureMap: Trackable +{ + private Dictionary _signatures; + private Dictionary _concrete_signatures; + + public SignatureMap() + { + _signatures = new(); + } + + public void _add_signature(string name, ConcreteFunction concrete_function) + { + _concrete_signatures[name] = concrete_function; + } + + public void _add_signature(string name, Function concrete_function) + { + _signatures[name] = concrete_function; + } + + public override IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + { + if (save_type != SaveType.SAVEDMODEL) + { + return new Dictionary(); + } + + Dictionary res = _signatures.ToDictionary(x => x.Key, x => (Trackable)x.Value); + foreach (var pair in _concrete_signatures) + { + res[pair.Key] = pair.Value; + } + + return res; + } + + public static SignatureMap create_signature_map(IDictionary signatures) + { + var signature_map = new SignatureMap(); + foreach (var pair in signatures) + { + var name = pair.Key; + var func = pair.Value; + // TODO: assert the arg_keywords + signature_map._add_signature(name, func); + } + + return signature_map; + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs new file mode 100644 index 000000000..723419f6f --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs @@ -0,0 +1,52 @@ +using System.IO; +using System.Security.Cryptography.X509Certificates; +using Tensorflow.Train; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public static partial class SavedModelUtils +{ + /// + /// Return variables sub-directory, or create one if it doesn't exist. + /// + /// + public static string get_or_create_variables_dir(string export_dir) + { + var variables_dir = get_variables_dir(export_dir); + Directory.CreateDirectory(variables_dir); + return variables_dir; + } + + /// + /// Return variables sub-directory in the SavedModel. + /// + /// + /// + public static string get_variables_dir(string export_dir) + { + return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.VARIABLES_DIRECTORY)); + } + + /// + /// Return assets sub-directory, or create one if it doesn't exist. + /// + /// + /// + public static string get_or_create_assets_dir(string export_dir) + { + var assets_destination_dir = get_assets_dir(export_dir); + Directory.CreateDirectory(assets_destination_dir); + return assets_destination_dir; + } + + /// + /// Return path to asset directory in the SavedModel. + /// + /// + /// + public static string get_assets_dir(string export_dir) + { + return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.ASSETS_DIRECTORY)); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index 3a6647880..98cdb274a 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -17,12 +17,17 @@ limitations under the License. using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Train; using static Tensorflow.Binding; namespace Tensorflow { - public class saveable_object_util + public static class saveable_object_util { + public class TrackableSaveable: MySaveableObject + { + + } /// /// Returns the variables and names that will be used for a Saver. /// @@ -121,5 +126,17 @@ public static Dictionary op_list_to_dict(IVariableV1[] op_list, return names_to_saveables; } + + public static IDictionary saveable_objects_from_trackable(Trackable obj) + { + // TODO: complete the implementation. + return obj.gather_saveables_for_checkpoint(); + } + + public static bool trackable_has_serialize_to_tensor(Trackable obj) + { + // TODO: implement it. + return false; + } } } diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index 79d6dca92..dce0be2ac 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -14,14 +14,38 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.ModelSaving; using static Tensorflow.Binding; namespace Tensorflow.Train { public abstract class Trackable { + /// + /// Corresponding to tensorflow/python/trackable/constants.py + /// + public static class Constants + { + public static readonly string OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"; + public static readonly string VARIABLE_VALUE_KEY = "VARIABLE_VALUE"; + public static readonly string OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"; + } protected int _self_update_uid; + protected IDictionary _unconditional_dependency_names = + new Dictionary(); + + protected IList _unconditional_checkpoint_dependencies = new List(); + protected IDictionary _self_saveable_object_factories = + new Dictionary(); + public virtual string ObjectIdentifier + { + get => "_generic_user_object"; + } + /// /// Restore-on-create for a variable be saved with this `Checkpointable`. /// @@ -73,10 +97,63 @@ protected IVariableV1 _track_checkpointable(IVariableV1 checkpointable, string n /// /// Initialize dependency management. /// - protected void _maybe_initialize_trackable() + public void _maybe_initialize_trackable() { // _self_unconditional_checkpoint_dependencies = [] _self_update_uid = -1; } + + // TODO: cache + public virtual IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + { + _maybe_initialize_trackable(); + return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); + } + + public static Trackable convert_to_trackable(object obj, object? parent = null) + { + if (obj is Trackable) + { + return (Trackable)obj; + } + else + { + throw new NotImplementedException(); + } + } + + public virtual IDictionary deserialization_dependencies(IDictionary children) + { + return new Dictionary(); + } + + public virtual (IDictionary, IDictionary) map_resources( + SaveOptions? save_options) + { + return (new Dictionary(), new Dictionary()); + } + + public virtual List export_to_saved_model_graph(IDictionary? object_map = null, + IDictionary? tensor_map = null, SaveOptions? options = null) + { + var (self_object_map, self_tensor_map) = map_resources(options); + foreach (var pair in self_object_map) + { + object_map.Add(pair); + } + foreach (var pair in self_tensor_map) + { + tensor_map.Add(pair); + } + + return self_tensor_map.Keys.ToList(); + } + + public virtual IDictionary gather_saveables_for_checkpoint() + { + return _self_saveable_object_factories; + } } + + public record class TrackableReference(string Name, Trackable Refer); } diff --git a/src/TensorFlowNET.Core/Training/TrackableUtils.cs b/src/TensorFlowNET.Core/Training/TrackableUtils.cs new file mode 100644 index 000000000..990207028 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/TrackableUtils.cs @@ -0,0 +1,148 @@ +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Exceptions; +using Tensorflow.Train; + +namespace Tensorflow.Training; + +public static class TrackableUtils +{ + public class CyclicDependencyError: System.Exception + { + public IDictionary> LeftOverDependencyMap { get; } + public CyclicDependencyError(IDictionary> leftover_dependency_map): base() + { + LeftOverDependencyMap = leftover_dependency_map; + } + public CyclicDependencyError(IDictionary> leftover_dependency_map): base() + { + LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable()); + } + } + private static string _ESCAPE_CHAR = "."; + private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; + private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; + private static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; + public static string object_path_to_string(IEnumerable node_path_arr) + { + return string.Join("/", node_path_arr.Select(x => escape_local_name(x.Name))); + } + + public static string escape_local_name(string name) + { + return name.Replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR).Replace("/", _ESCAPE_CHAR + "S"); + } + + public static string checkpoint_key(string object_path, string local_name) + { + var key_suffix = escape_local_name(local_name); + if (local_name == SERIALIZE_TO_TENSORS_NAME) + { + key_suffix = ""; + } + + return $"{object_path}/{OBJECT_ATTRIBUTES_NAME}/{key_suffix}"; + } + + /// + /// Topologically sorts the keys of a map so that dependencies appear first. + /// Uses Kahn's algorithm: https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm + /// + /// + /// + public static List order_by_dependency(IDictionary> dependency_map) + { + Dictionary> reverse_dependency_map = new(); + foreach (var pair in dependency_map) + { + foreach (var dep in pair.Value) + { + if (reverse_dependency_map.ContainsKey(dep)) + { + reverse_dependency_map[dep].Add(pair.Key); + } + else + { + reverse_dependency_map[dep] = new HashSet(); + reverse_dependency_map[dep].Add(pair.Key); + } + } + } + + // Validate that all values in the dependency map are also keys. + var unknown_keys = reverse_dependency_map.Keys.Except(dependency_map.Keys); + if (unknown_keys.Count() > 0) + { + throw new ValueError( + $"Found values in the dependency map which are not keys: {string.Join(", ", unknown_keys.Select(x => x.ToString()))}"); + } + + // Generate the list sorted by objects without dependencies -> dependencies. + // The returned list will reverse this. + List reversed_dependency_arr = new(); + + Queue to_visit = new(); + foreach (var x in dependency_map.Keys) + { + if (!reverse_dependency_map.ContainsKey(x)) + { + to_visit.Enqueue(x); + } + } + + while (to_visit.Count > 0) + { + var x = to_visit.Dequeue(); + reversed_dependency_arr.Add(x); + foreach (var dep in dependency_map[x].Distinct()) + { + var edges = reverse_dependency_map[dep]; + edges.Remove(x); + if (edges.Count == 0) + { + to_visit.Enqueue(dep); + if (!reverse_dependency_map.Remove(dep)) + { + throw new KeyError($"Cannot find the key {dep} in reverse_dependency_map"); + } + } + } + } + + if (reverse_dependency_map.Count > 0) + { + Dictionary> leftover_dependency_map = new(); + foreach (var pair in reverse_dependency_map) + { + foreach (var x in pair.Value) + { + if (leftover_dependency_map.ContainsKey(x)) + { + leftover_dependency_map[x].Add(pair.Key); + } + else + { + leftover_dependency_map[x] = new List() { pair.Key }; + } + } + } + + throw new CyclicDependencyError(leftover_dependency_map); + } + + reversed_dependency_arr.Reverse(); + return reversed_dependency_arr; + } + + public static string pretty_print_node_path(IEnumerable paths) + { + if (paths.Count() == 0) + { + return "root object"; + } + else + { + return $"root.{string.Join(".", paths.Select(x => x.Name))}"; + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index b270ec57d..0a050d0f1 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -2,6 +2,7 @@ using System; using Tensorflow.Eager; using Tensorflow.Variables; +using Tensorflow.Train; using static Tensorflow.Binding; namespace Tensorflow diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 95e8db577..bf5ae7bee 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -566,5 +566,23 @@ public static bool executing_eagerly_outside_functions() else throw new NotImplementedException(""); } + + public static bool inside_function() + { + return get_default_graph().building_function; + } + + public static void dismantle_graph(Graph graph) + { + + } + + public class NullContextManager: IDisposable + { + public void Dispose() + { + + } + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs new file mode 100644 index 000000000..1675fba1e --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs @@ -0,0 +1,31 @@ +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Engine; + +public abstract partial class Layer +{ + public LayerSavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); + + public string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; + + public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata; + + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary? cache = null) + { + IDictionary children; + if (save_type == SaveType.SAVEDMODEL) + { + // TODO: deal with cache. + children = TrackableSavedModelSaver.trackable_children(cache); + } + else + { + children = new Dictionary(); + } + + return children.Concat(base._trackable_children(save_type, cache)).ToDictionary(x => x.Key, x => x.Value); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index ba40b1a22..e95e55d6d 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -49,6 +49,8 @@ public abstract partial class Layer : AutoTrackable, ILayer public bool Built => built; public bool Trainable => args.Trainable; public TF_DataType DType => args.DType; + public bool AutoCast => args.Autocast; + public IRegularizer ActivityRegularizer => args.ActivityRegularizer; /// /// A stateful layer is a layer whose updates are run during inference too, @@ -162,7 +164,7 @@ private Tensor compute_mask(Tensor inputs, Tensor mask = null) /// /// /// - /// + /// /// protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { diff --git a/src/TensorFlowNET.Keras/Engine/Model.Save.cs b/src/TensorFlowNET.Keras/Engine/Model.Save.cs index c287309d4..59f74cd20 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Save.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Save.cs @@ -1,5 +1,7 @@ using System.Collections.Generic; +using Tensorflow.Functions; using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.ModelSaving; namespace Tensorflow.Keras.Engine @@ -18,9 +20,18 @@ public void save(string filepath, bool overwrite = true, bool include_optimizer = true, string save_format = "tf", - SaveOptions options = null) + SaveOptions? options = null, + IDictionary? signatures = null, + bool save_traces = true) { - saver.save(this, filepath); + if (save_format != "pb") + { + saver.save(this, filepath); + } + else + { + KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); + } } } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs index 162d06c57..835f6041b 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -35,6 +35,12 @@ public partial class Model : Layer, IModel bool _base_model_initialized; bool stop_training; DataHandler data_handler; + + public OptimizerV2 Optimizer + { + get => optimizer; + set => optimizer = value; + } public Model(ModelArgs args) : base(args) diff --git a/src/TensorFlowNET.Keras/Protobuf/SavedMetadata.cs b/src/TensorFlowNET.Keras/Protobuf/SavedMetadata.cs index 61cec6468..f29f2dec3 100644 --- a/src/TensorFlowNET.Keras/Protobuf/SavedMetadata.cs +++ b/src/TensorFlowNET.Keras/Protobuf/SavedMetadata.cs @@ -194,6 +194,18 @@ public SavedObject() { OnConstruction(); } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SavedObject(int nodeId, string nodePath, + global::ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef version, string identifier, string metadata) + { + OnConstruction(); + nodeId_ = nodeId; + nodePath_ = nodePath; + identifier_ = identifier; + metadata_ = metadata; + version_ = version; + } + partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] diff --git a/src/TensorFlowNET.Keras/Protobuf/Versions.cs b/src/TensorFlowNET.Keras/Protobuf/Versions.cs index 40405a5a6..ff9a23c62 100644 --- a/src/TensorFlowNET.Keras/Protobuf/Versions.cs +++ b/src/TensorFlowNET.Keras/Protobuf/Versions.cs @@ -74,6 +74,13 @@ public sealed partial class VersionDef : pb::IMessage { public VersionDef() { OnConstruction(); } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public VersionDef(int producer, int minConsumer) { + OnConstruction(); + producer_ = producer; + minConsumer_ = minConsumer; + } partial void OnConstruction(); diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs new file mode 100644 index 000000000..ea6853fde --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs @@ -0,0 +1,41 @@ +using System.Collections.Generic; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public static class Constants +{ + /// + /// Namespace used to store all attributes added during serialization. + /// e.g. the list of layers can be accessed using `loaded.keras_api.layers`, in an + /// object loaded from `tf.saved_model.load()`. + /// + public static readonly string KERAS_ATTR = "keras_api"; + /// + /// Keys for the serialization cache. + /// Maps to the keras serialization dict {Layer --> SerializedAttributes object} + /// + public static readonly string KERAS_CACHE_KEY = "keras_serialized_attributes"; + /// + /// Name of Keras metadata file stored in the SavedModel. + /// + public static readonly string SAVED_METADATA_PATH = "keras_metadata.pb"; + + public static readonly string INPUT_LAYER_IDENTIFIER = "_tf_keras_input_layer"; + public static readonly string LAYER_IDENTIFIER = "_tf_keras_layer"; + public static readonly string METRIC_IDENTIFIER = "_tf_keras_metric"; + public static readonly string MODEL_IDENTIFIER = "_tf_keras_model"; + public static readonly string NETWORK_IDENTIFIER = "_tf_keras_network"; + public static readonly string RNN_LAYER_IDENTIFIER = "_tf_keras_rnn_layer"; + public static readonly string SEQUENTIAL_IDENTIFIER = "_tf_keras_sequential"; + + public static readonly IList KERAS_OBJECT_IDENTIFIERS = new List() + { + INPUT_LAYER_IDENTIFIER, + LAYER_IDENTIFIER, + METRIC_IDENTIFIER, + MODEL_IDENTIFIER, + NETWORK_IDENTIFIER, + RNN_LAYER_IDENTIFIER, + SEQUENTIAL_IDENTIFIER + }; +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs new file mode 100644 index 000000000..a5f315bb3 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs @@ -0,0 +1,11 @@ +namespace Tensorflow.Keras.Saving.SavedModel; + +public class KerasObjectWrapper +{ + +} + +public class KerasObjectWrapper +{ + public T Item { get; set; } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs new file mode 100644 index 000000000..76453ca0d --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -0,0 +1,115 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Google.Protobuf; +using ICSharpCode.SharpZipLib.Zip; +using Tensorflow.Checkpoint; +using Tensorflow.Contexts; +using Tensorflow.Functions; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using Tensorflow.ModelSaving; +using Tensorflow.Train; +using Tensorflow.Exceptions; +using Tensorflow.IO; +using Tensorflow.Keras.Optimizers; +using ThirdParty.Tensorflow.Python.Keras.Protobuf; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public partial class KerasSavedModelUtils +{ + public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, IDictionary? signatures, + SaveOptions? options, bool save_traces = true) + { + if (!overwrite && File.Exists(filepath)) + { + throw new Exception("The file already exists but is not allowed to overwrite it."); + } + + if (save_traces) + { + if(should_skip_serialization(model)) + { + throw new NotImplementedException(); + } + } + + OptimizerV2? orig_optimizer = null; + if (!include_optimizer) + { + orig_optimizer = model.Optimizer; + model.Optimizer = null; + model._delete_tracking("optimizer"); + } + + IList saved_nodes; + IDictionary> node_paths; + // skip two scopes of python + using (KerasSavedModelUtils.keras_option_scope(save_traces)) + { + (saved_nodes, node_paths) = Tensorflow.SavedModelUtils.save_and_return_nodes(model, filepath, signatures, options); + } + + var metadata = generate_keras_metadata(saved_nodes, node_paths); + using (var f = new FileStream(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), FileMode.OpenOrCreate, + FileAccess.Write)) + { + var writer = new StreamWriter(f); + writer.Write(metadata.ToString()); + } + + if (!include_optimizer) + { + model.Optimizer = orig_optimizer!; + } + } + + public static SavedMetadata generate_keras_metadata(IList saved_nodes, + IDictionary> node_paths) + { + var metadata = new SavedMetadata(); + for (int i = 0; i < saved_nodes.Count; i++) + { + var node = saved_nodes[i]; + if (node is not Layer) + { + continue; + } + + Layer layer = (Layer)node; + + var path = node_paths[node]; + string node_path; + if (path is null) + { + node_path = "root"; + } + else + { + node_path = $"root.{string.Join(".", path.Select(x => x.Name))}"; + } + + ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedObject saved_object = new() + { + NodeId = i, + NodePath = node_path, + Version = new ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef() + { + Producer = 2, + MinConsumer = 1, + BadConsumers = { } + }, + Identifier = layer.ObjectIdentifier, + Metadata = layer.TrackingMetadata + }; + + } + + return metadata; + } + + +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs new file mode 100644 index 000000000..ba0bcc663 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs @@ -0,0 +1,19 @@ +using System.Collections.Generic; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public partial class KerasSavedModelUtils +{ + public static bool should_skip_serialization(object layer) + { + return false; + } + + public static IDictionary wrap_layer_objects(Layer layer, object serialization_cache) + { + // TODO: process the loss + + return null; + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs new file mode 100644 index 000000000..36111a18e --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs @@ -0,0 +1,40 @@ +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.Engine; +using Newtonsoft.Json; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public abstract class SavedModelSaver +{ + private Trackable _obj; + public SavedModelSaver(Trackable obj) + { + _obj = obj; + } + + public abstract string ObjectIdentifier { get; } + public abstract string TrackingMetadata { get; } + + public abstract IDictionary objects_to_serialize( + IDictionary serialization_cache); + + public abstract IDictionary functions_to_serialize( + IDictionary serialization_cache); + + public IDictionary trackable_children(IDictionary? serialization_cache) + { + if (!KerasSavedModelUtils.ShouldHaveTraces) + { + return new Dictionary(); + } + + var children = objects_to_serialize(serialization_cache); + + return children.ToDictionary(x => x.Key, x => (Trackable)x.Value) + .Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) + .ToDictionary(x => x.Key, x => x.Value); + } + +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs new file mode 100644 index 000000000..ade8ae73e --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs @@ -0,0 +1,62 @@ +using System.Collections.Generic; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public class LayerSavedModelSaver: SavedModelSaver +{ + private Layer _obj; + public LayerSavedModelSaver(Layer obj): base(obj) + { + _obj = obj; + } + public override string ObjectIdentifier + { + get => Constants.LAYER_IDENTIFIER; + } + + public override IDictionary objects_to_serialize(IDictionary serialization_cache) + { + throw new System.NotImplementedException(); + } + + public override IDictionary functions_to_serialize(IDictionary serialization_cache) + { + throw new System.NotImplementedException(); + } + + public override string TrackingMetadata + { + get + { + JObject metadata = new JObject(); + metadata["name"] = _obj.Name; + metadata["trainable"] = _obj.Trainable; + // metadata["expects_training_arg"] = _obj._expects_training_arg; + // metadata["dtype"] = policy.serialize(_obj._dtype_policy) + metadata["batch_input_shape"] = JToken.FromObject(_obj.BatchInputShape); + // metadata["stateful"] = _obj.stateful; + // metadata["must_restore_from_config"] = _obj.must_restore_from_config; + // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; + metadata["autocast"] = _obj.AutoCast; + + metadata.Merge(JObject.FromObject(get_serialized(_obj)), new JsonMergeSettings + { + // Handle conflicts by using values from obj2 + MergeArrayHandling = MergeArrayHandling.Merge + }); + // skip the check of `input_spec` and `build_input_shape` for the lack of members. + // skip the check of `activity_regularizer` for the type problem. + return metadata.ToString(); + } + } + + public static LayerConfig get_serialized(Layer obj) + { + return generic_utils.serialize_keras_object(obj); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs new file mode 100644 index 000000000..30e895827 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs @@ -0,0 +1,33 @@ +using System; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public partial class KerasSavedModelUtils +{ + public static bool ShouldHaveTraces { get; internal set; } + + public static SaveOptionsContext keras_option_scope(bool save_traces) + { + var res = new SaveOptionsContext(ShouldHaveTraces); + ShouldHaveTraces = save_traces; + return res; + } +} + +/// +/// Implementation of this class is different with that of python. +/// But it could be used with `using` the same as `with` of python. +/// +public class SaveOptionsContext: IDisposable +{ + public bool _old_value; + public SaveOptionsContext(bool old_value) + { + _old_value = true; + } + + public void Dispose() + { + KerasSavedModelUtils.ShouldHaveTraces = _old_value; + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs b/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs new file mode 100644 index 000000000..9d1b30886 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs @@ -0,0 +1,60 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Tensorflow; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using Tensorflow.Keras; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Optimizers; + +namespace TensorFlowNET.Keras.UnitTest; + +// class MNISTLoader +// { +// public MNISTLoader() +// { +// var mnist = new MnistModelLoader() +// +// } +// } + +[TestClass] +public class SaveTest +{ + [TestMethod] + public void Test() + { + var inputs = new KerasInterface().Input((28, 28, 1)); + var x = new Flatten(new FlattenArgs()).Apply(inputs); + x = new Dense(new DenseArgs() { Units = 100, Activation = tf.nn.relu }).Apply(x); + x = new LayersApi().Dense(units: 10).Apply(x); + var outputs = new LayersApi().Softmax(axis: 1).Apply(x); + var model = new KerasInterface().Model(inputs, outputs); + + model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[]{"accuracy"}); + + var data_loader = new MnistModelLoader(); + var num_epochs = 1; + var batch_size = 50; + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 50000, + }).Result; + + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); + + model.save("", save_format:"pb"); + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj b/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj index 36ff4a3dd..56c212d0e 100644 --- a/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj @@ -47,7 +47,7 @@ - + From c4114d5f1815a5281ff2a607dd51e43f17e8a23b Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Mon, 23 Jan 2023 16:25:38 +0800 Subject: [PATCH 03/15] Add more facilities to the saved model framework. --- src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs | 253 ++++++++++++++++++ .../Checkpoint/SaveUtilV1.cs | 18 +- .../Checkpoint/SaveableCompat.cs | 16 ++ .../Checkpoint/TrackableSaver.cs | 109 -------- .../Checkpoint/TrackableView.cs | 8 +- .../Checkpoint/checkpoint.cs | 191 +++++++++++++ .../Checkpoint/functional_saver.cs | 36 +++ .../Protobuf/TrackableObjectGraph.cs | 10 + .../Training/AutoTrackable.cs | 51 +++- .../Training/Saving/SaveSpec.cs | 2 +- .../Training/Saving/SavedModel/save.cs | 52 ++-- .../Saving/saveable_object_util.py.cs | 50 ++++ src/TensorFlowNET.Core/Training/Trackable.cs | 42 ++- 13 files changed, 685 insertions(+), 153 deletions(-) create mode 100644 src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs delete mode 100644 src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/checkpoint.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/functional_saver.cs diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs new file mode 100644 index 000000000..dc2a92fb0 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs @@ -0,0 +1,253 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; + +namespace Tensorflow.Checkpoint +{ + internal record class TrackableData( + // A trackable in the root Trackable object graph. + Trackable trackable, + // The index at which the Trackable appears in TrackableObjectGraph.nodes. + int node_id, + // The BFS-generated path from the root object / used to generate readable checkpoint keys. + string object_name, + // A list of ObjectReference for each child connected to this Trackable. + pbc::RepeatedField children_proto, + // A list of SlotVariableReference to save to the object (only valid for Optimizer objects). + pbc::RepeatedField slot_variable_proto, + // The object to save to checkpoint. Usually this is the same as `trackable`, + // but can differ when the the caller wants to specify a different object to + // save. For example, when saving checkpoints asynchronously, variables are + // copied to the CPU. `object_to_save` is set as the copied variable. + Trackable object_to_save + ); + public static class SaveUtil + { + public static (IDictionary>, IDictionary, IDictionary>, TrackableObjectGraph) + serialize_graph_view(ObjectGraphView graph_view, IDictionary? object_map = null, bool call_with_mapped_captures = false, object? cache = null) + { + var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); + var (tensor_trackables, pystate_trackables, registered_trackables) = split_trackables(trackable_data); + + var object_graph_proto = fill_object_graph_proto(trackable_data); + + var serialized_tensors = get_and_write_tensors_to_serialize(tensor_trackables, node_ids, call_with_mapped_captures, cache, object_graph_proto); + var registered_savers = get_and_write_registered_savers(registered_trackables, object_graph_proto); + + Dictionary feed_additions; + if(cache is null) + { + feed_additions = null; + serialized_tensors = serialized_tensors.Concat(get_and_write_tensors_to_serialize(pystate_trackables, node_ids, call_with_mapped_captures, + cache, object_graph_proto)).ToDictionary(x => x.Key, x => x.Value); + } + else + { + feed_additions = null; + // TODO: deal with cache. + throw new NotFiniteNumberException(); + } + + CheckPointUtils.add_checkpoint_values_check(object_graph_proto); + + return (serialized_tensors, feed_additions, registered_savers, object_graph_proto); + } + + private static (List, Dictionary) gather_trackable_data(ObjectGraphView graph_view, IDictionary? object_map) + { + var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); + Dictionary object_names = new(); + foreach(var pair in node_paths) + { + object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); + } + Dictionary node_ids = new(); + for(int i = 0; i < trackable_objects.Count; i++) + { + node_ids[trackable_objects[i]] = i; + } + var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names); + List trackable_data = new(); + foreach(var trackable in trackable_objects) + { + pbc::RepeatedField children_proto = new(); + foreach(var child in graph_view.list_children(trackable)) + { + children_proto.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference() + { + NodeId = node_ids[child.Refer], + LocalName = child.Name + }); + } + slot_variables.TryGetValue(trackable, out var slot_variable); + trackable_data.Add(new TrackableData( + trackable: trackable, + node_id: node_ids[trackable], + object_name: object_names[trackable], + children_proto: children_proto, + slot_variable_proto: slot_variable??new pbc.RepeatedField(), + object_to_save: CheckPointUtils.get_mapped_trackable(trackable, object_map) + )); + } + return (trackable_data, node_ids); + } + + private static TrackableObjectGraph fill_object_graph_proto(IList trackable_data) + { + TrackableObjectGraph object_graph_proto = new(); + for(int i = 0; i < trackable_data.Count; i++) + { + var td = trackable_data[i]; + Debug.Assert(td.node_id == i); + object_graph_proto.Nodes.Add(new TrackableObjectGraph.Types.TrackableObject(td.slot_variable_proto, td.children_proto)); + } + return object_graph_proto; + } + + /// + /// Creates dictionary of tensors to checkpoint, and updates the proto. + /// + /// + /// + /// + /// + /// + private static IDictionary> get_and_write_tensors_to_serialize(IList tensor_trackables, IDictionary node_ids, + bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) + { + Dictionary> serialized_tensors = new(); + foreach(var td in tensor_trackables) + { + // TODO: deal with cache. + var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; + var trackable = td.object_to_save; + IDictionary tensor_dict; + if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) + { + (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); + } + else + { + tensor_dict = get_tensors_from_trackable(td, call_with_mapped_captures, object_graph_proto); + } + if(trackable is not null) + { + serialized_tensors[trackable] = tensor_dict; + } + else + { + serialized_tensors[Trackable.None] = tensor_dict; + } + } + return serialized_tensors; + } + + private static IDictionary get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) + { + var trackable = trackable_data.object_to_save; + + // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. + IDictionary ret_tensor_dict; + if (call_with_mapped_captures) + { + throw new NotImplementedException(); + } + else + { + ret_tensor_dict = trackable.serialize_to_tensors(); + } + + // TODO: revise the types and complete it + Dictionary tensor_dict = new(); + foreach(var pair in ret_tensor_dict) + { + var local_name = TrackableUtils.escape_local_name(pair.Key); + var maybe_tensor = pair.Value; + var checkpoint_key = TrackableUtils.checkpoint_key(trackable_data.object_name, local_name); + + tensor_dict[checkpoint_key] = maybe_tensor; + + if(maybe_tensor is SaveSpec) + { + ((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; + } + + if(object_graph_proto is not null) + { + object_graph_proto.Nodes[trackable_data.node_id].Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() + { + Name = local_name, + CheckpointKey = checkpoint_key, + FullName = CheckPointUtils.get_full_name(trackable) + }); + } + } + return tensor_dict; + } + + /// + /// Gets tensors to serialize from a Trackable with legacy SaveableObjects. + /// + /// + /// + /// + /// + /// + private static (Trackable, IDictionary) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary node_ids, + bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) + { + Dictionary object_names = new(); + object_names[trackable_data.trackable] = trackable_data.object_name; + Dictionary object_map = new(); + object_map[trackable_data.trackable] = trackable_data.object_to_save; + + var (checkpoint_factory_map, _) = SaveUtilV1.get_checkpoint_factories_and_keys(object_names, object_map); + var (named_saveable_objects, _) = SaveUtilV1.generate_saveable_objects(checkpoint_factory_map, object_graph_proto, node_ids, object_map, + call_with_mapped_captures, saveables_cache: null); + var trackable = new SaveableCompatibilityConverter(trackable_data.object_to_save, named_saveable_objects); + return (trackable, trackable.serialize_to_tensors()); + } + + private static IDictionary> get_and_write_registered_savers(IDictionary> registered_trackables, TrackableObjectGraph object_graph_proto) + { + Dictionary> registered_savers = new(); + foreach(var pair in registered_trackables) + { + foreach(var td in pair.Value) + { + if (registered_savers.ContainsKey(pair.Key)) + { + registered_savers[pair.Key] = new Dictionary(); + } + else + { + registered_savers[pair.Key][td.object_name] = td.object_to_save; + } + + var object_proto = object_graph_proto.Nodes[td.node_id]; + // TODO: add APIs and complete it. Now the `TrackableObjectGraph.Types.TrackableObject` lacks `registered_savers`. + } + } + return registered_savers; + } + + private static (IList, IList, IDictionary>) split_trackables(IEnumerable trackable_data) + { + List tensor_trackables = new(); + List py_state_trackables = new(); // skip the process of `PyState` for the lack of API. This is only a pleceholder. + Dictionary> registered_trackables = new(); + + foreach(var td in trackable_data) + { + // TODO: deal with registration. + tensor_trackables.Add(td); + } + return (tensor_trackables, py_state_trackables, registered_trackables); + } + } +} diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs index 7724c6b70..44fa5c5de 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs @@ -7,6 +7,7 @@ using Tensorflow.Training; using pbc = global::Google.Protobuf.Collections; using static Tensorflow.Binding; +using Google.Protobuf; namespace Tensorflow.Checkpoint; @@ -47,19 +48,16 @@ public static (List, object?) frozen_saveables_and_savers(Obje IDictionary object_map, Graph? to_graph, bool call_with_mapped_captures, object? saveables_cache = null) { - - Graph target_context; if (to_graph is not null) { - using (to_graph.as_default()) - { - var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, + to_graph.as_default(); + var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, object_map, call_with_mapped_captures, saveables_cache); - // tensorflow python: `with ops.device("/cpu:0")` - var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); - named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); - return (named_saveable_objects, registered_savers); - } + // tensorflow python: `with ops.device("/cpu:0")` + var serialized = graph_proto.ToByteString().ToString(); + var object_graph_tensor = constant_op.constant("aaaa", TF_DataType.TF_STRING); + named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); + return (named_saveable_objects, registered_savers); } else { diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs b/src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs new file mode 100644 index 000000000..fa441d799 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; + +namespace Tensorflow.Checkpoint +{ + internal static class SaveableCompat + { + public static string? get_saveable_name(Trackable cls_or_obj) + { + // TODO: implement it with Attribute. + return null; + } + } +} diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs deleted file mode 100644 index 7d101d5e5..000000000 --- a/src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs +++ /dev/null @@ -1,109 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using Tensorflow.Contexts; -using Tensorflow.Eager; - -namespace Tensorflow.Checkpoint; - -public class TrackableSaver -{ - private ObjectGraphView _graph_view; - private EagerTensor _cached_save_operation; - private TrackableObjectGraph _last_save_object_graph; - private Tensor? _object_graph_feed_tensor = null; - private Tensor? _file_prefix_feed_tensor = null; - public TrackableSaver(ObjectGraphView graph_view) - { - _graph_view = graph_view; - - // TODO: cache when not executing eagerly. - // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, - // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` - - } - - private void gather_serialized_tensors(Tensor? object_graph_tensor = null) - { - throw new NotImplementedException(); - } - - private (EagerTensor, IDictionary) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) - { - throw new NotImplementedException(); - } - - // TODO: parameter write_done_callback - public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null, - CheckpointOptions? options = null) - { - if (options is null) - { - options = new CheckpointOptions(); - } - - Dictionary feed_dict = new(); - bool use_session = (!new Context().executing_eagerly() && !ops.inside_function()); - if (checkpoint_number is not null) - { - file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; - } - - Tensor file_prefix_tensor; - Tensor object_graph_tensor; - if (use_session) - { - if (_object_graph_feed_tensor is null) - { - // In python there is `with ops.device("/cpu:0")`. - _object_graph_feed_tensor = constant_op.constant("", dtypes.variant); - _file_prefix_feed_tensor = constant_op.constant("", dtypes.variant); - } - - object_graph_tensor = _object_graph_feed_tensor; - file_prefix_tensor = _file_prefix_feed_tensor; - feed_dict[file_prefix_tensor] = file_prefix; - } - else - { - // In python there is `with ops.device("/cpu:0")`. - file_prefix_tensor = ops.convert_to_tensor(file_prefix, dtypes.variant); - object_graph_tensor = null; - } - - var (save_path, new_feed_additions) = - save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options); - - if (new_feed_additions is not null) - { - foreach (var pair in new_feed_additions) - { - feed_dict.Add(pair.Key, pair.Value); - } - } - if(!use_session) - { - session = null; - } - else if (session is null) - { - session = new Session(); // In python it uses `get_session`. - } - - if (session is not null) - { - var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray(); - return session.run((Tensor)save_path, s); - } - else if (use_session) - { - throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " + - "in graph mode without a default session. Please use " + - "`with tf.Session():` to create a session."); - } - else - { - return save_path; - } - } -} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs index ed1f3ec47..6d81d2c9a 100644 --- a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs +++ b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs @@ -21,9 +21,14 @@ public TrackableView(WeakReference obj) public virtual IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) { obj._maybe_initialize_trackable(); + Dictionary children = new(); // Note: in python the return type of `Trackable._trackable_children` is not fixed. // Therefore it uses `convert_to_trackable` to have an extra process. - return obj._trackable_children(save_type); + foreach(var pair in obj._trackable_children(save_type)) + { + children[pair.Key] = pair.Value; + } + return children; } public Trackable Root @@ -50,6 +55,7 @@ public Trackable Root { List bfs_sorted = new(); Queue to_visit = new(); + to_visit.Enqueue(Root); Dictionary> node_paths = new(); node_paths[this.Root] = new List(); while (!to_visit.empty()) diff --git a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs new file mode 100644 index 000000000..791094899 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs @@ -0,0 +1,191 @@ +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Contexts; +using Tensorflow.Eager; +using Tensorflow.Train; +using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; +using static Tensorflow.Binding; + +namespace Tensorflow.Checkpoint; + +/// +/// Saves and restores a `Trackable` object and its dependencies. +/// +public class TrackableSaver +{ + private ObjectGraphView _graph_view; + private Tensor _cached_save_operation; + private TrackableObjectGraph _last_save_object_graph; + private Tensor? _object_graph_feed_tensor = null; + private Tensor? _file_prefix_feed_tensor = null; + private Dictionary? _object_map = null; + private object? _cache = null; + public TrackableSaver(ObjectGraphView graph_view) + { + _graph_view = graph_view; + + // TODO: cache when not executing eagerly. + // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, + // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` + + } + + private (IDictionary>, IDictionary, IDictionary>, TrackableObjectGraph) + gather_serialized_tensors(Tensor? object_graph_tensor = null) + { + var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); + + // TODO: cache. + + if(object_graph_tensor is null) + { + // tensorflow python: `with ops.device("/cpu:0"):` + object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); + } + else + { + feed_additions[object_graph_tensor] = graph_proto.ToString(); + } + Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); + if (serialized_tensors.ContainsKey(Trackable.None)) + { + serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; + } + return (serialized_tensors, feed_additions, registered_savers, graph_proto); + } + + private (Tensor, IDictionary) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) + { + var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); + + Func<(Tensor, IDictionary)> run_save = () => + { + if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) + { + var saver = new MultiDeviceSaver(serialized_tensors, registered_savers); + var save_op = saver.save(file_prefix, options); + + // tensorflow python: `with ops.device("/cpu:0"):` + using (ops.control_dependencies(new object[] { save_op })) + { + _cached_save_operation = array_ops.identity(file_prefix); + } + _last_save_object_graph = graph_proto; + } + return (_cached_save_operation, feed_additions); + }; + + if (options.experimental_enable_async_checkpoint) + { + throw new NotImplementedException(); + } + + return run_save(); + } + + private (Tensor, IDictionary) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options) + { + var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); + + Func<(Tensor, IDictionary)> run_save = () => + { + if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) + { + var saver = new MultiDeviceSaver(serialized_tensors, registered_savers); + var save_op = saver.save(file_prefix, options); + + // tensorflow python: `with ops.device("/cpu:0"):` + using (ops.control_dependencies(new object[] {save_op} )) + { + _cached_save_operation = array_ops.identity(tf.constant(file_prefix)); + } + _last_save_object_graph = graph_proto; + } + return (_cached_save_operation, feed_additions); + }; + + if (options.experimental_enable_async_checkpoint) + { + throw new NotImplementedException(); + } + + return run_save(); + } + + // TODO: parameter write_done_callback + public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null, + CheckpointOptions? options = null) + { + if (options is null) + { + options = new CheckpointOptions(); + } + + Dictionary feed_dict = new(); + bool use_session = (!new Context().executing_eagerly() && !ops.inside_function()); + if (checkpoint_number is not null) + { + file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; + } + + Tensor file_prefix_tensor; + Tensor object_graph_tensor; + if (use_session) + { + if (_object_graph_feed_tensor is null) + { + // In python there is `with ops.device("/cpu:0")`. + _object_graph_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING); + _file_prefix_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING); + } + + object_graph_tensor = _object_graph_feed_tensor; + file_prefix_tensor = _file_prefix_feed_tensor; + feed_dict[file_prefix_tensor] = file_prefix; + } + else + { + // In python there is `with ops.device("/cpu:0")`. + file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING); + object_graph_tensor = null; + } + + var (save_path, new_feed_additions) = + save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options); + + if (new_feed_additions is not null) + { + foreach (var pair in new_feed_additions) + { + feed_dict.Add(pair.Key, pair.Value); + } + } + if(!use_session) + { + session = null; + } + else if (session is null) + { + session = new Session(); // In python it uses `get_session`. + } + + if (session is not null) + { + var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray(); + return session.run((Tensor)save_path, s); + } + else if (use_session) + { + throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " + + "in graph mode without a default session. Please use " + + "`with tf.Session():` to create a session."); + } + else + { + return save_path; + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs new file mode 100644 index 000000000..759cbd663 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs @@ -0,0 +1,36 @@ +using System; +using System.Buffers.Text; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; +using static Tensorflow.ApiDef.Types; +using static Tensorflow.CostGraphDef.Types; +using static Tensorflow.OptimizerOptions.Types; + +namespace Tensorflow.Checkpoint +{ + /// + /// Saves checkpoints directly from multiple devices. + /// Note that this is a low-level utility which stores Tensors in the keys + /// specified by `SaveableObject`s.Higher-level utilities for object-based + /// checkpointing are built on top of it. + /// + public class MultiDeviceSaver + { + public MultiDeviceSaver(IDictionary> serialized_tensors, + IDictionary>? registered_savers = null, bool call_with_mapped_capture = false) + { + + } + + public Operation? save(string file_prefix, CheckpointOptions? options= null) + { + throw new NotImplementedException(); + } + + public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) + { + throw new NotImplementedException(); + } + } +} diff --git a/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs b/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs index 934136671..fb197eca2 100644 --- a/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs +++ b/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs @@ -205,6 +205,16 @@ public TrackableObject(pbc::RepeatedField slot, + pbc::RepeatedField children + ) + { + OnConstruction(); + slotVariables_ = slot; + children_ = children; + } + partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] diff --git a/src/TensorFlowNET.Core/Training/AutoTrackable.cs b/src/TensorFlowNET.Core/Training/AutoTrackable.cs index d8f6314bc..6f10fd2e5 100644 --- a/src/TensorFlowNET.Core/Training/AutoTrackable.cs +++ b/src/TensorFlowNET.Core/Training/AutoTrackable.cs @@ -1,4 +1,10 @@ -namespace Tensorflow.Train +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Functions; +using Tensorflow.Operations.Activation; +using static Tensorflow.Binding; + +namespace Tensorflow.Train { public abstract class AutoTrackable : Trackable { @@ -17,5 +23,48 @@ public void _delete_tracking(string name) } } } + + public override IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + { + if(save_type != SaveType.SAVEDMODEL) + { + return base._trackable_children(save_type, cache); + } + + Dictionary functions = new(); + // TODO: process of logs. + var properties = this.GetType().GetProperties(); + foreach ( var property in properties ) + { + string name = property.Name; + object value = property.GetValue(this, null); + if(value is Function || value is ConcreteFunction) + { + functions[name] = (Trackable)value; + } + } + + // TODO: process the type `core_types.GenericFunction`. + + Dictionary children = new(); + foreach(var pair in CheckpointDependencies) + { + var name = pair.Name; + var child = pair.Refer; + if(child is ConcreteFunction) // or Generic function + { + continue; + } + if(functions.ContainsKey(name) && functions[name] != child) + { + throw new ValueError($"Can't save object because it has multiple children with the same " + + $"name. Object: {this}, attribute name: {name}, child 1: " + + $"{child}, child 2: {functions[name]}"); + } + children[name] = child; + } + + return children.Concat(functions).ToDictionary(x => x.Key, x => x.Value); + } } } diff --git a/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs b/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs index 1ae912ce6..393a6a981 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs @@ -28,7 +28,7 @@ public class SaveSpec public string slice_spec => _slice_spec; private string _name; - public string name => _name; + public string name { get => _name; set => _name = value; } private TF_DataType _dtype; public TF_DataType dtype => _dtype; diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs index 692356054..cc8399528 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs @@ -134,35 +134,33 @@ private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_d Dictionary object_map; Dictionary tensor_map; AssetInfo asset_info; - using (var g = exported_graph.as_default()) + exported_graph.as_default(); + (object_map, tensor_map, asset_info) = saveable_view.map_resources(); + // TODO: deal with signatures. + if (save_custom_gradients) { - (object_map, tensor_map, asset_info) = saveable_view.map_resources(); - // TODO: deal with signatures. - if (save_custom_gradients) - { - // TODO: trace gradient functions. - } + // TODO: trace gradient functions. + } - foreach (var resource_initializer_function in resource_initializers) - { - // List asset_dependencies = new(); - // TODO: deal with initializers - } - - // using(ops.control_dependencies(...)) - var init_op = control_flow_ops.no_op(); - if (meta_graph_def.CollectionDef.ContainsKey(Tensorflow.Constants.MAIN_OP_KEY)) - { - meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY].NodeList.Value.Append(init_op.name); - } - else - { - meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY] = new CollectionDef(); - } - // Lack `CopyFrom` API - // meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY] + foreach (var resource_initializer_function in resource_initializers) + { + // List asset_dependencies = new(); + // TODO: deal with initializers } - + + // using(ops.control_dependencies(...)) + var init_op = control_flow_ops.no_op(); + if (meta_graph_def.CollectionDef.ContainsKey(Tensorflow.Constants.MAIN_OP_KEY)) + { + meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY].NodeList.Value.Append(init_op.name); + } + else + { + meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY] = new CollectionDef(); + } + // Lack `CopyFrom` API + // meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY] + foreach (var obj in object_map.Values) { obj._maybe_initialize_trackable(); @@ -180,11 +178,13 @@ private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_d verify_ops(graph_def, namespace_whitelist); meta_graph_def.GraphDef = new GraphDef(graph_def); + meta_graph_def.MetaInfoDef = new(); meta_graph_def.MetaInfoDef.Tags.Add(TagConstants.SERVING); meta_graph_def.MetaInfoDef.TensorflowVersion = tf.VERSION; // TODO: add git version. meta_graph_def.MetaInfoDef.TensorflowGitVersion = ""; meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; + meta_graph_def.MetaInfoDef.StrippedOpList = new(); meta_graph_def.MetaInfoDef.StrippedOpList.MergeFrom(meta_graph.stripped_op_list_for_graph(meta_graph_def.GraphDef)); meta_graph_def.AssetFileDef.AddRange(asset_info.asset_defs); diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index 98cdb274a..622eed3a7 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -138,5 +138,55 @@ public static bool trackable_has_serialize_to_tensor(Trackable obj) // TODO: implement it. return false; } + + internal static string convert_to_string(string x) + { + return tf.compat.as_str(x); + } + } + + public class SaveableCompatibilityConverter: Trackable + { + private Trackable _obj; + private IList _saveables; + public SaveableCompatibilityConverter(Trackable obj, IList saveables) + { + _obj= obj; + _saveables= saveables; + } + + public Trackable Obj => _obj; + public IList mySaveables=> _saveables; + + public override IDictionary serialize_to_tensors() + { + return saveable_objects_to_tensor_dict(_saveables); + } + + /// + /// Converts a list of SaveableObjects to a tensor dictionary. + /// + /// + public static Dictionary saveable_objects_to_tensor_dict(IList saveables) + { + Dictionary tensor_dict = new(); + foreach (var saveable in saveables) + { + foreach(var spec in saveable.specs) + { + var name = saveable_object_util.convert_to_string(spec.name); + var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); + if (!string.IsNullOrEmpty(slice_spec)) + { + throw new NotImplementedException(); + } + else + { + tensor_dict[name] = spec.tensor; + } + } + } + return tensor_dict; + } } } diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index dce0be2ac..b98075d3f 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -34,18 +34,35 @@ public static class Constants public static readonly string OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"; } protected int _self_update_uid; - protected IDictionary _unconditional_dependency_names = - new Dictionary(); + protected IDictionary _unconditional_dependency_names; - protected IList _unconditional_checkpoint_dependencies = new List(); + protected IList _unconditional_checkpoint_dependencies; protected IDictionary _self_saveable_object_factories = new Dictionary(); + + private static Trackable _none = new Function(); + /// + /// This is a trick for that CSharp does not allow the key of `Dictionary` to be null. + /// The `None` can be any object that inherits `Trackable`. + /// This Property is supposed to be used only internal. + /// + public static Trackable None + { + get + { + return _none; + } + } public virtual string ObjectIdentifier { get => "_generic_user_object"; } - + public int UpdateUid { get => _self_update_uid; set => _self_update_uid = value; } + public IList UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } + public IDictionary UnconditionalDependencyNames { get => _unconditional_dependency_names; } + public IList CheckpointDependencies { get => UnconditionalCheckpointDependencies; } + /// /// Restore-on-create for a variable be saved with this `Checkpointable`. /// @@ -99,8 +116,9 @@ protected IVariableV1 _track_checkpointable(IVariableV1 checkpointable, string n /// public void _maybe_initialize_trackable() { - // _self_unconditional_checkpoint_dependencies = [] _self_update_uid = -1; + _unconditional_checkpoint_dependencies = new List(); + _unconditional_dependency_names = new Dictionary(); } // TODO: cache @@ -153,6 +171,20 @@ public virtual IDictionary gather_saveables_for_checkp { return _self_saveable_object_factories; } + + /// + /// Gathers tensors to save to the checkpoint. You should only override `serialize_to_tensors` and `restore_from_tensors` + /// if you are defining a custom resource or variable with custom ops. + /// Otherwise, please store the state of your trackable in `tf.Variable` objects + /// and add them to Trackable object hierarchy using `setattr` (for subclasses + /// of `AutoTrackable`) or overriding the `_trackable_children` method. + /// + /// + /// + public virtual IDictionary serialize_to_tensors() + { + throw new NotImplementedException(); + } } public record class TrackableReference(string Name, Trackable Refer); From ddd06ab9b6d1bc18229630d98d5f062658b768a9 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Tue, 24 Jan 2023 18:01:22 +0800 Subject: [PATCH 04/15] Add ListWrapper and ITrackable, and revise implmentations. --- src/TensorFlowNET.Core/Keras/Layers/ILayer.cs | 3 +- .../Operations/NnOps/RNNCell.cs | 3 + src/TensorFlowNET.Core/Training/ITrackable.cs | 12 + src/TensorFlowNET.Core/Training/LayerUtils.cs | 9 + src/TensorFlowNET.Core/Training/Trackable.cs | 50 ++- .../Training/data_structures.cs | 364 ++++++++++++++++++ .../Variables/BaseResourceVariable.cs | 4 +- .../Variables/IVariableV1.cs | 1 + .../Variables/RefVariable.cs | 1 + src/TensorFlowNET.Keras/Engine/Functional.cs | 31 ++ src/TensorFlowNET.Keras/Engine/Model.cs | 11 + .../Saving/SavedModel/layer_serialization.cs | 12 + .../SavedModel/serialized_attributes.cs | 14 + .../Saving/SavedModel/utils.cs | 4 +- 14 files changed, 513 insertions(+), 6 deletions(-) create mode 100644 src/TensorFlowNET.Core/Training/ITrackable.cs create mode 100644 src/TensorFlowNET.Core/Training/LayerUtils.cs create mode 100644 src/TensorFlowNET.Core/Training/data_structures.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index f77b4a86d..f1ca56325 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -1,10 +1,11 @@ using System.Collections.Generic; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Training; namespace Tensorflow.Keras { - public interface ILayer + public interface ILayer: ITrackable { string Name { get; } bool Trainable { get; } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 04fdc7e57..734f26089 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -21,6 +21,7 @@ limitations under the License. using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; using Tensorflow.Operations; +using Tensorflow.Train; using Tensorflow.Util; using static Tensorflow.Binding; @@ -147,5 +148,7 @@ public LayerArgs get_config() { throw new NotImplementedException(); } + + public Trackable GetTrackable() { throw new NotImplementedException(); } } } diff --git a/src/TensorFlowNET.Core/Training/ITrackable.cs b/src/TensorFlowNET.Core/Training/ITrackable.cs new file mode 100644 index 000000000..e4ef2c8fc --- /dev/null +++ b/src/TensorFlowNET.Core/Training/ITrackable.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; + +namespace Tensorflow.Training +{ + public interface ITrackable + { + Trackable GetTrackable(); + } +} diff --git a/src/TensorFlowNET.Core/Training/LayerUtils.cs b/src/TensorFlowNET.Core/Training/LayerUtils.cs new file mode 100644 index 000000000..211419651 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/LayerUtils.cs @@ -0,0 +1,9 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; + +namespace Tensorflow.Training +{ + +} diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index b98075d3f..2646fb8d5 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -18,11 +18,12 @@ limitations under the License. using System.Collections.Generic; using System.Linq; using Tensorflow.ModelSaving; +using Tensorflow.Training; using static Tensorflow.Binding; namespace Tensorflow.Train { - public abstract class Trackable + public abstract class Trackable: ITrackable { /// /// Corresponding to tensorflow/python/trackable/constants.py @@ -40,6 +41,7 @@ public static class Constants protected IDictionary _self_saveable_object_factories = new Dictionary(); + private bool _manual_tracking = true; private static Trackable _none = new Function(); /// @@ -54,6 +56,10 @@ public static Trackable None return _none; } } + public Trackable GetTrackable() + { + return this; + } public virtual string ObjectIdentifier { get => "_generic_user_object"; @@ -128,6 +134,48 @@ public virtual IDictionary _trackable_children(SaveType save_ return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); } + public virtual Trackable _track_trackable(Trackable trackable, string name, bool overwrite = false) + { + _maybe_initialize_trackable(); + if (!_manual_tracking) return trackable; + var new_reference = new TrackableReference(name, trackable); + var current_object = _lookupup_dependency(name); + + if(current_object is null) + { + _unconditional_checkpoint_dependencies.Add(new_reference); + _handle_deferred_dependencies(name, trackable); + } + _unconditional_dependency_names[name] = trackable; + return trackable; + } + + /// + /// Pop and load any deferred checkpoint restores into `trackable`. + /// This method does not add a new dependency on `trackable`, but it does check if any outstanding/deferred dependencies have been queued waiting for + /// this dependency to be added (matched based on `name`). If so, `trackable` and its dependencies are restored. The restorations are + /// considered fulfilled and so are deleted. + /// `_track_trackable` is more appropriate for adding a normal/unconditional dependency, and includes handling for deferred restorations. + /// This method allows objects such as `Optimizer` to use the same restoration logic while managing conditional dependencies themselves, + /// by overriding `_checkpoint_dependencies` and `_lookup_dependency` to change the object's dependencies based on the context + /// it is saved/restored in (a single optimizer instance can have state associated with multiple graphs). + /// + /// + /// + public virtual void _handle_deferred_dependencies(string name, Trackable trackable) + { + //_maybe_initialize_trackable(); + //trackable._maybe_initialize_trackable(); + + // TODO: complete the implementation. + } + + public virtual Trackable? _lookupup_dependency(string name) + { + if (_unconditional_dependency_names.TryGetValue(name, out var dependency)) return dependency; + else return null; + } + public static Trackable convert_to_trackable(object obj, object? parent = null) { if (obj is Trackable) diff --git a/src/TensorFlowNET.Core/Training/data_structures.cs b/src/TensorFlowNET.Core/Training/data_structures.cs new file mode 100644 index 000000000..4cb78181b --- /dev/null +++ b/src/TensorFlowNET.Core/Training/data_structures.cs @@ -0,0 +1,364 @@ +using Google.Protobuf; +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO.Compression; +using System.Linq; +using System.Linq.Expressions; +using System.Runtime.InteropServices; +using System.Text; +using Tensorflow.Functions; +using Tensorflow.Keras; +using Tensorflow.Operations.Activation; +using Tensorflow.Train; +using static Tensorflow.ApiDef.Types; + +namespace Tensorflow.Training +{ + public class NoDependency + { + public Trackable Value { get; set; } + public NoDependency(Trackable value) + { + Value = value; + } + } + + public abstract class TrackableDataStructure : Trackable + { + private bool _self_trainable; + private List _self_extra_variables; + + public TrackableDataStructure() + { + _self_trainable = true; + _self_extra_variables = new List(); + } + + public abstract IEnumerable Values { get; } + public bool Trainable { get => _self_trainable; set => _self_trainable = value; } + public IEnumerable Layers + { + get + { + List collected = new(); + foreach(var obj in Values) + { + if(obj is ILayer) + { + collected.Add((ILayer)obj); + } + else if(obj is TrackableDataStructure) + { + collected.AddRange((obj as TrackableDataStructure).Layers); + } + } + return collected; + } + } + public IEnumerable TrainableWeights + { + get + { + if (!_self_trainable) + { + return new List(); + } + List trainable_variables = new(); + foreach (var obj in Values) + { + // skip the process of `module.Module`. + if (obj is TrackableDataStructure) + { + trainable_variables.AddRange((obj as TrackableDataStructure).TrainableVariables); + } + } + foreach(var v in _self_extra_variables) + { + if (v.Trainable) + { + trainable_variables.Add(v); + } + } + return trainable_variables; + } + } + public IEnumerable NonTrainableWeights + { + get + { + var trainable_extra_variables = _self_extra_variables.TakeWhile(x => x.Trainable).ToList(); + var non_trainable_extra_variables = _self_extra_variables.TakeWhile(x => !x.Trainable).ToList(); + List non_trainable_variables = new(); + foreach(var obj in Values) + { + // skip the process of `module.Module`. + if (obj is TrackableDataStructure) + { + non_trainable_variables.AddRange((obj as TrackableDataStructure).NonTrainableVariables); + } + } + + if (!_self_trainable) + { + // Return order is all trainable vars, then all non-trainable vars. + List trainable_variables = new(); + foreach(var obj in Values) + { + // skip the process of `module.Module`. + if (obj is TrackableDataStructure) + { + trainable_variables.AddRange((obj as TrackableDataStructure).TrainableVariables); + } + } + return trainable_variables.concat(trainable_extra_variables).concat(non_trainable_variables).concat(non_trainable_extra_variables); + } + else + { + return non_trainable_variables.concat(non_trainable_extra_variables); + } + } + } + public IEnumerable Weights => TrainableWeights.Concat(NonTrainableWeights); + public IEnumerable TrainableVariables => TrainableWeights; + public IEnumerable NonTrainableVariables => NonTrainableWeights; + public IEnumerable Variables => Weights; + + // TODO: `losses` property. + + /// + /// Add a dependency on `value`. + /// + /// + /// + protected virtual Trackable _track_value(Trackable value, string name) + { + value = sticky_attribute_assignment(this, name, value); + if(value is IVariableV1) + { + _self_extra_variables.Add(value as IVariableV1); + } + // skip the left process (need to be done in the future). + return value; + } + + protected static Trackable wrap_or_unwrap(NoDependency value) + { + return value.Value; + } + + protected static Trackable wrap_or_unwrap(Trackable value) + { + return value; + } + + protected static Trackable wrap_or_unwrap(IList value) + { + return new ListWrapper(value); + } + + protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, Trackable value) + { + value = wrap_or_unwrap(value); + trackable._track_trackable(value, name, true); + return value; + } + + protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, NoDependency value) + { + var wrapped_value = wrap_or_unwrap(value); + trackable._track_trackable(wrapped_value, name, true); + return wrapped_value; + } + + protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, IList value) + { + var wrapped_value = wrap_or_unwrap(value); + trackable._track_trackable(wrapped_value, name, true); + return wrapped_value; + } + } + + public class ListWrapper : TrackableDataStructure, IList, ICloneable + { + private IList _storage; + private bool _non_append_mutation_value; + private bool _external_modification_value; + private IList _last_wrapped_list_snapshot; + /// + /// + /// + /// The initial value of the data structure. A shallow copy may be maintained for error checking. `wrapped_list` itself should not be + /// modified directly after constructing the `ListWrapper`, and if changes are detected the `ListWrapper` will throw an exception on save. + public ListWrapper(IList wrapped_list) + { + _storage = wrapped_list; + _non_append_mutation_value = _external_modification_value = false; + _last_wrapped_list_snapshot = new List(_storage); + } + + protected bool NonAppendMuation { + get => _non_append_mutation_value; + set + { + // TODO: deal with `attribute_sentinel`. + _non_append_mutation_value = value; + } + } + + protected bool ExternalModification + { + get => _external_modification_value; + set + { + // TODO: deal with `attribute_sentinel`. + _external_modification_value = value; + } + } + + public override IEnumerable Values => this; + public bool IsReadOnly { get => _storage.IsReadOnly; } + + /// + /// Checks for any changes to the wrapped list not through the wrapper. + /// + private void check_external_modification() + { + if (_external_modification_value || _non_append_mutation_value) return; + if (!_storage.SequenceEqual(_last_wrapped_list_snapshot)) + { + _external_modification_value = true; + } + } + + private void update_snapshot() + { + // TODO: deal with `attribute_sentinel`. + if (_external_modification_value || _non_append_mutation_value) return; + _last_wrapped_list_snapshot = new List(_storage); + } + + public override IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + { + check_external_modification(); + if (_non_append_mutation_value) + { + throw new ValueError($"Unable to save the object {this} (a list wrapper constructed to track trackable TensorFlow objects). A list element was replaced" + + $", deleted or moved (sort). In order to support restoration on object creation, tracking is exclusively for append-only data structures." + + $"\n\nIf you don't need this list checkpointed, wrap it in a non-trackable object; it will be subsequently ignored."); + } + if (_external_modification_value) + { + throw new ValueError($"Unable to save the object {this} (a list wrapper constructed to track trackable TensorFlow objects). The wrapped list was modified " + + $"outside the wrapper (its final value was {_storage}, its value when a checkpoint dependency was added was {_last_wrapped_list_snapshot}), which breaks " + + $"restoration on object creation.\n\nIf you don't need this list checkpointed, wrap it in a NoDependency object; it will be subsequently ignored."); + } + var children = base._trackable_children(save_type, cache); + + if(save_type == SaveType.SAVEDMODEL) + { + children = children.Concat(this.TakeWhile(x => x is Function or ConcreteFunction).Select((x, idx) => new KeyValuePair(idx.ToString(), x))).ToDictionary(x => x.Key, x => x.Value); + } + + return children; + } + + private bool has_mutation_or_trackable() + { + return _non_append_mutation_value; + } + + /// + /// Allows storage of non-trackable objects. + /// + /// + /// + /// + protected override Trackable _track_value(Trackable value, string name) + { + try + { + base._track_value(value, name); + } + catch(ValueError ex) + { + value = sticky_attribute_assignment(this, name, value); + } + return value; + } + + public object Clone() + { + var res = new ListWrapper(_storage.Select(x => x).ToList()); + res.NonAppendMuation= _non_append_mutation_value; + res.ExternalModification = _external_modification_value; + return res; + } + + public Trackable this[int index] { + get => _storage[index]; + set + { + // skip the process of `Slice`, maybe support it in the future. + _non_append_mutation_value = true; + _storage[index] = _track_value(value, _name_element(index)); + + update_snapshot(); + } + } + + public int IndexOf(Trackable item) => _storage.IndexOf(item); + + public void Insert(int index, Trackable item) + { + check_external_modification(); + _non_append_mutation_value = true; + _storage.Insert(index, item); + update_snapshot(); + } + + public void RemoveAt(int index) + { + check_external_modification(); + if (has_mutation_or_trackable()) + { + _non_append_mutation_value = true; + } + _storage.RemoveAt(index); + update_snapshot(); + } + + public int Count { get => _storage.Count; } + + public void Add(Trackable item) + { + check_external_modification(); + _storage.Add(item); + update_snapshot(); + } + + public void Clear() => _storage.Clear(); + + public bool Contains(Trackable item) => _storage.Contains(item); + + public void CopyTo(Trackable[] array, int arrayIndex) => _storage.CopyTo(array, arrayIndex); + + public bool Remove(Trackable item) + { + check_external_modification(); + if (has_mutation_or_trackable()) + { + _non_append_mutation_value = true; + } + var res = _storage.Remove(item); + update_snapshot(); + return res; + } + + public IEnumerator GetEnumerator() => _storage.GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => _storage.GetEnumerator(); + + protected string _name_element(int index) => $"{index}"; + } +} diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 0a050d0f1..4526730fa 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -22,7 +22,7 @@ public class BaseResourceVariable : DisposableObject protected bool _in_graph_mode; protected bool _trainable; - public bool trainable => _trainable; + public bool Trainable => _trainable; protected Tensor _initial_value; @@ -166,7 +166,7 @@ IVariableV1 _lazy_read(Operation op, Tensor value) /// void variable_accessed(BaseResourceVariable variable) { - if (variable.trainable) + if (variable.Trainable) { foreach (var tape in tf.GetTapeSet()) tape.VariableAccessed(variable as ResourceVariable); diff --git a/src/TensorFlowNET.Core/Variables/IVariableV1.cs b/src/TensorFlowNET.Core/Variables/IVariableV1.cs index f4f716c3c..3eb78153a 100644 --- a/src/TensorFlowNET.Core/Variables/IVariableV1.cs +++ b/src/TensorFlowNET.Core/Variables/IVariableV1.cs @@ -46,6 +46,7 @@ public interface IVariableV1 Graph Graph { get; } TF_DataType dtype { get; } Shape shape { get; } + bool Trainable { get; } Tensor assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true); Tensor assign_sub(T delta, bool use_locking = false, string name = null, bool read_value = true); IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null); diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 67c12c427..38b5b7345 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -56,6 +56,7 @@ public partial class RefVariable : IVariableV1, IProtoBuf _variable.name; public Tensor eval() => _variable; + public bool Trainable => _trainable; public RefVariable(object initial_value = null, bool trainable = true, diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index 09a31b948..61a8956a6 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -3,6 +3,7 @@ using System.Linq; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Utils; +using Tensorflow.Train; using static Tensorflow.Binding; namespace Tensorflow.Keras.Engine @@ -20,6 +21,30 @@ public partial class Functional : Model Dictionary tensor_usage_count; + /// + /// Dictionary of layer dependencies to be included in the checkpoint. + /// + public IDictionary LayerCheckpointDependencies + { + get + { + int weight_layer_index = 0; + Dictionary dependencies = new(); + for(int i = 0; i < Layers.Count; i++) + { + var layer = Layers[i]; + var weights = layer.TrainableWeights.concat(layer.NonTrainableWeights).ToList(); + if(weights.Count > 0) + { + dependencies[$"layer_with_weights-{weight_layer_index}"] = layer; + weight_layer_index++; + } + dependencies[$"layer-{i}"] = layer; + } + return dependencies; + } + } + public Functional(Tensors inputs, Tensors outputs, string name = null) : base(new ModelArgs { @@ -325,5 +350,11 @@ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? train return output_tensors; } + + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary? cache = null) + { + return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache)) + .ToDictionary(x => x.Key, x => x.Value); + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs index 835f6041b..41f7788ed 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -4,6 +4,7 @@ using Tensorflow.Keras.Engine.DataAdapters; using Tensorflow.Keras.Losses; using Tensorflow.Keras.Optimizers; +using Tensorflow.Train; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -108,5 +109,15 @@ public override List TrainableVariables return variables; } } + + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary? cache = null) + { + if(save_type == SaveType.SAVEDMODEL) + { + //TODO: deal with `train_function`, `test_function`, `predict_function`, `train_tf_function`. + } + var children = base._trackable_children(save_type, cache); + return children; + } } } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs index ade8ae73e..f0ad74507 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs @@ -29,6 +29,18 @@ public override IDictionary functions_to_serialize(IDictionary throw new System.NotImplementedException(); } + /// + /// Generates or retrieves serialized attributes from cache. + /// + /// + protected void get_serialized_attributes(IDictionary serialization_cache) + { + // TODO: deal with cache. + Layer a; + + + } + public override string TrackingMetadata { get diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs new file mode 100644 index 000000000..6a163fecb --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + /// + /// Class that tracks and validates all serialization attributes. + /// + public class SerializedAttributes + { + + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs index 30e895827..a5d84d674 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs @@ -4,7 +4,7 @@ namespace Tensorflow.Keras.Saving.SavedModel; public partial class KerasSavedModelUtils { - public static bool ShouldHaveTraces { get; internal set; } + public static bool ShouldHaveTraces { get; internal set; } = true; public static SaveOptionsContext keras_option_scope(bool save_traces) { @@ -23,7 +23,7 @@ public class SaveOptionsContext: IDisposable public bool _old_value; public SaveOptionsContext(bool old_value) { - _old_value = true; + _old_value = old_value; } public void Dispose() From bdca3b5e3d92514a0b816f4a8a81b0864428ebf8 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Tue, 24 Jan 2023 22:03:53 +0800 Subject: [PATCH 05/15] Add serialized attributes. --- .../Training/AutoTrackable.cs | 2 +- .../SavedModel/serialized_attributes.cs | 267 +++++++++++++++++- 2 files changed, 267 insertions(+), 2 deletions(-) diff --git a/src/TensorFlowNET.Core/Training/AutoTrackable.cs b/src/TensorFlowNET.Core/Training/AutoTrackable.cs index 6f10fd2e5..5dd9784f5 100644 --- a/src/TensorFlowNET.Core/Training/AutoTrackable.cs +++ b/src/TensorFlowNET.Core/Training/AutoTrackable.cs @@ -6,7 +6,7 @@ namespace Tensorflow.Train { - public abstract class AutoTrackable : Trackable + public class AutoTrackable : Trackable { public void _delete_tracking(string name) { diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs index 6a163fecb..ff3c78757 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs @@ -1,14 +1,279 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers.Rnn; +using Tensorflow.Keras.Metrics; +using Tensorflow.Train; namespace Tensorflow.Keras.Saving.SavedModel { + // TODO: revise the name of these "Attributes". Since "Attribute" is a significant feature of C#, + // Using the name "Attributes" may be quite confusing. /// /// Class that tracks and validates all serialization attributes. /// - public class SerializedAttributes + public abstract class SerializedAttributes { + protected IDictionary _object_dict; + protected IDictionary _function_dict; + protected AutoTrackable _keras_trackable; + protected HashSet _all_functions; + protected HashSet _all_checkpointable_objects; + protected SerializedAttributes() + { + _object_dict= new Dictionary(); + _function_dict= new Dictionary(); + _keras_trackable= new AutoTrackable(); + _all_functions= new HashSet(); + _all_checkpointable_objects= new HashSet(); + } + + protected SerializedAttributes(IEnumerable checkpointable_objects, IEnumerable functions) + { + _object_dict = new Dictionary(); + _function_dict = new Dictionary(); + _keras_trackable = new AutoTrackable(); + + _all_checkpointable_objects = new HashSet(checkpointable_objects); + _all_functions = new HashSet(functions); + } + + public IDictionary Functions => _function_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); + + public IDictionary CheckpointableObjects => _object_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); + + /// + /// Returns functions to attach to the root object during serialization. + /// + public IDictionary FunctionsToSerialize + { + get + { + Dictionary functions = new(); + foreach(var pair in Functions) + { + if (_all_functions.Contains(pair.Key)) + { + // TODO: deal with `LayerCall`. + functions[pair.Key] = pair.Value; + } + } + return functions; + } + } + + /// + /// Returns objects to attach to the root object during serialization. + /// + public IDictionary ObjectsToSerialize + { + get + { + var objects = CheckpointableObjects.TakeWhile( x=> _all_checkpointable_objects.Contains(x.Key)).ToDictionary(x => x.Key, x => x.Value); + objects[Constants.KERAS_ATTR] = _keras_trackable; + return objects; + } + } + + /// + /// Saves function dictionary, and validates dictionary values. + /// + /// + public IDictionary set_and_validate_functions(IDictionary function_dict) + { + foreach(var key in _all_functions) + { + if (function_dict.ContainsKey(key)) + { + // TODO: deal with type `LayerCall`. + var fn = function_dict[key]; + if (fn is not null && (fn is not Function)) + { + throw new ValueError($"Function dictionary contained a non-function object: {function_dict[key]} (for key {key})."); + } + _function_dict[key] = fn; + + var tf_fn = fn; // TODO: deal with type `LayerCall`. + + // Warning: this implmentation should be considered again. + var properties = _keras_trackable.GetType().GetProperties(); + foreach (var property in properties) + { + if(property.Name == key) + { + property.SetValue(_keras_trackable, tf_fn); + break; + } + } + } + else + { + throw new ValueError($"Function {key} missing from serialized function dict."); + } + } + return Functions; + } + + /// + /// Saves objects to a dictionary, and validates the values. + /// + /// + public IDictionary set_and_validate_objects(IDictionary object_dict) + { + foreach(var key in _all_checkpointable_objects) + { + if (object_dict.ContainsKey(key)) + { + _object_dict[key] = object_dict[key]; + // Warning: this implmentation should be considered again. + var properties = _keras_trackable.GetType().GetProperties(); + foreach (var property in properties) + { + if (property.Name == key) + { + property.SetValue(_keras_trackable, object_dict[key]); + break; + } + } + } + else + { + throw new ValueError($"Object {key} missing from serialized object dict."); + } + } + return CheckpointableObjects; + } + + /// + /// Returns a new SerializedAttribute object (corresponding to `new` of tensorflow python). + /// + /// + public static SerializedAttributes Create(Trackable obj) + { + if(obj is Model) + { + return new ModelAttributes(); + } + else if(obj is Metric) + { + return new MetricAttributes(); + } + else if(obj is RNN) + { + return new RNNAttributes(); + } + else if(obj is Layer) + { + return new LayerAttributes(); + } + else + { + throw new TypeError($"Internal error during serialization: Expected Keras Layer object, got {obj} of type {obj.GetType()}"); + } + } + + protected virtual (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + { + return (checkpointable_objects ?? (new List()), functions ?? (new List())); + } + } + + // Note that the current implementation still has some potential risks. + // The tensorflow python says that this class is "Common endpoints shared by all models loadable by Keras". + // However, currently it's just a normal class. + public class CommonEndPoints: SerializedAttributes + { + protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + { + if(checkpointable_objects is null) + { + checkpointable_objects = new List(); + } + if(functions is null) + { + functions = new List(); + } + return base.get_objects_and_functions_recursively( + checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), + // TODO: remove the `__call__`. + functions.Concat(new string[] {"__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) + ); + } + } + + public class LayerAttributes: CommonEndPoints + { + protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + { + if (checkpointable_objects is null) + { + checkpointable_objects = new List(); + } + if (functions is null) + { + functions = new List(); + } + return base.get_objects_and_functions_recursively( + checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), + functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) + ); + } + } + + public class ModelAttributes: LayerAttributes + { + protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + { + if (checkpointable_objects is null) + { + checkpointable_objects = new List(); + } + if (functions is null) + { + functions = new List(); + } + return base.get_objects_and_functions_recursively(checkpointable_objects,functions); + } + } + + public class MetricAttributes : SerializedAttributes + { + protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + { + if (checkpointable_objects is null) + { + checkpointable_objects = new List(); + } + if (functions is null) + { + functions = new List(); + } + return base.get_objects_and_functions_recursively( + checkpointable_objects.Concat(new string[] { "variables" }), + functions + ); + } + } + + public class RNNAttributes: LayerAttributes + { + protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + { + if (checkpointable_objects is null) + { + checkpointable_objects = new List(); + } + if (functions is null) + { + functions = new List(); + } + return base.get_objects_and_functions_recursively( + checkpointable_objects.Concat(new string[] { "states" }), + functions + ); + } } } From b92b08d6290477150c403711b98778e8cae55425 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Wed, 25 Jan 2023 10:14:15 +0800 Subject: [PATCH 06/15] Implement layer serializations. --- .../Checkpoint/TrackableView.cs | 2 +- src/TensorFlowNET.Core/DisposableObject.cs | 68 ++++++++ .../Saving/SavedModel/AugmentedGraphView.cs | 4 +- .../Training/data_structures.cs | 11 +- .../Variables/BaseResourceVariable.cs | 2 +- .../Variables/RefVariable.cs | 3 +- src/TensorFlowNET.Keras/Engine/Layer.cs | 2 + .../Saving/SavedModel/SaveImpl.cs | 53 ++++++- .../Saving/SavedModel/base_serialization.cs | 7 +- .../Saving/SavedModel/layer_serialization.cs | 39 ++++- .../SavedModel/serialized_attributes.cs | 145 +++++++++--------- .../Saving/SavedModel/utils.cs | 14 ++ 12 files changed, 257 insertions(+), 93 deletions(-) diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs index 6d81d2c9a..69bf76fda 100644 --- a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs +++ b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs @@ -24,7 +24,7 @@ public virtual IDictionary children(Trackable obj, SaveType s Dictionary children = new(); // Note: in python the return type of `Trackable._trackable_children` is not fixed. // Therefore it uses `convert_to_trackable` to have an extra process. - foreach(var pair in obj._trackable_children(save_type)) + foreach (var pair in obj._trackable_children(save_type)) { children[pair.Key] = pair.Value; } diff --git a/src/TensorFlowNET.Core/DisposableObject.cs b/src/TensorFlowNET.Core/DisposableObject.cs index 3c70739bd..7fac3d0f1 100644 --- a/src/TensorFlowNET.Core/DisposableObject.cs +++ b/src/TensorFlowNET.Core/DisposableObject.cs @@ -17,6 +17,7 @@ limitations under the License. using System; using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; +using Tensorflow.Train; namespace Tensorflow { @@ -90,4 +91,71 @@ public void Dispose() Dispose(false); } } + + public abstract class DisposableTrackableObject: Trackable, IDisposable + { + protected IntPtr _handle; + protected bool _disposed; + + protected DisposableTrackableObject() + { } + + protected DisposableTrackableObject(IntPtr handle) + => _handle = handle; + + private void Dispose(bool disposing) + { + if (_disposed) + return; + + //first handle managed, they might use the unmanaged resources. + if (disposing) + { + // dispose managed state (managed objects). + DisposeManagedResources(); + } + + // free unmanaged memory + if (_handle != IntPtr.Zero) + { + // Call the appropriate methods to clean up + // unmanaged resources here. + // If disposing is false, + // only the following code is executed. + DisposeUnmanagedResources(_handle); + _handle = IntPtr.Zero; + } + + // Note disposing has been done. + _disposed = true; + } + + /// + /// Dispose any managed resources. + /// + /// Equivalent to what you would perform inside + protected virtual void DisposeManagedResources() + { } + + /// + /// Dispose any unmanaged resources related to given . + /// + protected abstract void DisposeUnmanagedResources(IntPtr handle); + + public void Dispose() + { + Dispose(true); + // This object will be cleaned up by the Dispose method. + // Therefore, you should call GC.SupressFinalize to + // take this object off the finalization queue + // and prevent finalization code for this object + // from executing a second time. + GC.SuppressFinalize(this); + } + + ~DisposableTrackableObject() + { + Dispose(false); + } + } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs index 6723206c0..82da2ee94 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs @@ -23,10 +23,10 @@ public void set_signature(object signature_map, object wrapped_functions) list_children(Root); } - public override List list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + public override List list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL) { Dictionary children = new(); - foreach (var pair in base.list_children(obj, save_type)) + foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL)) { var name = pair.Name; var child = pair.Refer; diff --git a/src/TensorFlowNET.Core/Training/data_structures.cs b/src/TensorFlowNET.Core/Training/data_structures.cs index 4cb78181b..d4e9c401b 100644 --- a/src/TensorFlowNET.Core/Training/data_structures.cs +++ b/src/TensorFlowNET.Core/Training/data_structures.cs @@ -142,21 +142,26 @@ protected virtual Trackable _track_value(Trackable value, string name) return value; } - protected static Trackable wrap_or_unwrap(NoDependency value) + public static Trackable wrap_or_unwrap(NoDependency value) { return value.Value; } - protected static Trackable wrap_or_unwrap(Trackable value) + public static Trackable wrap_or_unwrap(Trackable value) { return value; } - protected static Trackable wrap_or_unwrap(IList value) + public static Trackable wrap_or_unwrap(IList value) { return new ListWrapper(value); } + public static Trackable wrap_or_unwrap(IEnumerable value) + { + return new ListWrapper(value.ToList()); + } + protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, Trackable value) { value = wrap_or_unwrap(value); diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 4526730fa..f217a052d 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -7,7 +7,7 @@ namespace Tensorflow { - public class BaseResourceVariable : DisposableObject + public class BaseResourceVariable : DisposableTrackableObject { protected string _name; public virtual string Name => _handle_name; diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 38b5b7345..7b08f3ea4 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -20,11 +20,12 @@ limitations under the License. using System.Collections.Generic; using System.Linq; using static Tensorflow.Binding; +using Tensorflow.Train; namespace Tensorflow { [Obsolete] - public partial class RefVariable : IVariableV1, IProtoBuf + public partial class RefVariable: Trackable, IVariableV1, IProtoBuf { protected string _name; public string UniqueId => _name; diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index e95e55d6d..b9b01dae5 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -288,6 +288,8 @@ public List weights } } + public List Variables => weights; + public virtual LayerArgs get_config() => args; } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs index ba0bcc663..7168e25b7 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs @@ -1,5 +1,8 @@ using System.Collections.Generic; +using System.Linq; using Tensorflow.Keras.Engine; +using Tensorflow.Train; +using Tensorflow.Training; namespace Tensorflow.Keras.Saving.SavedModel; @@ -10,10 +13,54 @@ public static bool should_skip_serialization(object layer) return false; } - public static IDictionary wrap_layer_objects(Layer layer, object serialization_cache) + /// + /// Returns extra trackable objects to attach to the serialized layer. + /// + /// + /// + /// + public static IDictionary wrap_layer_objects(Layer layer, IDictionary serialization_cache) { - // TODO: process the loss + // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. - return null; + // TODO: change the inherits of `Variable` and revise the implmentation. + var variables = layer.Variables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + }); + var trainable_variables = layer.TrainableVariables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + }); + var non_trainable_variables = layer.non_trainable_variables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + }); + + Dictionary res = new(); + res["variables"] = TrackableDataStructure.wrap_or_unwrap(variables); + res["trainable_variables"] = TrackableDataStructure.wrap_or_unwrap(trainable_variables); + res["non_trainable_variables"] = TrackableDataStructure.wrap_or_unwrap(non_trainable_variables); + res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); + + return res; + } + + /// + /// Returns dict of wrapped layer call function and losses in tf.functions. + /// + /// + /// + /// + public static IDictionary wrap_layer_functions(Layer layer, IDictionary serialization_cache) + { + // TODO: deal with type `RevivedLayer` and `Sequential`. + + // skip the process because of lack of APIs of `Layer`. + + return new Dictionary(); } } \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs index 36111a18e..a399eaf13 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs @@ -17,10 +17,10 @@ public SavedModelSaver(Trackable obj) public abstract string ObjectIdentifier { get; } public abstract string TrackingMetadata { get; } - public abstract IDictionary objects_to_serialize( + public abstract IDictionary objects_to_serialize( IDictionary serialization_cache); - public abstract IDictionary functions_to_serialize( + public abstract IDictionary functions_to_serialize( IDictionary serialization_cache); public IDictionary trackable_children(IDictionary? serialization_cache) @@ -32,8 +32,7 @@ public IDictionary trackable_children(IDictionary x.Key, x => (Trackable)x.Value) - .Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) + return children.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) .ToDictionary(x => x.Key, x => x.Value); } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs index f0ad74507..7a0ddd21b 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs @@ -19,26 +19,51 @@ public override string ObjectIdentifier get => Constants.LAYER_IDENTIFIER; } - public override IDictionary objects_to_serialize(IDictionary serialization_cache) + public override IDictionary objects_to_serialize(IDictionary serialization_cache) { - throw new System.NotImplementedException(); + return get_serialized_attributes(serialization_cache).ObjectsToSerialize; } - public override IDictionary functions_to_serialize(IDictionary serialization_cache) + public override IDictionary functions_to_serialize(IDictionary serialization_cache) { - throw new System.NotImplementedException(); + return get_serialized_attributes(serialization_cache).FunctionsToSerialize; } /// /// Generates or retrieves serialized attributes from cache. /// /// - protected void get_serialized_attributes(IDictionary serialization_cache) + protected SerializedAttributes get_serialized_attributes(IDictionary serialization_cache) { // TODO: deal with cache. - Layer a; - + var serialized_attr = SerializedAttributes.Create(_obj); + + // TODO: complete the statement. Currently the `Layer` lacks member `_must_restore_from_config`. + if (KerasSavedModelUtils.should_skip_serialization(_obj)) + { + return serialized_attr; + } + + var (object_dict, function_dict) = get_serialized_attributes_internal(serialization_cache); + + serialized_attr.set_and_validate_objects(object_dict); + serialized_attr.set_and_validate_functions(function_dict); + return serialized_attr; + } + + /// + /// Returns dictionary of serialized attributes. + /// + /// + private (IDictionary, IDictionary) get_serialized_attributes_internal(IDictionary serialization_cache) + { + var objects = KerasSavedModelUtils.wrap_layer_objects(_obj, serialization_cache); + var functions = KerasSavedModelUtils.wrap_layer_functions(_obj, serialization_cache); + + functions["_default_save_signature"] = null; + + return (objects, functions); } public override string TrackingMetadata diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs index ff3c78757..804ea1a93 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs @@ -17,15 +17,15 @@ namespace Tensorflow.Keras.Saving.SavedModel public abstract class SerializedAttributes { protected IDictionary _object_dict; - protected IDictionary _function_dict; + protected IDictionary _function_dict; protected AutoTrackable _keras_trackable; protected HashSet _all_functions; protected HashSet _all_checkpointable_objects; - protected SerializedAttributes() + private SerializedAttributes() { _object_dict= new Dictionary(); - _function_dict= new Dictionary(); + _function_dict= new Dictionary(); _keras_trackable= new AutoTrackable(); _all_functions= new HashSet(); _all_checkpointable_objects= new HashSet(); @@ -34,25 +34,35 @@ protected SerializedAttributes() protected SerializedAttributes(IEnumerable checkpointable_objects, IEnumerable functions) { _object_dict = new Dictionary(); - _function_dict = new Dictionary(); + _function_dict = new Dictionary(); _keras_trackable = new AutoTrackable(); _all_checkpointable_objects = new HashSet(checkpointable_objects); _all_functions = new HashSet(functions); } - public IDictionary Functions => _function_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); + protected SerializedAttributes((IEnumerable, IEnumerable) objects_and_functions) + { + _object_dict = new Dictionary(); + _function_dict = new Dictionary(); + _keras_trackable = new AutoTrackable(); + + _all_checkpointable_objects = new HashSet(objects_and_functions.Item1); + _all_functions = new HashSet(objects_and_functions.Item2); + } + + public IDictionary Functions => _function_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); public IDictionary CheckpointableObjects => _object_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); /// /// Returns functions to attach to the root object during serialization. /// - public IDictionary FunctionsToSerialize + public IDictionary FunctionsToSerialize { get { - Dictionary functions = new(); + Dictionary functions = new(); foreach(var pair in Functions) { if (_all_functions.Contains(pair.Key)) @@ -82,7 +92,7 @@ public IDictionary ObjectsToSerialize /// Saves function dictionary, and validates dictionary values. /// /// - public IDictionary set_and_validate_functions(IDictionary function_dict) + public IDictionary set_and_validate_functions(IDictionary function_dict) { foreach(var key in _all_functions) { @@ -186,94 +196,87 @@ protected virtual (IEnumerable, IEnumerable) get_objects_and_fun // However, currently it's just a normal class. public class CommonEndPoints: SerializedAttributes { - protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + public CommonEndPoints(IEnumerable checkpointable_objects, IEnumerable functions) : + //base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), + // functions.Concat(new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" })) + base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables"}), + functions.Concat(new string[] { })) { - if(checkpointable_objects is null) - { - checkpointable_objects = new List(); - } - if(functions is null) - { - functions = new List(); - } - return base.get_objects_and_functions_recursively( - checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), - // TODO: remove the `__call__`. - functions.Concat(new string[] {"__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) - ); + + } + + public CommonEndPoints() : + //base(new string[] { "variables", "trainable_variables", "regularization_losses" }, + // new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) + base(new string[] { "variables", "trainable_variables"}, + new string[] {}) + { + } } public class LayerAttributes: CommonEndPoints { - protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + public LayerAttributes(IEnumerable checkpointable_objects, IEnumerable functions) : + //base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), + // functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) + base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers"}), + functions.Concat(new string[] { })) { - if (checkpointable_objects is null) - { - checkpointable_objects = new List(); - } - if (functions is null) - { - functions = new List(); - } - return base.get_objects_and_functions_recursively( - checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), - functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) - ); + + } + + public LayerAttributes() : + //base(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }, + // new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) + base(new string[] { "non_trainable_variables", "layers" }, + new string[] { }) + { + } } public class ModelAttributes: LayerAttributes { - protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + public ModelAttributes(IEnumerable checkpointable_objects, IEnumerable functions): + base(checkpointable_objects, functions) { - if (checkpointable_objects is null) - { - checkpointable_objects = new List(); - } - if (functions is null) - { - functions = new List(); - } - return base.get_objects_and_functions_recursively(checkpointable_objects,functions); + + } + + public ModelAttributes(): base() + { + } } public class MetricAttributes : SerializedAttributes { - protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + public MetricAttributes(IEnumerable checkpointable_objects, IEnumerable functions) : + base(checkpointable_objects.Concat(new string[] { "variables" }), functions) { - if (checkpointable_objects is null) - { - checkpointable_objects = new List(); - } - if (functions is null) - { - functions = new List(); - } - return base.get_objects_and_functions_recursively( - checkpointable_objects.Concat(new string[] { "variables" }), - functions - ); + + } + + public MetricAttributes() : + base(new string[] { "variables" }, new string[] {}) + { + } } public class RNNAttributes: LayerAttributes { - protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + public RNNAttributes(IEnumerable checkpointable_objects, IEnumerable functions) : + base(checkpointable_objects, functions.Concat(new string[] {"states"})) { - if (checkpointable_objects is null) - { - checkpointable_objects = new List(); - } - if (functions is null) - { - functions = new List(); - } - return base.get_objects_and_functions_recursively( - checkpointable_objects.Concat(new string[] { "states" }), - functions - ); + + } + + public RNNAttributes() : + base(new string[] { }, new string[] { "states" }) + { + } } } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs index a5d84d674..3054271ae 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs @@ -1,4 +1,6 @@ using System; +using System.Collections.Generic; +using Tensorflow.Keras.Engine; namespace Tensorflow.Keras.Saving.SavedModel; @@ -12,6 +14,18 @@ public static SaveOptionsContext keras_option_scope(bool save_traces) ShouldHaveTraces = save_traces; return res; } + + public static IEnumerable list_all_layers(Layer layer) + { + if(layer is Model) + { + return (layer as Model).Layers; + } + else + { + return new List(layer._flatten_layers(false, false)); + } + } } /// From 83906b8f798d7faa99784da7d66489ca51dae4fd Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Mon, 30 Jan 2023 13:42:51 +0800 Subject: [PATCH 07/15] Add lacked implementations (mainly MultiDeviceSaver). --- .../Checkpoint/CheckpointOptions.cs | 2 +- .../Checkpoint/ObjectGraphView.cs | 9 +- src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs | 23 +- .../Checkpoint/SaveUtilV1.cs | 27 +- .../Checkpoint/TrackableView.cs | 5 +- .../Checkpoint/checkpoint.cs | 9 +- .../Checkpoint/functional_saver.cs | 515 +++++++++++++++++- .../SavedModel/ISerializedAttributes.cs | 35 ++ .../Training/AutoTrackable.cs | 3 +- .../Saving/SavedModel/AugmentedGraphView.cs | 109 +++- .../Saving/SavedModel/SaveableView.cs | 6 +- .../Training/Saving/SavedModel/save.cs | 16 +- .../SavedModel/signature_serialization.cs | 99 +++- .../Saving/saveable_object_util.py.cs | 156 +++++- src/TensorFlowNET.Core/Training/Trackable.cs | 48 +- .../Training/TrackableUtils.cs | 28 +- .../Training/data_structures.cs | 3 +- .../Variables/BaseResourceVariable.cs | 3 + .../Variables/ResourceVariable.cs | 9 + src/TensorFlowNET.Keras/Engine/Functional.cs | 3 +- .../Engine/Layer.Serialize.cs | 7 +- src/TensorFlowNET.Keras/Engine/Layer.cs | 24 +- src/TensorFlowNET.Keras/Engine/Model.Save.cs | 2 +- src/TensorFlowNET.Keras/Engine/Model.cs | 3 +- .../Saving/SavedModel/Save.cs | 9 +- .../Saving/SavedModel/SaveImpl.cs | 4 +- .../Saving/SavedModel/base_serialization.cs | 7 +- .../Saving/SavedModel/layer_serialization.cs | 28 +- .../SavedModel/serialized_attributes.cs | 2 +- test/TensorFlowNET.Keras.UnitTest/SaveTest.cs | 4 +- 30 files changed, 1037 insertions(+), 161 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs b/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs index d8297ea3f..f14b5ce78 100644 --- a/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs +++ b/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs @@ -1,5 +1,5 @@ namespace Tensorflow.Checkpoint; public record class CheckpointOptions( - string experimental_io_device = null, + string? experimental_io_device = null, bool experimental_enable_async_checkpoint = false); \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs b/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs index 2ad554485..cb01b539a 100644 --- a/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs +++ b/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using Serilog.Debugging; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Train; namespace Tensorflow.Checkpoint; @@ -21,9 +22,9 @@ public object Clone() return new ObjectGraphView(Root, _attached_dependencies); } - public virtual List list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + public virtual List list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary>? serialization_cache = null) { - List res = base.children(obj, save_type) + List res = base.children(obj, save_type, serialization_cache) .Select(x => new TrackableReference(x.Key, x.Value)).ToList(); // Check the reference, not value. if (obj == Root && _attached_dependencies is not null) @@ -34,9 +35,9 @@ public virtual List list_children(Trackable obj, SaveType sa return res; } - public override IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + public override IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary>? serialization_cache = null) { - return list_children(obj, save_type).ToDictionary(x => x.Name, x => x.Refer); + return list_children(obj, save_type, serialization_cache).ToDictionary(x => x.Name, x => x.Refer); } public IEnumerable? AttachedDependencies diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs index dc2a92fb0..e646f1f04 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs @@ -28,7 +28,7 @@ Trackable object_to_save ); public static class SaveUtil { - public static (IDictionary>, IDictionary, IDictionary>, TrackableObjectGraph) + public static (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) serialize_graph_view(ObjectGraphView graph_view, IDictionary? object_map = null, bool call_with_mapped_captures = false, object? cache = null) { var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); @@ -117,16 +117,16 @@ private static TrackableObjectGraph fill_object_graph_proto(IList /// /// /// - private static IDictionary> get_and_write_tensors_to_serialize(IList tensor_trackables, IDictionary node_ids, + private static IDictionary>>> get_and_write_tensors_to_serialize(IList tensor_trackables, IDictionary node_ids, bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) { - Dictionary> serialized_tensors = new(); + Dictionary>>> serialized_tensors = new(); foreach(var td in tensor_trackables) { // TODO: deal with cache. var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; var trackable = td.object_to_save; - IDictionary tensor_dict; + IDictionary>> tensor_dict; if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) { (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); @@ -147,12 +147,12 @@ private static IDictionary> get_and_write return serialized_tensors; } - private static IDictionary get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) + private static IDictionary>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) { var trackable = trackable_data.object_to_save; // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. - IDictionary ret_tensor_dict; + IDictionary>> ret_tensor_dict; if (call_with_mapped_captures) { throw new NotImplementedException(); @@ -162,8 +162,8 @@ private static IDictionary get_tensors_from_trackable(TrackableD ret_tensor_dict = trackable.serialize_to_tensors(); } - // TODO: revise the types and complete it - Dictionary tensor_dict = new(); + // TODO: deal with the type `SaveSpce` (currently it will never be it). + Dictionary>> tensor_dict = new(); foreach(var pair in ret_tensor_dict) { var local_name = TrackableUtils.escape_local_name(pair.Key); @@ -172,9 +172,10 @@ private static IDictionary get_tensors_from_trackable(TrackableD tensor_dict[checkpoint_key] = maybe_tensor; - if(maybe_tensor is SaveSpec) + if(maybe_tensor.GetValueA() is SaveSpec) { - ((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; + throw new NotImplementedException(); + //((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; } if(object_graph_proto is not null) @@ -198,7 +199,7 @@ private static IDictionary get_tensors_from_trackable(TrackableD /// /// /// - private static (Trackable, IDictionary) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary node_ids, + private static (Trackable, IDictionary>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary node_ids, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) { Dictionary object_names = new(); diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs index 44fa5c5de..d8e251ece 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs @@ -174,25 +174,20 @@ public static (List, object?) generate_saveable_objects( { var name = factory_data.name; var key = factory_data.checkpoint_key; - var saveable_factory = factory_data.factory; - + var maybe_saveable = factory_data.factory; + // TODO: oneflow python has a process with callable `saveable_factory`. - var maybe_saveable = saveable_factory; - IEnumerable savesbles; - if (maybe_saveable is MySaveableObject) - { - savesbles = new List() { (MySaveableObject)maybe_saveable }; - } - else if (maybe_saveable is Tensor) + List saveables = new(); + if (maybe_saveable.DataType == typeof(MySaveableObject)) { - savesbles = saveable_object_util.saveable_objects_for_op((Tensor)maybe_saveable, key); + saveables.Add(maybe_saveable.GetValueB()); } else { - throw new TypeError("Unexpected type."); + saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValueA() as Trackable, key)); } - foreach (var saveable in savesbles) + foreach (var saveable in saveables) { if (!saveable.name.Contains(key)) { @@ -204,11 +199,11 @@ public static (List, object?) generate_saveable_objects( // skip the process of PythonState - named_saveable_objects.AddRange(savesbles); + named_saveable_objects.AddRange(saveables); if(!fill_object_proto) continue; - - // skip the process of TrackableSaveable + + // skip the process of `TrackableSaveable` because of lack of APIs. object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() { Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) }); @@ -221,7 +216,7 @@ public static (List, object?) generate_saveable_objects( public record class CheckpointFactoryData ( - object factory, + Maybe factory, string name, string checkpoint_key ); \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs index 69bf76fda..f89dc10d7 100644 --- a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs +++ b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs @@ -2,6 +2,7 @@ using Tensorflow.Train; using System.Collections.Generic; using System.IO; +using Tensorflow.Keras.Saving.SavedModel; namespace Tensorflow.Checkpoint; @@ -18,13 +19,13 @@ public TrackableView(WeakReference obj) _root_ref = obj; } - public virtual IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + public virtual IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) { obj._maybe_initialize_trackable(); Dictionary children = new(); // Note: in python the return type of `Trackable._trackable_children` is not fixed. // Therefore it uses `convert_to_trackable` to have an extra process. - foreach (var pair in obj._trackable_children(save_type)) + foreach (var pair in obj._trackable_children(save_type, cache)) { children[pair.Key] = pair.Value; } diff --git a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs index 791094899..c9bee0db3 100644 --- a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs +++ b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs @@ -33,7 +33,7 @@ public TrackableSaver(ObjectGraphView graph_view) } - private (IDictionary>, IDictionary, IDictionary>, TrackableObjectGraph) + private (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) gather_serialized_tensors(Tensor? object_graph_tensor = null) { var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); @@ -125,7 +125,7 @@ public Tensor save(string file_prefix, int? checkpoint_number = null, Session? s } Dictionary feed_dict = new(); - bool use_session = (!new Context().executing_eagerly() && !ops.inside_function()); + bool use_session = (!tf.Context.executing_eagerly() && !ops.inside_function()); if (checkpoint_number is not null) { file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; @@ -133,6 +133,7 @@ public Tensor save(string file_prefix, int? checkpoint_number = null, Session? s Tensor file_prefix_tensor; Tensor object_graph_tensor; + string file_prefix_to_save; if (use_session) { if (_object_graph_feed_tensor is null) @@ -145,16 +146,18 @@ public Tensor save(string file_prefix, int? checkpoint_number = null, Session? s object_graph_tensor = _object_graph_feed_tensor; file_prefix_tensor = _file_prefix_feed_tensor; feed_dict[file_prefix_tensor] = file_prefix; + file_prefix_to_save = ""; } else { // In python there is `with ops.device("/cpu:0")`. file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING); object_graph_tensor = null; + file_prefix_to_save = file_prefix; } var (save_path, new_feed_additions) = - save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options); + save_cached_when_graph_building(file_prefix_to_save, object_graph_tensor, options); if (new_feed_additions is not null) { diff --git a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs index 759cbd663..c4a03985f 100644 --- a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs +++ b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs @@ -6,9 +6,254 @@ using static Tensorflow.ApiDef.Types; using static Tensorflow.CostGraphDef.Types; using static Tensorflow.OptimizerOptions.Types; +using static Tensorflow.Binding; +using System.Text.RegularExpressions; +using System.Linq; +using Tensorflow.Operations; +using Tensorflow.Training; +using Tensorflow.Graphs; namespace Tensorflow.Checkpoint { + /// + /// `FunctionHolder` is a series of containers to help dynamically call some dotnet functions. + /// Note that this API does not gurantee performance. Besides, it is not supposed to be exposed to users. + /// + public interface IFunctionHolder + { + int ArgCount { get; } + object DynamicInvoke(params object[] args); + } + internal record class FunctionHolder(Func Func): IFunctionHolder + { + public int ArgCount => 0; + public object DynamicInvoke(params object[] args) + { + return Func.DynamicInvoke(args); + } + } + internal record class FunctionHolder(Func Func) : IFunctionHolder + { + public int ArgCount => 1; + public object DynamicInvoke(params object[] args) + { + return Func.DynamicInvoke(args); + } + } + internal record class FunctionHolder(Func Func) : IFunctionHolder + { + public int ArgCount => 2; + public object DynamicInvoke(params object[] args) + { + return Func.DynamicInvoke(args); + } + } + internal record class FunctionHolder(Func Func) : IFunctionHolder + { + public int ArgCount => 3; + public object DynamicInvoke(params object[] args) + { + return Func.DynamicInvoke(args); + } + } + public class Maybe + { + private TA? _valueA = default(TA); + private TB? _valueB = default(TB); + private Type _type; + private bool _assigned = false; + public Maybe(TA value) + { + _valueA = value; + _type= typeof(TA); + _assigned = true; + } + public Maybe(TB value) + { + _valueB = value; + _type = typeof(TB); + _assigned = true; + } + + public Type DataType => _type; + + public TA GetValueA() + { + if(!_assigned || DataType != typeof(TA)) + { + throw new TypeError("Cannot get the data because of wrong specified type."); + } + return _valueA; + } + public TB GetValueB() + { + if (!_assigned || DataType != typeof(TB)) + { + throw new TypeError("Cannot get the data because of wrong specified type."); + } + return _valueB; + } + public object GetValue() + { + if (!_assigned) + { + throw new TypeError("Cannot get the data because of wrong specified type."); + } + if(DataType == typeof(TA) && _valueA is not null) + { + return _valueA; + } + else if(DataType == typeof(TB) && _valueB is not null) + { + return _valueB; + } + else if(DataType == typeof(TA)) + { + return _valueA; + } + else + { + return _valueB; + } + } + + public static implicit operator Maybe(TA a) + { + return new Maybe(a); + } + public static implicit operator Maybe(TB b) + { + return new Maybe(b); + } + } + internal class SingleDeviceSaver + { + private IDictionary>> _tensor_slice_dict; + public SingleDeviceSaver(IDictionary>> tensor_slice_dict) + { + _tensor_slice_dict = tensor_slice_dict; + } + public SingleDeviceSaver(IDictionary> tensor_slice_dict) + { + _tensor_slice_dict = tensor_slice_dict.ToDictionary( + x => x.Key, x => x.Value.ToDictionary( + y => y.Key, y => new Maybe(y.Value)) + as IDictionary>); + } + public SingleDeviceSaver(IDictionary> tensor_slice_dict) + { + _tensor_slice_dict = tensor_slice_dict.ToDictionary( + x => x.Key, x => x.Value.ToDictionary( + y => y.Key, y => new Maybe(y.Value)) + as IDictionary>); + } + public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) + { + if(options is null) + { + options = new CheckpointOptions(); + } + List tensor_names = new(); + List tensors = new(); + List slice_specs = new(); + foreach(var pair in _tensor_slice_dict) + { + var checkpoint_key = pair.Key; + var tensor_slices = pair.Value; + foreach(var slice in tensor_slices) + { + var slice_spec = slice.Key; + var maybe_tensor = slice.Value; + // TODO: deal with other types. Currently only `SaveSpec` is allowed. + if(maybe_tensor.DataType == typeof(SaveSpec)) + { + var spec = maybe_tensor.GetValueB(); + var tensor_value = spec.tensor; + if (tensor_value is not null) + { + tensor_names.Add(spec.name); + tensors.Add(tensor_value); + slice_specs.Add(spec.slice_spec); + } + } + else + { + var tensor = maybe_tensor.GetValueA(); + tensor_names.Add(checkpoint_key); + tensors.Add(tensor); + slice_specs.Add(slice_spec); + } + } + } + // TODO: specify the device. + return tf.io.save_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensors.ToArray()); + } + + public Operation? save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix, TF_DataType.TF_STRING), options); + + public IDictionary> restore(Tensor file_prefix, CheckpointOptions? options = null) + { + if(options is null) + { + options = new CheckpointOptions(); + } + List tensor_names = new(); + List tensor_dtypes = new(); + List slice_specs = new(); + + foreach(var pair in _tensor_slice_dict) + { + var checkpoint_key = pair.Key; + var tensor_slices = pair.Value; + foreach(var slice in tensor_slices) + { + var slice_spec = slice.Key; + var maybe_tensor = slice.Value; + // TODO: deal with other types. Currently only `SaveSpec` is allowed. + if(maybe_tensor.DataType == typeof(SaveSpec)) + { + var spec = maybe_tensor.GetValueB(); + tensor_dtypes.Add(spec.dtype); + slice_specs.Add(spec.slice_spec); + tensor_names.Add(spec.name); + } + else + { + var tensor = maybe_tensor.GetValueA(); + tensor_dtypes.Add(tensor.dtype); + slice_specs.Add(slice_spec); + tensor_names.Add(checkpoint_key); + } + } + } + + string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!; + + // tf python has code `with ops.device(restore_device):` here. + tf.device(restore_device); // may be risky. + var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); + + Dictionary> restored_tensor_dict = new(); + int idx = 0; + foreach(var pair in _tensor_slice_dict) + { + var checkpoint_key = pair.Key; + var tensor_slices = pair.Value; + foreach(var slice_spec in tensor_slices.Keys) + { + var restored_tensor = restored_tensors[idx++]; + if (!restored_tensor_dict.ContainsKey(checkpoint_key)) + { + restored_tensor_dict[checkpoint_key] = new Dictionary(); + } + restored_tensor_dict[checkpoint_key][slice_spec] = restored_tensor; + } + } + return restored_tensor_dict; + } + + public IDictionary> restore(string file_prefix, CheckpointOptions? options = null) => restore(tf.constant(file_prefix)); + } /// /// Saves checkpoints directly from multiple devices. /// Note that this is a low-level utility which stores Tensors in the keys @@ -17,20 +262,280 @@ namespace Tensorflow.Checkpoint /// public class MultiDeviceSaver { - public MultiDeviceSaver(IDictionary> serialized_tensors, + private Dictionary _single_device_savers; + private IDictionary _registered_savers; + private Dictionary<(string, string), IFunctionHolder> _keys_to_restore_fn; + private Dictionary> _restore_fn_to_keys; + /// + /// + /// + /// A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. + /// + /// + public MultiDeviceSaver(IDictionary>>> serialized_tensors, IDictionary>? registered_savers = null, bool call_with_mapped_capture = false) { + _keys_to_restore_fn = new Dictionary<(string, string), IFunctionHolder>(); + _restore_fn_to_keys = new Dictionary>(); + Dictionary>> tensors_by_device= new(); + + foreach(var pair in serialized_tensors) + { + var obj = pair.Key; + var tensor_dict = pair.Value; + IFunctionHolder restore_fn; + if(obj is null) + { + restore_fn = new FunctionHolder(() => null); + } + else + { + restore_fn = null; + // TODO: implement obj._restore_from_tensors + } + + foreach(var item in tensor_dict) + { + var checkpoint_key = item.Key; + IDictionary spec_to_tensor; + if(item.Value.DataType != typeof(IDictionary)) + { + spec_to_tensor = new Dictionary(); + spec_to_tensor[""] = item.Value.GetValueA(); + } + else + { + spec_to_tensor = item.Value.GetValueB(); + } + + foreach(var spec in spec_to_tensor) + { + var slice_spec = spec.Key; + var tensor = spec.Value; + if(_keys_to_restore_fn.ContainsKey((checkpoint_key, slice_spec))) + { + throw new ValueError("Recieved multiple tensors with the same checkpoint key and " + + $"slice spec. This is invalid because one will overwrite the " + + $"other in the checkpoint. This indicates a bug in the Checkpoint key-generation."); + } + _keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn; + _restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); + + // skip the process of device name because lack of API. + var host_device = tensor.Device; + var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary>()); + if (!internal_dict.ContainsKey(checkpoint_key)) + { + internal_dict[checkpoint_key] = new Dictionary(); + } + internal_dict[checkpoint_key][slice_spec] = tensor; + } + } + } + + _single_device_savers = tensors_by_device.ToDictionary(x => x.Key, x => new SingleDeviceSaver(x.Value)); + _registered_savers = new Dictionary(); + if(registered_savers is not null && registered_savers.Count > 0) + { + // TODO: complete the implementation. + throw new NotImplementedException(); + } } - public Operation? save(string file_prefix, CheckpointOptions? options= null) + public Operation save(string file_prefix, CheckpointOptions? options= null) { - throw new NotImplementedException(); + if(options is null) + { + options = new CheckpointOptions(); + } + + tf.device("CPU"); // may be risky. + // TODO: optimize the implementation with new APIs adding to `string_ops`. + string sharded_suffix = Regex.Match(file_prefix, "^s3://.*").Success ? ".part" : "_temp/part"; + var tmp_checkpoint_prefix = tf.constant(file_prefix + sharded_suffix); + IDictionary registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); + + Operation save_fn() + { + List saved_prefixes= new(); + foreach(var saver in _registered_savers) + { + // TODO: implementi it later. + throw new NotImplementedException(); + } + + int num_shards = _single_device_savers.Count; + List sharded_saves = new(); + var num_shards_tensor = constant_op.constant(num_shards, name: "num_shards"); + string? last_device = null; + int shard = 0; + foreach(var pair in _single_device_savers.OrderBy(x => x.Key)) + { + var device = pair.Key; + var saver = pair.Value; + last_device = device; + // skip the extra process of device name because of lack of API. + tf.device(device); + var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); + saved_prefixes.Add(shard_prefix); + sharded_saves.Add(saver.save(shard_prefix, options)); + } + using (var controller = ops.control_dependencies(sharded_saves.ToArray())) + { + string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; + tf.device(merge_device); + return gen_ops.merge_v2checkpoints(tf.concat(saved_prefixes, 0), tf.constant(file_prefix), delete_old_dirs: true); + } + } + + if(tf.Context.executing_eagerly() && _single_device_savers.Count > 1) + { + // TODO: implement it. Currently `autograph` does not support the function with non parameter. + throw new NotImplementedException(); + } + else + { + return save_fn(); + } } - public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) + public Operation save(Tensor file_prefix, CheckpointOptions? options = null) => save(file_prefix.numpy().StringData()[0], options); + + public IDictionary restore(string file_prefix, CheckpointOptions? options = null) + { + if(options is null) + { + options = new CheckpointOptions(); + } + + IDictionary restore_func() + { + Dictionary>>> restore_fn_inputs = new(); + Dictionary restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); + Dictionary restore_ops = new(); + + foreach(var single_saver in _single_device_savers.OrderBy(x => x.Key)) + { + var device = single_saver.Key; + var saver = single_saver.Value; + tf.device(device); + var restored_tensor_dict = saver.restore(file_prefix, options); + + foreach(var pair in restored_tensor_dict) + { + var checkpoint_key = pair.Key; + var slice_and_tensor = pair.Value; + foreach(var item in slice_and_tensor) + { + var slice_spec = item.Key; + var tensor = item.Value; + var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; + var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary>>()); + if (!string.IsNullOrEmpty(slice_spec)) + { + if (!internal_dict.ContainsKey(checkpoint_key)) + { + Dictionary dict = new(); + dict[slice_spec] = tensor; + internal_dict[checkpoint_key] = new Maybe>(dict); + } + else + { + internal_dict[checkpoint_key].GetValueB()[slice_spec] = tensor; + } + } + else + { + internal_dict[checkpoint_key] = new Maybe>(tensor); + } + restore_fn_input_count[restore_fn]--; + + if (restore_fn_input_count[restore_fn] == 0) + { + Dictionary>> restored_tensors = new(); + foreach(var input in restore_fn_inputs[restore_fn]) + { + restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; + } + var ret = restore_fn.DynamicInvoke(restored_tensors); + if(ret is IDictionary) + { + var dict = (IDictionary)ret; + restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); + } + } + } + } + } + + foreach(var item in _registered_savers) + { + throw new NotImplementedException(); + } + return restore_ops; + } + + // TODO: complete the implementation. Currently skip it because of lack of API. + bool has_custom_device_saver = false; + + if (tf.Context.executing_eagerly() && (_single_device_savers.Count > 1 || has_custom_device_saver)) + { + // TODO: implement it. Currently `autograph` does not support the function with non parameter. + throw new NotImplementedException(); + } + else + { + return restore_func(); + } + } + + /// + /// Serializes to a SaverDef referencing the current graph. + /// + public SaverDef to_proto() + { + var filename_tensor = array_ops.placeholder(TF_DataType.TF_STRING, new int[] { }, "saver_filename"); + var save_tensor = _traced_save(filename_tensor); + var restore_op = _traced_restore(filename_tensor).op; + return new SaverDef() + { + FilenameTensorName = filename_tensor.name, + SaveTensorName = save_tensor.name, + RestoreOpName = restore_op.name, + Version = SaverDef.Types.CheckpointFormatVersion.V2 + }; + } + + [AutoGraph] + private Tensor _traced_save(Tensor file_prefix) + { + var save_op = save(file_prefix.StringData()[0]); + tf.device("cpu:0"); + using (ops.control_dependencies(new object[]{ save_op })) + { + return array_ops.identity(file_prefix); + } + } + + [AutoGraph] + private Tensor _traced_restore(Tensor file_prefix) + { + var restore_op = restore(file_prefix.StringData()[0]); + tf.device("cpu:0"); + using (ops.control_dependencies(new object[] { restore_op })) + { + return array_ops.identity(file_prefix); + } + } + + private static Tensor registered_saver_filename(string filename, string saver_name) + { + return tf.constant($"{filename}-{saver_name}"); + } + private static Tensor sharded_filename(Tensor filename_tensor, int shard, Tensor num_shards) { - throw new NotImplementedException(); + return filename_tensor; } } } diff --git a/src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs b/src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs new file mode 100644 index 000000000..ae8a1ab13 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs @@ -0,0 +1,35 @@ +using System; +using System.Collections.Generic; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + public interface ISerializedAttributes + { + IDictionary Functions { get; } + + IDictionary CheckpointableObjects { get; } + + /// + /// Returns functions to attach to the root object during serialization. + /// + IDictionary FunctionsToSerialize { get; } + + /// + /// Returns objects to attach to the root object during serialization. + /// + IDictionary ObjectsToSerialize{get; } + + /// + /// Saves function dictionary, and validates dictionary values. + /// + /// + IDictionary set_and_validate_functions(IDictionary function_dict); + + /// + /// Saves objects to a dictionary, and validates the values. + /// + /// + IDictionary set_and_validate_objects(IDictionary object_dict); + } +} diff --git a/src/TensorFlowNET.Core/Training/AutoTrackable.cs b/src/TensorFlowNET.Core/Training/AutoTrackable.cs index 5dd9784f5..4d5a664ec 100644 --- a/src/TensorFlowNET.Core/Training/AutoTrackable.cs +++ b/src/TensorFlowNET.Core/Training/AutoTrackable.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Functions; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Operations.Activation; using static Tensorflow.Binding; @@ -24,7 +25,7 @@ public void _delete_tracking(string name) } } - public override IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + public override IDictionary _trackable_children(SaveType save_type, IDictionary>? cache = null) { if(save_type != SaveType.SAVEDMODEL) { diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs index 82da2ee94..97162651a 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs @@ -4,57 +4,130 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Functions; +using Tensorflow.Keras.Saving.SavedModel; namespace Tensorflow; public class AugmentedGraphView: ObjectGraphView { - // private object _children_cache; - // private object _serialization_cache; + private Dictionary> _children_cache; + private Dictionary> _serialization_cache; private List _untraces_functions; + private Dictionary _wrapped_functions; public AugmentedGraphView(Trackable root): base(root) { - _untraces_functions = new(); + _children_cache= new Dictionary>(); + _serialization_cache = new Dictionary>(); + _untraces_functions = new List(); + _wrapped_functions = new Dictionary(); } - public void set_signature(object signature_map, object wrapped_functions) + public void set_signature(SignatureMap signature_map, IDictionary wrapped_functions) { - // TODO: cache list_children(Root); + var name = SignatureSerializationUtils.SIGNATURE_ATTRIBUTE_NAME; + if (!_children_cache.ContainsKey(Root)) + { + _children_cache[Root] = new Dictionary(); + } + _children_cache[Root][name] = signature_map; + _wrapped_functions = _wrapped_functions.Concat(wrapped_functions).ToDictionary(x => x.Key, x => x.Value); } - public override List list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL) + public override List list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL, IDictionary>? serialization_cache = null) { - Dictionary children = new(); - foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL)) + if(serialization_cache is not null) + { + throw new ValueError("Serialization cache should not be passed to `AugmentedGraphView.list_children`, please either remove the parameter or use `ObjectGraphView.list_children`."); + } + + if (!_children_cache.ContainsKey(obj)) + { + Dictionary children = new Dictionary(); + _children_cache[obj] = children; + foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL, _serialization_cache)) + { + var name = pair.Name; + var child = pair.Refer; + if(child is ConcreteFunction) + { + child = maybe_uncache_variable_captures((ConcreteFunction)child); + } + children[name] = child; + } + + if (obj is Function && children.Count == 0) + { + _untraces_functions.Add(((Function)obj).Name); + } + } + + List res = new(); + foreach(var pair in _children_cache[obj]) { - var name = pair.Name; - var child = pair.Refer; - children[name] = child; + res.Add(new TrackableReference(pair.Key, pair.Value)); } - if (obj is Function && children.Count == 0) + return res; + } + + private ConcreteFunction maybe_uncache_variable_captures(ConcreteFunction concrete_function) + { + if (_wrapped_functions.ContainsKey(concrete_function)) { - _untraces_functions.Add(((Function)obj).Name); + return _wrapped_functions[concrete_function]; } + // skip the process here because of lack of feature. + // In the future, we may add an attribute which could specify if the variable is supposed to be cached. + //foreach(var capture in concrete_function.CapturedInputs) + //{ - return children.Select(x => new TrackableReference(x.Key, x.Value)).ToList(); + //} + return concrete_function; } public override (List, Dictionary>) breadth_first_traversal() { - // TODO: implement it if needed. + Trackable get_merged_trackable(Trackable x) + { + // TODO: complete it with new definitions `Asset` and `TrackableConstant`. + return x; + } + var trackable_objects = base.breadth_first_traversal(); + + foreach(var obj in _children_cache.Keys) + { + // skip the deletion of cache (maybe do it later). + foreach(var pair in _children_cache[obj]) + { + _children_cache[obj][pair.Key] = get_merged_trackable(pair.Value); + } + } + return base.breadth_first_traversal(); } public List<(string, Trackable)> list_dependencies(Trackable obj) { - // TODO: deal with cache. - return obj.deserialization_dependencies(null).Select(x => (x.Key, x.Value)).ToList(); + IDictionary children; + if (!_children_cache.ContainsKey(obj)) + { + children= new Dictionary(); + } + else + { + children= _children_cache[obj]; + } + List<(string, Trackable)> res = new(); + foreach(var pair in obj.deserialization_dependencies(children)) + { + res.Add((pair.Key, pair.Value)); + } + return res; } public Trackable get_child(Trackable obj, string name) { - throw new NotImplementedException(); + return _children_cache[obj][name]; } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs index 6a241f0e7..6700e277d 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs @@ -141,16 +141,16 @@ public List dependency_sorted_node_ids() foreach (var node in _nodes) { var node_id = _node_ids[node]; - List deps = new(); + List deps = new List(); + dependency_map.Add(node_id, deps); // TODO: deal with captured tensor. - string node_path; foreach (var (_, dep) in _augmented_graph_view.list_dependencies(node)) { if (!_node_ids.ContainsKey(dep)) { - node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]); + var node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]); throw new ValueError( $"Found an untracked dependency. Object {node_path} depends on {dep}, " + $"but this dependency isn't listed as a child. Please track this child by " + diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs index cc8399528..f3f273b81 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs @@ -24,7 +24,7 @@ public static partial class SavedModelUtils }.Select(x => (int)x); public static (IList, IDictionary>) save_and_return_nodes(Trackable obj, - string export_dir, IDictionary? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false) + string export_dir, ConcreteFunction? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false) { if (options is null) { @@ -41,9 +41,9 @@ public static (IList, IDictionary, IDictionary, IDictionary, Dictionary>) _build_meta_graph(Trackable obj, - IDictionary? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) + ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) { if (ops.inside_function()) { @@ -95,9 +95,9 @@ private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List, } AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); - if (signatures is not null) + if (signatures is null) { - throw new NotImplementedException(); + signatures = SignatureSerializationUtils.find_function_to_export(augmented_graph_view); } // TODO: process of aignatures and wrapped_functions @@ -125,7 +125,7 @@ private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List, } private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view, - IDictionary signatures, IEnumerable namespace_whitelist, + ConcreteFunction signatures, IEnumerable namespace_whitelist, bool save_custom_gradients) { var resource_initializers = saveable_view.get_concrete_resource_initializers(); diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs index 21272941f..0d34907f7 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs @@ -1,15 +1,84 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using Tensorflow.Functions; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Train; namespace Tensorflow; +public static class SignatureSerializationUtils +{ + internal static readonly string DEFAULT_SIGNATURE_ATTR = "_default_save_signature"; + internal static readonly string SIGNATURE_ATTRIBUTE_NAME = "signatures"; + internal static readonly int _NUM_DISPLAY_NORMALIZED_SIGNATURES = 5; + public static SignatureMap create_signature_map(IDictionary signatures) + { + var signature_map = new SignatureMap(); + foreach (var pair in signatures) + { + var name = pair.Key; + var func = pair.Value; + Debug.Assert(func is ConcreteFunction); + // TODO: assert the `func.structured_outputs` and arg_keywords. + signature_map._add_signature(name, (ConcreteFunction)func); + } + + return signature_map; + } + + public static ConcreteFunction find_function_to_export(AugmentedGraphView graph_view) + { + var children = graph_view.list_children(graph_view.Root); + List possible_signatures = new(); + foreach (var item in children) + { + var name = item.Name; + var child = item.Refer; + if(child is not (Function or ConcreteFunction)) + { + continue; + } + if(name == DEFAULT_SIGNATURE_ATTR) + { + Debug.Assert(child is ConcreteFunction); + return (ConcreteFunction)child; + } + ConcreteFunction concrete = get_signature(child); + if(concrete is not null && valid_signature(concrete)) + { + possible_signatures.Add(concrete); + } + } + + if(possible_signatures.Count == 1) + { + var signature = get_signature(possible_signatures[0]); + if(signature is not null && valid_signature(signature)) + { + return signature; + } + } + return null; + } + + private static ConcreteFunction get_signature(Trackable function) + { + // TODO: implement it. + return null; + } + + private static bool valid_signature(ConcreteFunction concreate_function) + { + // TODO: implement it. + return false; + } +} + public class SignatureMap: Trackable { - private Dictionary _signatures; - private Dictionary _concrete_signatures; + private Dictionary _signatures; public SignatureMap() { @@ -18,7 +87,7 @@ public SignatureMap() public void _add_signature(string name, ConcreteFunction concrete_function) { - _concrete_signatures[name] = concrete_function; + _signatures[name] = concrete_function; } public void _add_signature(string name, Function concrete_function) @@ -26,33 +95,13 @@ public void _add_signature(string name, Function concrete_function) _signatures[name] = concrete_function; } - public override IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + public override IDictionary _trackable_children(SaveType save_type, IDictionary>? cache = null) { if (save_type != SaveType.SAVEDMODEL) { return new Dictionary(); } - Dictionary res = _signatures.ToDictionary(x => x.Key, x => (Trackable)x.Value); - foreach (var pair in _concrete_signatures) - { - res[pair.Key] = pair.Value; - } - - return res; - } - - public static SignatureMap create_signature_map(IDictionary signatures) - { - var signature_map = new SignatureMap(); - foreach (var pair in signatures) - { - var name = pair.Key; - var func = pair.Value; - // TODO: assert the arg_keywords - signature_map._add_signature(name, func); - } - - return signature_map; + return _signatures.TakeWhile(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value); } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index 622eed3a7..7066b3665 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -16,18 +16,38 @@ limitations under the License. using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using Tensorflow.Checkpoint; using Tensorflow.Train; +using Tensorflow.Training; using static Tensorflow.Binding; namespace Tensorflow { - public static class saveable_object_util + /// + /// A SaveableObject that defines `Trackable` checkpointing steps. + /// + public class TrackableSaveable : MySaveableObject { - public class TrackableSaveable: MySaveableObject + private string _prefix; + private IEnumerable _local_names; + private Trackable _trackable; + private bool _call_with_mapped_captures; + // TODO: revise the implementation. Currently the parameter of constructor of this class and its base class has conflict. + public TrackableSaveable(Trackable obj, IEnumerable specs, string name, IEnumerable local_names, + string prefix, bool call_with_mapped_captures = false) : base((object)obj as Tensor, specs.ToArray(), name) { - + _prefix = prefix; + _trackable = obj; + _local_names = local_names; + _call_with_mapped_captures = call_with_mapped_captures; } + + // TODO: complete this class. + } + public static class saveable_object_util + { /// /// Returns the variables and names that will be used for a Saver. /// @@ -57,7 +77,7 @@ private static void _add_saveable(List saveables, List seen_ops, T } /// - /// Create `SaveableObject`s from an operation. + /// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`. /// /// /// @@ -79,6 +99,74 @@ public static IEnumerable saveable_objects_for_op(Tensor op, s } } + /// + /// Create `SaveableObject`s from an operation. + /// + /// + /// + /// + public static IEnumerable saveable_objects_for_op(Trackable obj, string name) + { + // The `op` maybe `Variable` or `Trackable`. + if (obj is BaseResourceVariable) + { + var variable = obj as BaseResourceVariable; + if (variable.InGraphMode) + { + yield return new ResourceVariableSaveable(variable.GraphElement, "", name); + } + else + { + Debug.Assert(variable is ResourceVariable); + yield return new ResourceVariableSaveable((ResourceVariable)variable, "", name); + } + } + else + { + foreach(var pair in saveable_objects_from_trackable(obj)) + { + var attr = pair.Key; + var factory = pair.Value; + string full_name; + if(attr == Trackable.Constants.VARIABLE_VALUE_KEY) + { + full_name = name; + } + else + { + full_name = name + "_" + attr; + } + if(factory.DataType == typeof(ResourceVariable)) + { + var variable = factory.GetValueA(); + foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name)) + { + yield return op; + } + } + else + { + var variable = factory.GetValueB(); + foreach (var op in saveable_objects_for_op(variable, variable.name)) + { + yield return op; + } + } + } + } + } + + /// + /// Create `SaveableObject`s from an operation. + /// + /// + /// + /// + public static IEnumerable saveable_objects_for_op(MySaveableObject obj, string name) + { + yield return obj; + } + public static Dictionary op_list_to_dict(IVariableV1[] op_list, bool convert_variable_to_tensor = true) { op_list = op_list.OrderBy(x => x.Name).ToArray(); @@ -127,16 +215,55 @@ public static Dictionary op_list_to_dict(IVariableV1[] op_list, return names_to_saveables; } - public static IDictionary saveable_objects_from_trackable(Trackable obj) + public static IDictionary> saveable_objects_from_trackable(Trackable obj) { - // TODO: complete the implementation. - return obj.gather_saveables_for_checkpoint(); + // skip the process of type `PythonState` + + if (trackable_has_serialize_to_tensor(obj)) + { + var name = TrackableUtils.SERIALIZE_TO_TENSORS_NAME; + // skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. + var tensor_dict = obj.serialize_to_tensors(); + + List specs = new(); + List local_names = new(); + string prefix = SaveableCompat.get_saveable_name(obj) ?? ""; + foreach(var pair in tensor_dict) + { + var tensor_name = pair.Key; + var maybe_tensor = pair.Value; + local_names.Add(tensor_name); + string spec_name = name + TrackableUtils.escape_local_name(tensor_name); + + IDictionary internal_dict; + if(maybe_tensor.DataType == typeof(Tensor)) + { + internal_dict= new Dictionary(); + internal_dict[""] = maybe_tensor.GetValueA(); + } + else + { + internal_dict = maybe_tensor.GetValueB(); + } + + foreach(var item in internal_dict) + { + specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); + } + } + Dictionary> res = new(); + res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix); + return res; + } + else + { + return obj.gather_saveables_for_checkpoint(); + } } public static bool trackable_has_serialize_to_tensor(Trackable obj) { - // TODO: implement it. - return false; + return obj.GetType().GetMethod("serialize_to_tensors").DeclaringType != typeof(Trackable); } internal static string convert_to_string(string x) @@ -158,27 +285,28 @@ public SaveableCompatibilityConverter(Trackable obj, IList sav public Trackable Obj => _obj; public IList mySaveables=> _saveables; - public override IDictionary serialize_to_tensors() + public override IDictionary>> serialize_to_tensors() { - return saveable_objects_to_tensor_dict(_saveables); + return saveable_object_to_tensor_dict(_saveables); } /// /// Converts a list of SaveableObjects to a tensor dictionary. /// /// - public static Dictionary saveable_objects_to_tensor_dict(IList saveables) + public static Dictionary>> saveable_object_to_tensor_dict(IList saveables) { - Dictionary tensor_dict = new(); + Dictionary>> tensor_dict = new(); foreach (var saveable in saveables) { foreach(var spec in saveable.specs) { + // skip the check that if `spec` is callable. var name = saveable_object_util.convert_to_string(spec.name); var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); if (!string.IsNullOrEmpty(slice_spec)) { - throw new NotImplementedException(); + tensor_dict.SetDefault(name, new Dictionary()).GetValueB()[slice_spec] = spec.tensor; } else { diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index 2646fb8d5..a677044a1 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -16,7 +16,10 @@ limitations under the License. using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using Tensorflow.Checkpoint; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.ModelSaving; using Tensorflow.Training; using static Tensorflow.Binding; @@ -39,8 +42,8 @@ public static class Constants protected IList _unconditional_checkpoint_dependencies; - protected IDictionary _self_saveable_object_factories = - new Dictionary(); + protected IDictionary> _self_saveable_object_factories = + new Dictionary>(); private bool _manual_tracking = true; private static Trackable _none = new Function(); @@ -94,9 +97,13 @@ protected virtual IVariableV1 _add_variable_with_custom_getter(VariableArgs args // assign again. It will add this variable to our dependencies, and if there // is a non-trivial restoration queued, it will handle that. This also // handles slot variables. - if (!args.Overwrite || new_variable is RefVariable) - return _track_checkpointable(new_variable, name: args.Name, - overwrite: args.Overwrite); + if (!args.Overwrite || new_variable is RefVariable || new_variable is Trackable) + { + var temp = new_variable as Trackable; + var res = _track_trackable(temp, args.Name, args.Overwrite); + Debug.Assert(res is IVariableV1); + return res as IVariableV1; + } else return new_variable; } @@ -122,13 +129,16 @@ protected IVariableV1 _track_checkpointable(IVariableV1 checkpointable, string n /// public void _maybe_initialize_trackable() { + if(_unconditional_checkpoint_dependencies is not null) + { + return; + } _self_update_uid = -1; _unconditional_checkpoint_dependencies = new List(); _unconditional_dependency_names = new Dictionary(); } - // TODO: cache - public virtual IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + public virtual IDictionary _trackable_children(SaveType save_type, IDictionary>? cache) { _maybe_initialize_trackable(); return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); @@ -139,8 +149,8 @@ public virtual Trackable _track_trackable(Trackable trackable, string name, bool _maybe_initialize_trackable(); if (!_manual_tracking) return trackable; var new_reference = new TrackableReference(name, trackable); - var current_object = _lookupup_dependency(name); - + var current_object = _lookup_dependency(name); + if(current_object is null) { _unconditional_checkpoint_dependencies.Add(new_reference); @@ -170,7 +180,7 @@ public virtual void _handle_deferred_dependencies(string name, Trackable trackab // TODO: complete the implementation. } - public virtual Trackable? _lookupup_dependency(string name) + public virtual Trackable? _lookup_dependency(string name) { if (_unconditional_dependency_names.TryGetValue(name, out var dependency)) return dependency; else return null; @@ -199,8 +209,8 @@ public virtual (IDictionary, IDictionary) return (new Dictionary(), new Dictionary()); } - public virtual List export_to_saved_model_graph(IDictionary? object_map = null, - IDictionary? tensor_map = null, SaveOptions? options = null) + public virtual List export_to_saved_model_graph(IDictionary object_map, + IDictionary tensor_map, SaveOptions? options = null) { var (self_object_map, self_tensor_map) = map_resources(options); foreach (var pair in self_object_map) @@ -215,9 +225,17 @@ public virtual List export_to_saved_model_graph(IDictionary gather_saveables_for_checkpoint() + public virtual IDictionary> gather_saveables_for_checkpoint() { - return _self_saveable_object_factories; + if (saveable_object_util.trackable_has_serialize_to_tensor(this)) + { + // TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). + throw new NotImplementedException(); + } + else + { + return _self_saveable_object_factories; + } } /// @@ -229,7 +247,7 @@ public virtual IDictionary gather_saveables_for_checkp /// /// /// - public virtual IDictionary serialize_to_tensors() + public virtual IDictionary>> serialize_to_tensors() { throw new NotImplementedException(); } diff --git a/src/TensorFlowNET.Core/Training/TrackableUtils.cs b/src/TensorFlowNET.Core/Training/TrackableUtils.cs index 990207028..390d95c75 100644 --- a/src/TensorFlowNET.Core/Training/TrackableUtils.cs +++ b/src/TensorFlowNET.Core/Training/TrackableUtils.cs @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.Linq; using Tensorflow.Exceptions; using Tensorflow.Train; @@ -22,7 +23,7 @@ public CyclicDependencyError(IDictionary> leftover_dependency_map private static string _ESCAPE_CHAR = "."; private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; - private static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; + internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; public static string object_path_to_string(IEnumerable node_path_arr) { return string.Join("/", node_path_arr.Select(x => escape_local_name(x.Name))); @@ -145,4 +146,27 @@ public static string pretty_print_node_path(IEnumerable path return $"root.{string.Join(".", paths.Select(x => x.Name))}"; } } + + /// + /// Returns the substring after the "/.ATTIBUTES/" in the checkpoint key. + /// + /// + /// + /// + public static string extract_local_name(string key, string? prefix = null) + { + if(prefix is null) + { + prefix = ""; + } + var search_key = OBJECT_ATTRIBUTES_NAME + "/" + prefix; + try + { + return key.Substring(key.IndexOf(search_key) + search_key.Length); + } + catch(ArgumentOutOfRangeException) + { + return key; + } + } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/data_structures.cs b/src/TensorFlowNET.Core/Training/data_structures.cs index d4e9c401b..6e3336c90 100644 --- a/src/TensorFlowNET.Core/Training/data_structures.cs +++ b/src/TensorFlowNET.Core/Training/data_structures.cs @@ -9,6 +9,7 @@ using System.Text; using Tensorflow.Functions; using Tensorflow.Keras; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Operations.Activation; using Tensorflow.Train; using static Tensorflow.ApiDef.Types; @@ -243,7 +244,7 @@ private void update_snapshot() _last_wrapped_list_snapshot = new List(_storage); } - public override IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + public override IDictionary _trackable_children(SaveType save_type, IDictionary>? cache = null) { check_external_modification(); if (_non_append_mutation_value) diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index f217a052d..756024db4 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -4,6 +4,8 @@ using Tensorflow.Variables; using Tensorflow.Train; using static Tensorflow.Binding; +using System.Collections.Generic; +using Tensorflow.ModelSaving; namespace Tensorflow { @@ -20,6 +22,7 @@ public class BaseResourceVariable : DisposableTrackableObject public string UniqueId => _unique_id; protected bool _in_graph_mode; + internal bool InGraphMode => _in_graph_mode; protected bool _trainable; public bool Trainable => _trainable; diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index b31960c73..6093f8106 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -17,7 +17,9 @@ limitations under the License. using Google.Protobuf; using System; using System.Collections.Generic; +using Tensorflow.Checkpoint; using Tensorflow.NumPy; +using Tensorflow.Train; using static Tensorflow.Binding; namespace Tensorflow @@ -235,5 +237,12 @@ public NDArray eval(Session session = null) { return _graph_element.eval(session); } + + public override IDictionary> gather_saveables_for_checkpoint() + { + var res = new Dictionary>(); + res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; + return res; + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index 61a8956a6..7c8812adb 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Keras.Utils; using Tensorflow.Train; using static Tensorflow.Binding; @@ -351,7 +352,7 @@ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? train return output_tensors; } - public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary? cache = null) + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) { return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache)) .ToDictionary(x => x.Key, x => x.Value); diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs index 1675fba1e..ffb6f71bc 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Train; @@ -9,16 +10,16 @@ public abstract partial class Layer { public LayerSavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); - public string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; + public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata; - public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary? cache = null) + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) { IDictionary children; if (save_type == SaveType.SAVEDMODEL) { - // TODO: deal with cache. + Debug.Assert(cache is not null); children = TrackableSavedModelSaver.trackable_children(cache); } else diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index b9b01dae5..a2f92ba8b 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -88,9 +88,29 @@ public abstract partial class Layer : AutoTrackable, ILayer ThreadLocal callContext = new ThreadLocal(); public CallContext CallContext => callContext.Value; - public Tensor[] input => inboundNodes[0].input_tensors; + public Tensor[] input + { + get + { + if(inboundNodes is not null && inboundNodes.Count > 0) + { + return inboundNodes[0].input_tensors; + } + return null; + } + } public Dictionary> NodesByDepth { get; set; } - public Shape OutputShape => inboundNodes[0].Outputs.shape; + public Shape OutputShape + { + get + { + if(inboundNodes is not null && inboundNodes.Count > 0) + { + return inboundNodes[0].Outputs.shape; + } + return null; + } + } protected List _self_tracked_trackables; public Layer(LayerArgs args) diff --git a/src/TensorFlowNET.Keras/Engine/Model.Save.cs b/src/TensorFlowNET.Keras/Engine/Model.Save.cs index 59f74cd20..59b205e44 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Save.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Save.cs @@ -21,7 +21,7 @@ public void save(string filepath, bool include_optimizer = true, string save_format = "tf", SaveOptions? options = null, - IDictionary? signatures = null, + ConcreteFunction? signatures = null, bool save_traces = true) { if (save_format != "pb") diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs index 41f7788ed..dfe5b05f3 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -4,6 +4,7 @@ using Tensorflow.Keras.Engine.DataAdapters; using Tensorflow.Keras.Losses; using Tensorflow.Keras.Optimizers; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Train; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -110,7 +111,7 @@ public override List TrainableVariables } } - public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary? cache = null) + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) { if(save_type == SaveType.SAVEDMODEL) { diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs index 76453ca0d..6a6e418cf 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Saving.SavedModel; public partial class KerasSavedModelUtils { - public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, IDictionary? signatures, + public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures, SaveOptions? options, bool save_traces = true) { if (!overwrite && File.Exists(filepath)) @@ -54,12 +54,7 @@ public static void Save(Model model, string filepath, bool overwrite, bool inclu } var metadata = generate_keras_metadata(saved_nodes, node_paths); - using (var f = new FileStream(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), FileMode.OpenOrCreate, - FileAccess.Write)) - { - var writer = new StreamWriter(f); - writer.Write(metadata.ToString()); - } + File.WriteAllBytes(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToByteArray()); if (!include_optimizer) { diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs index 7168e25b7..fc7eab3a3 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs @@ -19,7 +19,7 @@ public static bool should_skip_serialization(object layer) /// /// /// - public static IDictionary wrap_layer_objects(Layer layer, IDictionary serialization_cache) + public static IDictionary wrap_layer_objects(Layer layer, IDictionary> serialization_cache) { // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. @@ -55,7 +55,7 @@ public static IDictionary wrap_layer_objects(Layer layer, IDi /// /// /// - public static IDictionary wrap_layer_functions(Layer layer, IDictionary serialization_cache) + public static IDictionary wrap_layer_functions(Layer layer, IDictionary> serialization_cache) { // TODO: deal with type `RevivedLayer` and `Sequential`. diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs index a399eaf13..0235f87bd 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs @@ -18,12 +18,12 @@ public SavedModelSaver(Trackable obj) public abstract string TrackingMetadata { get; } public abstract IDictionary objects_to_serialize( - IDictionary serialization_cache); + IDictionary> serialization_cache); public abstract IDictionary functions_to_serialize( - IDictionary serialization_cache); + IDictionary> serialization_cache); - public IDictionary trackable_children(IDictionary? serialization_cache) + public IDictionary trackable_children(IDictionary> serialization_cache) { if (!KerasSavedModelUtils.ShouldHaveTraces) { @@ -31,7 +31,6 @@ public IDictionary trackable_children(IDictionary x.Key, x => (Trackable)x.Value)) .ToDictionary(x => x.Key, x => x.Value); } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs index 7a0ddd21b..b092b5950 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs @@ -19,12 +19,12 @@ public override string ObjectIdentifier get => Constants.LAYER_IDENTIFIER; } - public override IDictionary objects_to_serialize(IDictionary serialization_cache) + public override IDictionary objects_to_serialize(IDictionary> serialization_cache) { return get_serialized_attributes(serialization_cache).ObjectsToSerialize; } - public override IDictionary functions_to_serialize(IDictionary serialization_cache) + public override IDictionary functions_to_serialize(IDictionary> serialization_cache) { return get_serialized_attributes(serialization_cache).FunctionsToSerialize; } @@ -33,11 +33,21 @@ public override IDictionary functions_to_serialize(IDictionar /// Generates or retrieves serialized attributes from cache. /// /// - protected SerializedAttributes get_serialized_attributes(IDictionary serialization_cache) + protected ISerializedAttributes get_serialized_attributes(IDictionary> serialization_cache) { // TODO: deal with cache. + IDictionary keras_cache; + if(serialization_cache is not null && serialization_cache.ContainsKey(Constants.KERAS_CACHE_KEY)) + { + keras_cache = serialization_cache[Constants.KERAS_CACHE_KEY]; + } + else + { + serialization_cache![Constants.KERAS_CACHE_KEY] = keras_cache = new Dictionary(); + } + if (keras_cache.ContainsKey(_obj)) return keras_cache[_obj]; - var serialized_attr = SerializedAttributes.Create(_obj); + var serialized_attr = keras_cache[_obj] = SerializedAttributes.Create(_obj); // TODO: complete the statement. Currently the `Layer` lacks member `_must_restore_from_config`. if (KerasSavedModelUtils.should_skip_serialization(_obj)) @@ -56,7 +66,7 @@ protected SerializedAttributes get_serialized_attributes(IDictionary /// - private (IDictionary, IDictionary) get_serialized_attributes_internal(IDictionary serialization_cache) + private (IDictionary, IDictionary) get_serialized_attributes_internal(IDictionary> serialization_cache) { var objects = KerasSavedModelUtils.wrap_layer_objects(_obj, serialization_cache); var functions = KerasSavedModelUtils.wrap_layer_functions(_obj, serialization_cache); @@ -75,7 +85,7 @@ public override string TrackingMetadata metadata["trainable"] = _obj.Trainable; // metadata["expects_training_arg"] = _obj._expects_training_arg; // metadata["dtype"] = policy.serialize(_obj._dtype_policy) - metadata["batch_input_shape"] = JToken.FromObject(_obj.BatchInputShape); + metadata["batch_input_shape"] = _obj.BatchInputShape is null ? null : JToken.FromObject(_obj.BatchInputShape); // metadata["stateful"] = _obj.stateful; // metadata["must_restore_from_config"] = _obj.must_restore_from_config; // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; @@ -92,8 +102,10 @@ public override string TrackingMetadata } } - public static LayerConfig get_serialized(Layer obj) + public static IDictionary get_serialized(Layer obj) { - return generic_utils.serialize_keras_object(obj); + // TODO: complete the implmentation (need to revise `get_config`). + return new Dictionary(); + //return generic_utils.serialize_keras_object(obj); } } \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs index 804ea1a93..ac194c00f 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs @@ -14,7 +14,7 @@ namespace Tensorflow.Keras.Saving.SavedModel /// /// Class that tracks and validates all serialization attributes. /// - public abstract class SerializedAttributes + public abstract class SerializedAttributes: ISerializedAttributes { protected IDictionary _object_dict; protected IDictionary _function_dict; diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs b/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs index 9d1b30886..0f34ff10d 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs @@ -50,11 +50,11 @@ public void Test() { TrainDir = "mnist", OneHot = false, - ValidationSize = 50000, + ValidationSize = 0, }).Result; model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); - model.save("", save_format:"pb"); + model.save("C:\\Work\\tf.net\\tf_test\\tf.net.model", save_format:"pb"); } } \ No newline at end of file From f2e41a17916b25ff6fd3baf20ed6fc0d651fb4c2 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Thu, 2 Feb 2023 17:34:50 +0800 Subject: [PATCH 08/15] Support autograph.to_graph under graph mode. --- src/TensorFlowNET.Core/Graphs/AutoGraph.cs | 46 +++++++++++++++------- 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/AutoGraph.cs b/src/TensorFlowNET.Core/Graphs/AutoGraph.cs index 2af1a3720..ceeca8abf 100644 --- a/src/TensorFlowNET.Core/Graphs/AutoGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/AutoGraph.cs @@ -1,4 +1,5 @@ using System; +using System.Diagnostics; using System.Linq; using static Tensorflow.Binding; @@ -6,14 +7,14 @@ namespace Tensorflow.Graphs { public class AutoGraph { - public Func to_graph(Func func) + public Func to_graph(Func func, TF_DataType dtype = TF_DataType.TF_INT32) { string func_name = $"{func.Method.Name}_{ops.uid_function()}"; var graph = new FuncGraph(func_name); graph.as_default(); - var input = tf.placeholder(tf.int32); + var input = tf.placeholder(dtype); var output = func(input); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); @@ -26,25 +27,33 @@ public Func to_graph(Func func) return (Tensor input) => { - var result = tf.Runner.TFE_Execute(tf.Context, - tf.Context.DeviceName, - func_name, - new[] { input }, - null, - 1); - return result[0]; + if (tf.executing_eagerly()) + { + var result = tf.Runner.TFE_Execute(tf.Context, + tf.Context.DeviceName, + func_name, + new[] { input }, + null, + 1); + return result[0]; + } + using (var s = tf.Session(input.graph)) + { + var output = func(input); + return output; + } }; } - public Func to_graph(Func func) + public Func to_graph(Func func, params TF_DataType[] dtypes) { string func_name = $"{func.Method.Name}_{ops.uid_function()}"; var graph = new FuncGraph(func_name); graph.as_default(); - var input1 = tf.placeholder(tf.int32); - var input2 = tf.placeholder(tf.int32); + var input1 = tf.placeholder(dtypes.Length >= 1 ? dtypes[0] : tf.int32); + var input2 = tf.placeholder(dtypes.Length >= 2 ? dtypes[1] : tf.int32); var output = func(input1, input2); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); @@ -56,13 +65,22 @@ public Func to_graph(Func func) return (Tensor a, Tensor b) => { - var result = tf.Runner.TFE_Execute(tf.Context, + if (tf.executing_eagerly()) + { + var result = tf.Runner.TFE_Execute(tf.Context, tf.Context.DeviceName, func_name, new[] { a, b }, null, 1); - return result[0]; + return result[0]; + } + using (var s = tf.Session(a.graph)) + { + Debug.Assert(a.graph == b.graph); + var output = func(a, b); + return output; + } }; } } From a479e53f3aad18a6272eeddb8f3243b10f3beffb Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Thu, 2 Feb 2023 19:17:15 +0800 Subject: [PATCH 09/15] Add more implementations to the pb model save. --- .../Checkpoint/CheckPointUtils.cs | 10 +- src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs | 7 +- .../Checkpoint/SaveUtilV1.cs | 18 +-- .../Checkpoint/checkpoint.cs | 23 ++-- .../Checkpoint/functional_saver.cs | 60 ++++++---- src/TensorFlowNET.Core/Eager/execute.cs | 31 +++++ .../Framework/meta_graph.cs | 23 ++++ .../ModelSaving/SaveOptions.cs | 38 +++++- src/TensorFlowNET.Core/Operations/gen_ops.cs | 59 +++++++++- src/TensorFlowNET.Core/Operations/io_ops.cs | 32 +++++ .../Operations/resource_variable_ops.cs | 54 +++++++++ .../Saving/ResourceVariableSaveable.cs | 28 +++++ .../Training/Saving/SaveableObject.cs | 24 +++- .../Saving/SavedModel/SaveableView.cs | 14 +-- .../Training/Saving/SavedModel/save.cs | 83 +++++++------ .../Saving/SavedModel/save_context.cs | 53 +++++++++ .../Training/Saving/SavedModel/utils.cs | 5 + .../Saving/saveable_object_util.py.cs | 109 +++++++++++++----- src/TensorFlowNET.Core/Training/Trackable.cs | 13 ++- .../Variables/BaseResourceVariable.cs | 64 ++++++++++ .../Variables/ResourceVariable.cs | 8 +- .../Variables/UninitializedVariable.cs | 70 +++++++++++ .../Engine/Functional.GetConfig.cs | 2 +- .../Engine/Layer.Serialize.cs | 2 +- .../Layers/Core/InputLayer.cs | 3 + .../Saving/SavedModel/Save.cs | 4 +- .../Saving/SavedModel/SaveImpl.cs | 24 ++-- .../Saving/SavedModel/base_serialization.cs | 2 +- .../Saving/SavedModel/layer_serialization.cs | 63 ++++++++-- .../Utils/generic_utils.cs | 9 +- 30 files changed, 775 insertions(+), 160 deletions(-) create mode 100644 src/TensorFlowNET.Core/Eager/execute.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs create mode 100644 src/TensorFlowNET.Core/Variables/UninitializedVariable.cs diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs index 70d771559..cd37703b6 100644 --- a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs +++ b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.IO; using System.Linq; using Tensorflow.Train; @@ -85,17 +86,18 @@ public static Trackable get_mapped_trackable(Trackable trackable, IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) + public static (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) serialize_graph_view(ObjectGraphView graph_view, IDictionary? object_map = null, bool call_with_mapped_captures = false, object? cache = null) { var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); @@ -39,7 +39,7 @@ public static (IDictionary feed_additions; + Dictionary feed_additions; if(cache is null) { feed_additions = null; @@ -125,7 +125,7 @@ private static IDictionary>> tensor_dict; if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) { @@ -134,6 +134,7 @@ private static IDictionary>, object return (checkpoint_factory_map, null); } - public static (List, object?) frozen_saveables_and_savers(ObjectGraphView graph_view, + public static (List, IDictionary>?) frozen_saveables_and_savers(ObjectGraphView graph_view, IDictionary object_map, Graph? to_graph, bool call_with_mapped_captures, object? saveables_cache = null) { if (to_graph is not null) { - to_graph.as_default(); + var g = to_graph.as_default(); var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, object_map, call_with_mapped_captures, saveables_cache); - // tensorflow python: `with ops.device("/cpu:0")` - var serialized = graph_proto.ToByteString().ToString(); - var object_graph_tensor = constant_op.constant("aaaa", TF_DataType.TF_STRING); + tf.device("/cpu:0"); + var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); + g.Exit(); return (named_saveable_objects, registered_savers); } else @@ -65,7 +65,7 @@ public static (List, object?) frozen_saveables_and_savers(Obje { var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, object_map, call_with_mapped_captures, saveables_cache); - // tensorflow python: `with ops.device("/cpu:0")` + tf.device("/cpu:0"); var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); return (named_saveable_objects, registered_savers); @@ -73,7 +73,7 @@ public static (List, object?) frozen_saveables_and_savers(Obje } } - public static (List, TrackableObjectGraph, object?, object?) serialize_gathered_objects(ObjectGraphView graph_view, + public static (List, TrackableObjectGraph, object?, IDictionary>?) serialize_gathered_objects(ObjectGraphView graph_view, IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) { var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); @@ -129,7 +129,7 @@ private static TrackableObjectGraph fill_object_graph_proto(ObjectGraphView grap return object_graph_proto; } - private static (List, object?, object?) add_attributes_to_object_graph(IList trackable_objects, + private static (List, object?, IDictionary>?) add_attributes_to_object_graph(IList trackable_objects, TrackableObjectGraph object_graph_proto, IDictionary node_ids, IDictionary object_names, IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) @@ -216,7 +216,7 @@ public static (List, object?) generate_saveable_objects( public record class CheckpointFactoryData ( - Maybe factory, + Maybe factory, string name, string checkpoint_key ); \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs index c9bee0db3..0c2862dac 100644 --- a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs +++ b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs @@ -33,7 +33,7 @@ public TrackableSaver(ObjectGraphView graph_view) } - private (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) + private (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) gather_serialized_tensors(Tensor? object_graph_tensor = null) { var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); @@ -42,26 +42,27 @@ public TrackableSaver(ObjectGraphView graph_view) if(object_graph_tensor is null) { - // tensorflow python: `with ops.device("/cpu:0"):` - object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); + tf.device("/cpu:0"); + object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); } else { - feed_additions[object_graph_tensor] = graph_proto.ToString(); + feed_additions[object_graph_tensor] = graph_proto.ToByteArray(); } Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); - if (serialized_tensors.ContainsKey(Trackable.None)) + if (!serialized_tensors.ContainsKey(Trackable.None)) { - serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; + serialized_tensors[Trackable.None] = new Dictionary>>(); } + serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; return (serialized_tensors, feed_additions, registered_savers, graph_proto); } - private (Tensor, IDictionary) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) + private (Tensor, IDictionary) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) { var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); - Func<(Tensor, IDictionary)> run_save = () => + Func<(Tensor, IDictionary)> run_save = () => { if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) { @@ -86,11 +87,11 @@ public TrackableSaver(ObjectGraphView graph_view) return run_save(); } - private (Tensor, IDictionary) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options) + private (Tensor, IDictionary) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options) { var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); - Func<(Tensor, IDictionary)> run_save = () => + Func<(Tensor, IDictionary)> run_save = () => { if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) { @@ -124,7 +125,7 @@ public Tensor save(string file_prefix, int? checkpoint_number = null, Session? s options = new CheckpointOptions(); } - Dictionary feed_dict = new(); + Dictionary feed_dict = new(); bool use_session = (!tf.Context.executing_eagerly() && !ops.inside_function()); if (checkpoint_number is not null) { diff --git a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs index c4a03985f..90bbccf07 100644 --- a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs +++ b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs @@ -12,6 +12,8 @@ using Tensorflow.Operations; using Tensorflow.Training; using Tensorflow.Graphs; +using System.Xml.Linq; +using System.Diagnostics; namespace Tensorflow.Checkpoint { @@ -31,6 +33,10 @@ public object DynamicInvoke(params object[] args) { return Func.DynamicInvoke(args); } + public TR Invoke() + { + return Func.Invoke(); + } } internal record class FunctionHolder(Func Func) : IFunctionHolder { @@ -164,7 +170,6 @@ public SingleDeviceSaver(IDictionary> tens { var slice_spec = slice.Key; var maybe_tensor = slice.Value; - // TODO: deal with other types. Currently only `SaveSpec` is allowed. if(maybe_tensor.DataType == typeof(SaveSpec)) { var spec = maybe_tensor.GetValueB(); @@ -284,14 +289,16 @@ public MultiDeviceSaver(IDictionary(() => null); } else { - restore_fn = null; - // TODO: implement obj._restore_from_tensors + restore_fn = new FunctionHolder>>, IDictionary>(x => + { + return obj._restore_from_tensors(x); + }); } foreach(var item in tensor_dict) @@ -343,7 +350,7 @@ public MultiDeviceSaver(IDictionary registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); Operation save_fn() @@ -385,7 +392,7 @@ Operation save_fn() { string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; tf.device(merge_device); - return gen_ops.merge_v2checkpoints(tf.concat(saved_prefixes, 0), tf.constant(file_prefix), delete_old_dirs: true); + return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); } } @@ -400,9 +407,9 @@ Operation save_fn() } } - public Operation save(Tensor file_prefix, CheckpointOptions? options = null) => save(file_prefix.numpy().StringData()[0], options); + public Operation save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix), options); - public IDictionary restore(string file_prefix, CheckpointOptions? options = null) + public IDictionary restore(Tensor file_prefix, CheckpointOptions? options = null) { if(options is null) { @@ -496,8 +503,10 @@ IDictionary restore_func() public SaverDef to_proto() { var filename_tensor = array_ops.placeholder(TF_DataType.TF_STRING, new int[] { }, "saver_filename"); - var save_tensor = _traced_save(filename_tensor); - var restore_op = _traced_restore(filename_tensor).op; + var traced_save_func = tf.autograph.to_graph(_traced_save, TF_DataType.TF_STRING); + var traced_restore_func = tf.autograph.to_graph(_traced_restore, TF_DataType.TF_STRING); + var save_tensor = traced_save_func(filename_tensor); + var restore_op = traced_restore_func(filename_tensor).op; return new SaverDef() { FilenameTensorName = filename_tensor.name, @@ -507,10 +516,9 @@ public SaverDef to_proto() }; } - [AutoGraph] private Tensor _traced_save(Tensor file_prefix) { - var save_op = save(file_prefix.StringData()[0]); + var save_op = save(file_prefix); tf.device("cpu:0"); using (ops.control_dependencies(new object[]{ save_op })) { @@ -518,24 +526,34 @@ private Tensor _traced_save(Tensor file_prefix) } } - [AutoGraph] private Tensor _traced_restore(Tensor file_prefix) { - var restore_op = restore(file_prefix.StringData()[0]); + var restore_op = restore(file_prefix); tf.device("cpu:0"); - using (ops.control_dependencies(new object[] { restore_op })) + using (ops.control_dependencies(restore_op.Values.ToArray())) { return array_ops.identity(file_prefix); } } - private static Tensor registered_saver_filename(string filename, string saver_name) + public static MultiDeviceSaver from_saveables(IEnumerable saveables, IDictionary>? registered_savers = null, bool call_with_mapped_captures = false) + { + Dictionary>>> serialized_tensors = new(); + foreach (var saveable in saveables) + { + var trackable = new SaveableCompatibilityConverter(saveable, new List() { saveable }); + serialized_tensors[trackable] = trackable.serialize_to_tensors(); + } + return new MultiDeviceSaver(serialized_tensors, registered_savers, call_with_mapped_captures); + } + + private static Tensor registered_saver_filename(Tensor filename_tensor, string saver_name) { - return tf.constant($"{filename}-{saver_name}"); + return gen_ops.string_join(new Tensor[] { filename_tensor, constant_op.constant($"-{saver_name}") }); } private static Tensor sharded_filename(Tensor filename_tensor, int shard, Tensor num_shards) { - return filename_tensor; + return gen_ops.sharded_filename(filename_tensor, tf.constant(shard), num_shards); } } } diff --git a/src/TensorFlowNET.Core/Eager/execute.cs b/src/TensorFlowNET.Core/Eager/execute.cs new file mode 100644 index 000000000..cb3ea4d3c --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/execute.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Xml.Linq; +using Tensorflow.Contexts; +using static Tensorflow.ApiDef.Types; +using static Tensorflow.CostGraphDef.Types; +using static Tensorflow.Binding; + +namespace Tensorflow.Eager +{ + internal class execute + { + public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx) + { + var v = values.Select(t => ops.convert_to_tensor(t, ctx:ctx)); + var types = v.Select(t => t.dtype.as_datatype_enum()); + return (types.ToArray(), v.ToArray()); + } + public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) + { + string device_name = ctx.DeviceName; + + ctx.ensure_initialized(); + var tensors = tf.Runner.TFE_Execute(ctx, device_name, op_name, inputs, attrs, num_outputs); + + return tensors; + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.cs b/src/TensorFlowNET.Core/Framework/meta_graph.cs index cce13b55d..c3616fafd 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.cs @@ -406,5 +406,28 @@ public static void strip_graph_default_valued_attrs(MetaGraphDef meta_graph_def) meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; } + + /// + /// Extract the Op name from a Tensor name. + /// + /// + /// + public static string op_name(string tensor_name) + { + if (string.IsNullOrEmpty(tensor_name)) + { + throw new ValueError($"Tensor name cannot be empty or None. Received: {tensor_name}."); + } + + if (tensor_name.StartsWith("^")) + { + tensor_name = tensor_name.Substring(1); + } + if (tensor_name.Contains(":")) + { + return tensor_name.Split(':')[0]; + } + return tensor_name; + } } } diff --git a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs b/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs index fce42850f..45ebd884f 100644 --- a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs +++ b/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs @@ -14,11 +14,47 @@ public class SaveOptions public IDictionary? function_aliases { get; set; } = null; public string? experimental_io_device { get; set; } = null; // TODO: experimental - public Object? experimental_variable_polict { get; set; } = null; + public VariablePolicy experimental_variable_policy { get; set; } = VariablePolicy.None; public bool experimental_custom_gradients { get; set; } = true; public SaveOptions(bool save_debug_info = false) { this.save_debug_info = save_debug_info; } } + + public class VariablePolicy + { + public string Policy { get; } + private VariablePolicy(string policy) + { + Policy = policy; + } + public static VariablePolicy None = new(null); + public static VariablePolicy SAVE_VARIABLE_DEVICES = new("save_variable_devices"); + public static VariablePolicy EXPAND_DISTRIBUTED_VARIABLES = new("expand_distributed_variables"); + + public bool save_variable_devices() + { + return this != VariablePolicy.None; + } + + /// + /// Tries to convert `obj` to a VariablePolicy instance. + /// + /// + /// + public static VariablePolicy from_obj(object obj) + { + if (obj is null) return VariablePolicy.None; + if (obj is VariablePolicy) return (VariablePolicy)obj; + var key = obj.ToString().ToLower(); + return key switch + { + null => VariablePolicy.None, + "save_variable_devices" => VariablePolicy.SAVE_VARIABLE_DEVICES, + "expand_distributed_variables" => VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, + _ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.") + }; + } + } } diff --git a/src/TensorFlowNET.Core/Operations/gen_ops.cs b/src/TensorFlowNET.Core/Operations/gen_ops.cs index 11cb6de8e..956be96b5 100644 --- a/src/TensorFlowNET.Core/Operations/gen_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_ops.cs @@ -1,6 +1,9 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Xml.Linq; +using Tensorflow.Contexts; +using Tensorflow.Eager; using static Tensorflow.Binding; namespace Tensorflow.Operations @@ -17182,17 +17185,47 @@ public static Tensor merge_summary(Tensor[] inputs, string name = "MergeSummary" /// path in the input checkpoint_prefixes. This is useful when those paths are non /// user-facing temporary locations. /// - public static Operation merge_v2checkpoints(Tensor checkpoint_prefixes, Tensor destination_prefix, bool? delete_old_dirs = null, string name = "MergeV2Checkpoints") - { + public static Operation merge_v2_checkpoints(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs = true, bool allow_missing_files = false, string name = "MergeV2Checkpoints") + { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name, + checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files)); + result = null; + return null; + //try + //{ + // var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name, + // new object[] { checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files })); + // result = null; + // return null; + //} + //catch (System.Exception) + //{ + // return merge_v2_checkpoints_eager_fallback(checkpoint_prefixes, destination_prefix, delete_old_dirs: delete_old_dirs, + // allow_missing_files: allow_missing_files, name: name, ctx: ctx); + //} + } var dict = new Dictionary(); dict["checkpoint_prefixes"] = checkpoint_prefixes; dict["destination_prefix"] = destination_prefix; - if (delete_old_dirs.HasValue) - dict["delete_old_dirs"] = delete_old_dirs.Value; + dict["delete_old_dirs"] = delete_old_dirs; var op = tf.OpDefLib._apply_op_helper("MergeV2Checkpoints", name: name, keywords: dict); return op; } + //public static Operation merge_v2_checkpoints_eager_fallback(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs, bool allow_missing_files, string name, Context ctx) + //{ + // checkpoint_prefixes = ops.convert_to_tensor(checkpoint_prefixes, TF_DataType.TF_STRING); + // destination_prefix = ops.convert_to_tensor(destination_prefix, TF_DataType.TF_STRING); + // var inputs_flat = new Tensor[] { checkpoint_prefixes, destination_prefix }; + // var attrs = new object[] { "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files }; + // var result = execute.quick_execute("MergeV2Checkpoints", 0, inputs_flat, attrs, ctx, name); + // result = null; + // return null; + //} + /// /// Transforms a spectrogram into a form that's useful for speech recognition. /// @@ -24259,6 +24292,12 @@ public static (Tensor output_false, Tensor output_true) ref_switch(Tensor data, /// public static Tensor regex_full_match(Tensor input, Tensor pattern, string name = "RegexFullMatch") { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("RegexFullMatch", name, input, pattern)); + return result[0]; + } var dict = new Dictionary(); dict["input"] = input; dict["pattern"] = pattern; @@ -29744,6 +29783,12 @@ public static Tensor[] shape_n(Tensor[] input, TF_DataType? out_type = null, str /// public static Tensor sharded_filename(Tensor basename, Tensor shard, Tensor num_shards, string name = "ShardedFilename") { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("ShardedFilename", name, basename, shard, num_shards)); + return result[0]; + } var dict = new Dictionary(); dict["basename"] = basename; dict["shard"] = shard; @@ -34668,6 +34713,12 @@ public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, /// public static Tensor string_join(Tensor[] inputs, string separator = null, string name = "StringJoin") { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("StringJoin", name, inputs, "separator", separator)); + return result[0]; + } var dict = new Dictionary(); dict["inputs"] = inputs; if (separator != null) diff --git a/src/TensorFlowNET.Core/Operations/io_ops.cs b/src/TensorFlowNET.Core/Operations/io_ops.cs index 4f276e36c..35c5877f3 100644 --- a/src/TensorFlowNET.Core/Operations/io_ops.cs +++ b/src/TensorFlowNET.Core/Operations/io_ops.cs @@ -14,7 +14,9 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System.Linq; using Tensorflow.Contexts; +using Tensorflow.Eager; using static Tensorflow.Binding; namespace Tensorflow @@ -23,11 +25,41 @@ public class io_ops { public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null) { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + try + { + var result = tf.Runner.TFE_FastPathExecute( + new FastPathOpExecInfo("SaveV2", name, new object[] { prefix, tensor_names, shape_and_slices, tensors })); + result = null; + return null; + } + catch (System.Exception) + { + return save_v2_eager_fallback(prefix, tensor_names, shape_and_slices, tensors, name, ctx); + } + } var _op = tf.OpDefLib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); return _op; } + public Operation save_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name, Context ctx) + { + DataType[] attr_dtypes; + (attr_dtypes, tensors) = execute.onvert_to_mixed_eager_tensors(tensors, ctx); + prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING); + var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING); + var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING); + var inputs_flat = tensors.Concat(new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor }).ToArray(); + var attrs = new object[] { "dtypes", attr_dtypes }; + + var result = execute.quick_execute("SaveV2", 0, inputs_flat, attrs, ctx, name); + result = null; + return null; + } + public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) { var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs index d5a32c10e..1b1fa0037 100644 --- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -17,7 +17,9 @@ limitations under the License. using System; using System.Linq; using Tensorflow.Framework; +using Tensorflow.ModelSaving; using Tensorflow.Train; +using Tensorflow.Variables; using static Tensorflow.CppShapeInferenceResult.Types; namespace Tensorflow @@ -177,5 +179,57 @@ private static HandleData get_eager_safe_handle_data(Tensor handle) return HandleData.Parser.ParseFrom(handle.BufferToArray()); } } + + /// + /// Copies an existing variable to a new graph, with no initializer. + /// + /// + public static UninitializedVariable copy_to_graph_uninitialized(ResourceVariable variable) + { + var new_variable = new UninitializedVariable( + trainable: variable.Trainable, + shape: variable.shape, + dtype: variable.dtype, + name: variable.SharedName, + aggregation: variable.Aggregation, + extra_handle_data: null); + new_variable._maybe_initialize_trackable(); + return new_variable; + } + + /// + /// Writes additional information of the variable into the SavedObject proto. + /// + /// + /// + /// + /// + public static void write_object_proto_for_resource_variable(BaseResourceVariable resource_variable, SavedObject proto, SaveOptions options, bool enforcing_naming = true) + { + // lack of API: `proto.Variable.SetInParent()`. + if(enforcing_naming && !resource_variable.Name.EndsWith(":0")) + { + throw new ValueError($"Cowardly refusing to save variable {resource_variable.Name} because of " + + $"unexpected suffix in the name (expected ':0') which won't be restored."); + } + if(proto.Variable is null) + { + proto.Variable = new SavedVariable(); + } + proto.Variable.Name = meta_graph.op_name(resource_variable.Name); + proto.Variable.Trainable = resource_variable.Trainable; + proto.Variable.Dtype = resource_variable.dtype.as_datatype_enum(); + // TODO: lack of API `proto.Variable.Synchronization = resource_variable.synchronization.value`. + proto.Variable.Aggregation = resource_variable.Aggregation; + proto.Variable.Shape = resource_variable.shape.as_proto(); + + if (options.experimental_variable_policy.save_variable_devices()) + { + if (!string.IsNullOrEmpty(resource_variable.Device)) + { + proto.Variable.Device = resource_variable.Device; + } + } + } } } diff --git a/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs b/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs index 167c635a8..2d23a325f 100644 --- a/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs +++ b/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs @@ -14,6 +14,8 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using static Tensorflow.Binding; + namespace Tensorflow { public class ResourceVariableSaveable : MySaveableObject @@ -35,6 +37,32 @@ public ResourceVariableSaveable(Tensor var, string slice_spec, string name) this.name = name; } + public ResourceVariableSaveable(BaseResourceVariable var, string slice_spec, string name) + { + _var_device = var.Device; + _var_shape = var.shape; + + Tensor _read_variable_closure(BaseResourceVariable v) + { + tf.device(v.Device); + if(tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) + { + return null; + } + var x = v.read_value_no_copy(); + tf.device("/device:CPU:0"); + return array_ops.identity(x); + } + + this.handle_op = var.Handle; + var tensor = _read_variable_closure(var); + + var spec = new SaveSpec(tensor, slice_spec, name, dtype: var.dtype); + _op = var; + specs = new SaveSpec[] { spec }; + this.name = name; + } + public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) { var restored_tensor = restored_tensors[0]; diff --git a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs index 6239030ba..43d36dba3 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs @@ -14,11 +14,31 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using Tensorflow.Checkpoint; + namespace Tensorflow { public class MySaveableObject { - public Tensor op; + protected Maybe _op; + public Tensor op + { + get + { + if(_op.DataType == typeof(Tensor)) + { + return _op.GetValueA(); + } + else + { + throw new TypeError("The _op is not a tensor."); + } + } + set + { + _op = value; + } + } public SaveSpec[] specs; public string name; public string device; @@ -35,7 +55,7 @@ public MySaveableObject(Tensor var, string slice_spec, string name) public MySaveableObject(Tensor op, SaveSpec[] specs, string name) { - this.op = op; + this._op = op; this.specs = specs; this.name = name; } diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs index 6700e277d..6132e0254 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs @@ -10,6 +10,7 @@ using Tensorflow.Training; using pbc = global::Google.Protobuf.Collections; using static Tensorflow.Binding; +using Tensorflow.Training.Saving.SavedModel; namespace Tensorflow; @@ -75,7 +76,7 @@ public SaveableView(AugmentedGraphView augmented_graph_view, SaveOptions options private void initialize_save_and_restore_functions() { // TODO: deal with the return value of `get_checkpoint_factories_and_keys`. - SaveUtilV1.get_checkpoint_factories_and_keys(_object_names); + var (checkpoint_factory_map, registered_savers) = SaveUtilV1.get_checkpoint_factories_and_keys(_object_names); // skip the process of registered savers and the generation of saveable_objects_map and _obj_to_registered_saver. _obj_to_registered_saver = new(); _saveable_objects_map = new(); @@ -191,7 +192,7 @@ public List dependency_sorted_node_ids() /// /// /// - public SavedObjectGraph serialize_object_graph(IDictionary asset_file_def_index, SaveOptions options) + public SavedObjectGraph serialize_object_graph(IDictionary asset_file_def_index) { SavedObjectGraph proto = new(); fill_object_graph_proto(proto); @@ -203,21 +204,20 @@ public SavedObjectGraph serialize_object_graph(IDictionary asset { var obj = _nodes[i]; var obj_proto = proto.Nodes[i]; - write_object_proto(obj, obj_proto, asset_file_def_index, x => _augmented_graph_view.list_children(x), - options); + write_object_proto(obj, obj_proto, asset_file_def_index, x => _augmented_graph_view.list_children(x)); } return proto; } private static void write_object_proto(Trackable obj, SavedObject proto, - IDictionary asset_file_def_index, Func> list_children_fn, SaveOptions options) + IDictionary asset_file_def_index, Func> list_children_fn) { // skip the process of type Asset if (resource_variable_ops.is_resource_variable(obj)) { - // TODO: complete it. - throw new NotImplementedException(); + var options = SaveContext.get_save_options(); + (obj as BaseResourceVariable).write_object_proto(proto, options); } else if (obj is Function) { diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs index f3f273b81..d82d49d8f 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs @@ -10,6 +10,7 @@ using Tensorflow.Train; using Tensorflow.Exceptions; using static Tensorflow.Binding; +using Tensorflow.Training.Saving.SavedModel; namespace Tensorflow; @@ -43,7 +44,7 @@ public static (IList, IDictionary, IDictionary, Dictionary>) _build_meta_graph(Trackable obj, ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) { - if (ops.inside_function()) + using (SaveContext.save_context(options)) { - throw new AssertionError("`tf.saved_model.save` is not supported inside a traced @tf.function. " + - "Move the call to the outer eagerly-executed context."); - } + if (ops.inside_function()) + { + throw new AssertionError("`tf.saved_model.save` is not supported inside a traced @tf.function. " + + "Move the call to the outer eagerly-executed context."); + } - if (meta_graph_def is null) - { - meta_graph_def = new MetaGraphDef(); - } + if (meta_graph_def is null) + { + meta_graph_def = new MetaGraphDef(); + } - AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); - if (signatures is null) - { - signatures = SignatureSerializationUtils.find_function_to_export(augmented_graph_view); - } - - // TODO: process of aignatures and wrapped_functions + AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); + if (signatures is null) + { + signatures = SignatureSerializationUtils.find_function_to_export(augmented_graph_view); + } - SaveableView saveable_view = new SaveableView(augmented_graph_view, options); - TrackableSaver object_saver = new TrackableSaver(augmented_graph_view); - var (asset_info, exported_graph) = _fill_meta_graph_def(meta_graph_def, saveable_view, signatures, - options.namespace_white_list, options.experimental_custom_gradients); - if (options.function_aliases is not null) - { - var function_aliases = meta_graph_def.MetaInfoDef.FunctionAliases; - foreach (var pair in options.function_aliases) + // TODO: process of aignatures and wrapped_functions + + SaveableView saveable_view = new SaveableView(augmented_graph_view, options); + TrackableSaver object_saver = new TrackableSaver(augmented_graph_view); + var (asset_info, exported_graph) = _fill_meta_graph_def(meta_graph_def, saveable_view, signatures, + options.namespace_white_list, options.experimental_custom_gradients); + if (options.function_aliases is not null) { - var alias = pair.Key; - var func = pair.Value; - // TODO: complete it. - throw new NotImplementedException(); + var function_aliases = meta_graph_def.MetaInfoDef.FunctionAliases; + foreach (var pair in options.function_aliases) + { + var alias = pair.Key; + var func = pair.Value; + // TODO: complete it. + throw new NotImplementedException(); + } } - } - var object_graph_proto = saveable_view.serialize_object_graph(asset_info.asset_index, options); - meta_graph_def.ObjectGraphDef = new SavedObjectGraph(object_graph_proto); + var object_graph_proto = saveable_view.serialize_object_graph(asset_info.asset_index); + meta_graph_def.ObjectGraphDef = new SavedObjectGraph(object_graph_proto); - return (meta_graph_def, exported_graph, object_saver, asset_info, saveable_view.Nodes, saveable_view.NodePaths); + return (meta_graph_def, exported_graph, object_saver, asset_info, saveable_view.Nodes, saveable_view.NodePaths); + } } private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view, @@ -134,7 +139,7 @@ private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_d Dictionary object_map; Dictionary tensor_map; AssetInfo asset_info; - exported_graph.as_default(); + var g = exported_graph.as_default(); (object_map, tensor_map, asset_info) = saveable_view.map_resources(); // TODO: deal with signatures. if (save_custom_gradients) @@ -161,15 +166,23 @@ private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_d // Lack `CopyFrom` API // meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY] + g.Exit(); + foreach (var obj in object_map.Values) { obj._maybe_initialize_trackable(); } + // TODO: add the implementation of `call_with_mapped_functions`. var (named_saveable_objects, registered_savers) = SaveUtilV1.frozen_saveables_and_savers(saveable_view.AugmentedGraphView, object_map, exported_graph, false); - - // TODO: complete the save of checkpoints with `MultiDeviceSaver`. + var saver = MultiDeviceSaver.from_saveables(named_saveable_objects, registered_savers, false); + + var eg = exported_graph.as_default(); + var saver_def = saver.to_proto(); + meta_graph_def.SaverDef = saver_def; + eg.Exit(); + saveable_view.dependency_sorted_node_ids(); diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs new file mode 100644 index 000000000..4cfe0b69b --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.ModelSaving; + +namespace Tensorflow.Training.Saving.SavedModel +{ + /// + /// A context for building a graph of SavedModel. + /// + public static class SaveContext + { + // TODO: make it thead safe. + private static bool _in_save_context = false; + private static SaveOptions _save_options = null; + + public static bool in_save_context() => _in_save_context; + public static SaveOptions get_save_options() + { + if (!in_save_context()) + { + throw new ValueError("Not in a SaveContext."); + } + return _save_options; + } + public static SaveContextHandler save_context(SaveOptions options) + { + return new SaveContextHandler(options); + } + + public class SaveContextHandler: IDisposable + { + private bool _old_in_save_context; + private SaveOptions _old_save_options; + public SaveContextHandler(SaveOptions options) + { + if (SaveContext.in_save_context()) + { + throw new ValueError("Already in a SaveContext."); + } + _old_in_save_context = SaveContext._in_save_context; + SaveContext._in_save_context = true; + _old_save_options = SaveContext._save_options; + SaveContext._save_options = options; + } + public void Dispose() + { + SaveContext._in_save_context = _old_in_save_context; + SaveContext._save_options = _old_save_options; + } + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs index 723419f6f..2deff0275 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs @@ -28,6 +28,11 @@ public static string get_variables_dir(string export_dir) return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.VARIABLES_DIRECTORY)); } + public static string get_variables_path(string export_dir) + { + return Path.Combine(tf.compat.as_text(get_variables_dir(export_dir)), tf.compat.as_text(Constants.VARIABLES_FILENAME)); + } + /// /// Return assets sub-directory, or create one if it doesn't exist. /// diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index 7066b3665..582e2431e 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -19,6 +19,7 @@ limitations under the License. using System.Diagnostics; using System.Linq; using Tensorflow.Checkpoint; +using Tensorflow.Operations.Activation; using Tensorflow.Train; using Tensorflow.Training; using static Tensorflow.Binding; @@ -117,8 +118,7 @@ public static IEnumerable saveable_objects_for_op(Trackable ob } else { - Debug.Assert(variable is ResourceVariable); - yield return new ResourceVariableSaveable((ResourceVariable)variable, "", name); + yield return new ResourceVariableSaveable(variable, "", name); } } else @@ -215,7 +215,7 @@ public static Dictionary op_list_to_dict(IVariableV1[] op_list, return names_to_saveables; } - public static IDictionary> saveable_objects_from_trackable(Trackable obj) + public static IDictionary> saveable_objects_from_trackable(Trackable obj) { // skip the process of type `PythonState` @@ -251,7 +251,7 @@ public static IDictionary> sav specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); } } - Dictionary> res = new(); + Dictionary> res = new(); res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix); return res; } @@ -270,25 +270,6 @@ internal static string convert_to_string(string x) { return tf.compat.as_str(x); } - } - - public class SaveableCompatibilityConverter: Trackable - { - private Trackable _obj; - private IList _saveables; - public SaveableCompatibilityConverter(Trackable obj, IList saveables) - { - _obj= obj; - _saveables= saveables; - } - - public Trackable Obj => _obj; - public IList mySaveables=> _saveables; - - public override IDictionary>> serialize_to_tensors() - { - return saveable_object_to_tensor_dict(_saveables); - } /// /// Converts a list of SaveableObjects to a tensor dictionary. @@ -299,11 +280,11 @@ public static Dictionary>> sav Dictionary>> tensor_dict = new(); foreach (var saveable in saveables) { - foreach(var spec in saveable.specs) + foreach (var spec in saveable.specs) { // skip the check that if `spec` is callable. - var name = saveable_object_util.convert_to_string(spec.name); - var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); + var name = convert_to_string(spec.name); + var slice_spec = convert_to_string(spec.slice_spec); if (!string.IsNullOrEmpty(slice_spec)) { tensor_dict.SetDefault(name, new Dictionary()).GetValueB()[slice_spec] = spec.tensor; @@ -316,5 +297,81 @@ public static Dictionary>> sav } return tensor_dict; } + + /// + /// Generates `Trackable._restore_from_tensors` from SaveableObjects. + /// + /// + public static Func>>, IDictionary> saveable_object_to_restore_fn(IList saveables) + { + return (restored_tensors) => + { + Dictionary restored_ops = new(); + + foreach(var saveable in saveables) + { + List saveable_restored_tensors = new(); + foreach(var spec in saveable.specs) + { + var name = TrackableUtils.extract_local_name(saveable_object_util.convert_to_string(spec.name)); + var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); + + var maybe_tensor = restored_tensors[name]; + IDictionary dict; + if(maybe_tensor.DataType == typeof(Tensor)) + { + dict = new Dictionary(); + dict[""] = maybe_tensor.GetValueA(); + } + else + { + dict = maybe_tensor.GetValueB(); + } + saveable_restored_tensors.Add(dict[slice_spec]); + } + restored_ops[saveable.name] = saveable.restore(saveable_restored_tensors.ToArray(), null); + } + return restored_ops; + }; + } + } + + public class SaveableCompatibilityConverter: Trackable + { + private object _obj; + private IList _saveables; + public SaveableCompatibilityConverter(object obj, IList saveables) + { + _obj= obj; + _saveables= saveables; + } + + public object Obj => _obj; + public IList mySaveables=> _saveables; + + public override IDictionary>> serialize_to_tensors() + { + return saveable_object_util.saveable_object_to_tensor_dict(_saveables); + } + + /// + /// Returns the restore ops defined in the Saveables. + /// + /// + /// + public override IDictionary _restore_from_tensors(IDictionary>> restored_tensors) + { + List expected_keys = new(); + foreach(var saveable in _saveables) + { + expected_keys.AddRange(saveable.specs.Select(x => TrackableUtils.extract_local_name(saveable_object_util.convert_to_string(x.name)))); + } + if (!expected_keys.Distinct().SequenceEqual(restored_tensors.Keys)) + { + throw new ValueError($"Could not restore object {_obj} because not all expected tensors were in the checkpoint." + + $"\n\tExpected: {expected_keys} \n\tGot: {list(restored_tensors.Keys)}"); + } + return saveable_object_util.saveable_object_to_restore_fn(_saveables).Invoke(restored_tensors); + } } } diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index a677044a1..434d51b63 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -42,11 +42,11 @@ public static class Constants protected IList _unconditional_checkpoint_dependencies; - protected IDictionary> _self_saveable_object_factories = - new Dictionary>(); + protected IDictionary> _self_saveable_object_factories = + new Dictionary>(); private bool _manual_tracking = true; - private static Trackable _none = new Function(); + private static Trackable _none = new AutoTrackable(); /// /// This is a trick for that CSharp does not allow the key of `Dictionary` to be null. /// The `None` can be any object that inherits `Trackable`. @@ -225,7 +225,7 @@ public virtual List export_to_saved_model_graph(IDictionary> gather_saveables_for_checkpoint() + public virtual IDictionary> gather_saveables_for_checkpoint() { if (saveable_object_util.trackable_has_serialize_to_tensor(this)) { @@ -251,6 +251,11 @@ public virtual IDictionary>> s { throw new NotImplementedException(); } + + public virtual IDictionary _restore_from_tensors(IDictionary>> restored_tensors) + { + throw new NotImplementedException(); + } } public record class TrackableReference(string Name, Trackable Refer); diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 756024db4..4005d5640 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -6,6 +6,8 @@ using static Tensorflow.Binding; using System.Collections.Generic; using Tensorflow.ModelSaving; +using System.Diagnostics; +using Tensorflow.Checkpoint; namespace Tensorflow { @@ -13,6 +15,7 @@ public class BaseResourceVariable : DisposableTrackableObject { protected string _name; public virtual string Name => _handle_name; + public virtual string SharedName => _name; protected TF_DataType _dtype; public TF_DataType dtype => _dtype; protected string _handle_name; @@ -50,6 +53,7 @@ public class BaseResourceVariable : DisposableTrackableObject public Graph Graph => handle.graph; public string Device => handle.Device; EagerResourceDeleter eager_resource_deleter; + public VariableAggregation Aggregation { get; protected set; } = VariableAggregation.None; public BaseResourceVariable() { @@ -77,6 +81,11 @@ public void __init__(bool trainable = true, _handle = handle.EagerTensorHandle.DangerousGetHandle(); eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device); } + else if(handle is null) + { + // TODO: fix this dangerous change. + _handle = IntPtr.Zero; + } else { _handle = handle.Handle == null ? IntPtr.Zero : handle.Handle.DangerousGetHandle(); @@ -247,5 +256,60 @@ public Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = else return value(); } + + public override (IDictionary, IDictionary) map_resources(SaveOptions save_options) + { + BaseResourceVariable new_variable; + if (save_options.experimental_variable_policy.save_variable_devices()) + { + tf.device(this.Device); + Debug.Assert(this is ResourceVariable); + new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); + } + else + { + new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); + } + Dictionary obj_map = new(); + Dictionary resource_map = new(); + obj_map[this] = new_variable; + resource_map[this.handle] = new_variable.handle; + return (obj_map, resource_map); + } + + /// + /// Writes additional information of the variable into the SavedObject proto. + /// ubclasses of ResourceVariables could choose to override this method to + /// customize extra information to provide when saving a SavedModel. + /// + /// + /// + public virtual void write_object_proto(SavedObject proto, SaveOptions options) + { + resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); + } + + public override IDictionary> gather_saveables_for_checkpoint() + { + var res = new Dictionary>(); + res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; + return res; + } + + public Tensor is_initialized(string name = null) + { + return gen_resource_variable_ops.var_is_initialized_op(this.handle, name); + } + + public Tensor read_value_no_copy() + { + Tensor value = null; + tf_with(ops.name_scope("Read"), _ => + { + // TODO: `no_copy = true`. + value = _read_variable_op(); + }); + return array_ops.identity(value); + } } } diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index 6093f8106..1645d7130 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -41,6 +41,7 @@ public ResourceVariable(object initial_value = null, VariableAggregation aggregation = VariableAggregation.None, Shape shape = null) { + Aggregation = aggregation; if (variable_def != null) { if (initial_value != null) @@ -237,12 +238,5 @@ public NDArray eval(Session session = null) { return _graph_element.eval(session); } - - public override IDictionary> gather_saveables_for_checkpoint() - { - var res = new Dictionary>(); - res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; - return res; - } } } diff --git a/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs b/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs new file mode 100644 index 000000000..6c0349950 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs @@ -0,0 +1,70 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Gradients; +using static Tensorflow.Binding; + +namespace Tensorflow.Variables +{ + /// + /// A variable with no initializer. + /// + public sealed class UninitializedVariable: BaseResourceVariable + { + // TODO: complete the arg list. + public UninitializedVariable( + bool trainable = true, + string caching_device = "", + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid, + VariableAggregation aggregation = VariableAggregation.None, + Shape shape = null, + Tensor extra_handle_data = null) + { + string unique_id = ""; + string handle_name = ""; + tf_with(ops.init_scope(), (x) => + { + _in_graph_mode = !tf.Context.executing_eagerly(); + tf_with(ops.name_scope(name, "Variable", skip_on_eager: false), name => + { + handle_name = ops.name_from_scope_name(name); + string? shared_name; + if (_in_graph_mode) + { + shared_name = handle_name; + unique_id = shared_name; + } + else + { + unique_id = $"{handle_name}-{ops.uid()}"; + shared_name = null; + } + var handle = resource_variable_ops.variable_handle_from_shape_and_dtype( + shape, dtype, shared_name, name, _in_graph_mode, extra_handle_data); + // skip the assignment of `handle._parent_trackable` because of lack of API. + // skip the assignment of `handle._name` and `handle._unique_id` because of accessability. + + if (_in_graph_mode) + { + tf_with(ops.name_scope("Read"), _ => + { + tf.device(handle.Device); + var value = gen_resource_variable_ops.read_variable_op(handle, dtype); + // _maybe_set_handle_data(dtype, handle, value) + _graph_element = value; + }); + ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); + } + else + { + _graph_element = null; + } + }); + }); + _shape = shape; + _dtype = dtype; + base.__init__(trainable, handle, unique_id: unique_id, handle_name: handle_name); + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs index 23c40fbff..a221444b7 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs @@ -55,7 +55,7 @@ ModelConfig get_network_config() } } - var layer_config = generic_utils.serialize_keras_object(layer); + var layer_config = generic_utils.serialize_layer_to_config(layer); layer_config.Name = layer.Name; layer_config.InboundNodes = filtered_inbound_nodes; layer_configs.Add(layer_config); diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs index ffb6f71bc..fc405d872 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs @@ -8,7 +8,7 @@ namespace Tensorflow.Keras.Engine; public abstract partial class Layer { - public LayerSavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); + public virtual SavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; diff --git a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs index 6b064716f..03b4b742a 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs @@ -18,6 +18,7 @@ limitations under the License. using Tensorflow.Framework.Models; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving.SavedModel; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -105,5 +106,7 @@ public static InputLayer from_config(LayerArgs args) { return new InputLayer(args as InputLayerArgs); } + + public override SavedModelSaver TrackableSavedModelSaver => new InputLayerSavedModelSaver(this); } } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs index 6a6e418cf..4ff8f02f0 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -55,6 +55,7 @@ public static void Save(Model model, string filepath, bool overwrite, bool inclu var metadata = generate_keras_metadata(saved_nodes, node_paths); File.WriteAllBytes(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToByteArray()); + //File.WriteAllText(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToString()); if (!include_optimizer) { @@ -100,7 +101,8 @@ public static SavedMetadata generate_keras_metadata(IList saved_nodes Identifier = layer.ObjectIdentifier, Metadata = layer.TrackingMetadata }; - + + metadata.Nodes.Add(saved_object); } return metadata; diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs index fc7eab3a3..f7e1bf45c 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs @@ -24,26 +24,26 @@ public static IDictionary wrap_layer_objects(Layer layer, IDi // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. // TODO: change the inherits of `Variable` and revise the implmentation. - var variables = layer.Variables.Select(x => + var variables = TrackableDataStructure.wrap_or_unwrap(layer.Variables.Select(x => { if (x is ResourceVariable or RefVariable) return (Trackable)x; else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); - }); - var trainable_variables = layer.TrainableVariables.Select(x => + })); + var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x => { if (x is ResourceVariable or RefVariable) return (Trackable)x; else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); - }); - var non_trainable_variables = layer.non_trainable_variables.Select(x => - { - if (x is ResourceVariable or RefVariable) return (Trackable)x; - else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); - }); + })); + var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + })); Dictionary res = new(); - res["variables"] = TrackableDataStructure.wrap_or_unwrap(variables); - res["trainable_variables"] = TrackableDataStructure.wrap_or_unwrap(trainable_variables); - res["non_trainable_variables"] = TrackableDataStructure.wrap_or_unwrap(non_trainable_variables); + res["variables"] = variables; + res["trainable_variables"] = trainable_variables; + res["non_trainable_variables"] = non_trainable_variables; res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); return res; diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs index 0235f87bd..60c4ee5b8 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs @@ -8,7 +8,7 @@ namespace Tensorflow.Keras.Saving.SavedModel; public abstract class SavedModelSaver { - private Trackable _obj; + protected Trackable _obj; public SavedModelSaver(Trackable obj) { _obj = obj; diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs index b092b5950..655127af9 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs @@ -2,6 +2,7 @@ using Newtonsoft.Json; using Newtonsoft.Json.Linq; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; using Tensorflow.Keras.Utils; using Tensorflow.Train; @@ -9,10 +10,11 @@ namespace Tensorflow.Keras.Saving.SavedModel; public class LayerSavedModelSaver: SavedModelSaver { - private Layer _obj; + private Layer _layer; public LayerSavedModelSaver(Layer obj): base(obj) { _obj = obj; + _layer = obj; } public override string ObjectIdentifier { @@ -68,8 +70,8 @@ protected ISerializedAttributes get_serialized_attributes(IDictionary private (IDictionary, IDictionary) get_serialized_attributes_internal(IDictionary> serialization_cache) { - var objects = KerasSavedModelUtils.wrap_layer_objects(_obj, serialization_cache); - var functions = KerasSavedModelUtils.wrap_layer_functions(_obj, serialization_cache); + var objects = KerasSavedModelUtils.wrap_layer_objects(_layer, serialization_cache); + var functions = KerasSavedModelUtils.wrap_layer_functions(_layer, serialization_cache); functions["_default_save_signature"] = null; @@ -81,17 +83,18 @@ public override string TrackingMetadata get { JObject metadata = new JObject(); - metadata["name"] = _obj.Name; - metadata["trainable"] = _obj.Trainable; + metadata["name"] = _layer.Name; + metadata["trainable"] = _layer.Trainable; // metadata["expects_training_arg"] = _obj._expects_training_arg; // metadata["dtype"] = policy.serialize(_obj._dtype_policy) - metadata["batch_input_shape"] = _obj.BatchInputShape is null ? null : JToken.FromObject(_obj.BatchInputShape); + metadata["batch_input_shape"] = _layer.BatchInputShape is null ? null : JToken.FromObject(_layer.BatchInputShape); // metadata["stateful"] = _obj.stateful; // metadata["must_restore_from_config"] = _obj.must_restore_from_config; // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; - metadata["autocast"] = _obj.AutoCast; - - metadata.Merge(JObject.FromObject(get_serialized(_obj)), new JsonMergeSettings + metadata["autocast"] = _layer.AutoCast; + + var temp = JObject.FromObject(get_serialized(_layer)); + metadata.Merge(temp, new JsonMergeSettings { // Handle conflicts by using values from obj2 MergeArrayHandling = MergeArrayHandling.Merge @@ -108,4 +111,46 @@ public static IDictionary get_serialized(Layer obj) return new Dictionary(); //return generic_utils.serialize_keras_object(obj); } +} + +public class InputLayerSavedModelSaver: SavedModelSaver +{ + public InputLayerSavedModelSaver(Layer obj) : base(obj) + { + + } + public override string ObjectIdentifier => Constants.INPUT_LAYER_IDENTIFIER; + + public override IDictionary functions_to_serialize(IDictionary> serialization_cache) + { + return new Dictionary(); + } + + public override IDictionary objects_to_serialize(IDictionary> serialization_cache) + { + return new Dictionary(); + } + + public override string TrackingMetadata + { + get + { + if(_obj is not Layer) + { + throw new TypeError($"The type {_obj.GetType()} cannot be recognized as an input layer."); + } + var layer = (Layer)_obj; + var info = new + { + class_name = layer.GetType().Name, + name = layer.Name, + dtype = layer.DType, + //sparse = layer.sparse, + //ragged = layer.ragged, + batch_input_shape = layer.BatchInputShape, + config = layer.get_config() + }; + return JsonConvert.SerializeObject(info); + } + } } \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Utils/generic_utils.cs b/src/TensorFlowNET.Keras/Utils/generic_utils.cs index c2839cdc7..68903eb2c 100644 --- a/src/TensorFlowNET.Keras/Utils/generic_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/generic_utils.cs @@ -15,6 +15,8 @@ limitations under the License. ******************************************************************************/ using System; +using System.Collections; +using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.Saving; @@ -22,7 +24,12 @@ namespace Tensorflow.Keras.Utils { public class generic_utils { - public static LayerConfig serialize_keras_object(ILayer instance) + /// + /// This method does not have corresponding method in python. It's close to `serialize_keras_object`. + /// + /// + /// + public static LayerConfig serialize_layer_to_config(ILayer instance) { var config = instance.get_config(); return new LayerConfig From 2ab0bdbc8690b0048eaeb5ab5069e042b1a88d25 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Fri, 3 Feb 2023 19:08:50 +0800 Subject: [PATCH 10/15] Add more implementations to the keras part of pb model save. --- .../ArgsDefinition/Activation/SoftmaxArgs.cs | 17 ++- .../ArgsDefinition/AutoSerializeLayerArgs.cs | 19 +++ .../Keras/ArgsDefinition/Core/DenseArgs.cs | 42 +++++- .../ArgsDefinition/Core/InputLayerArgs.cs | 17 ++- .../Keras/ArgsDefinition/DataAdapterArgs.cs | 3 +- .../Keras/ArgsDefinition/DataHandlerArgs.cs | 3 +- .../Keras/ArgsDefinition/LayerArgs.cs | 31 +++-- .../Keras/ArgsDefinition/NodeArgs.cs | 6 +- .../Keras/ArgsDefinition/OptimizerV2Args.cs | 6 +- .../ArgsDefinition/Reshaping/FlattenArgs.cs | 7 +- .../CustomizedActivationJsonConverter.cs | 50 +++++++ .../Common/CustomizedAxisJsonConverter.cs | 48 +++++++ .../CustomizedNodeConfigJsonConverter.cs | 73 ++++++++++ .../Common/CustomizedShapeJsonConverter.cs | 67 ++++++++++ .../Keras/Engine/InputSpec.cs | 31 ++++- src/TensorFlowNET.Core/Keras/Layers/ILayer.cs | 5 +- .../Keras/Saving/IKerasConfig.cs | 15 +++ .../Keras/Saving/LayerConfig.cs | 9 +- .../Keras/Saving/ModelConfig.cs | 9 +- .../Keras/Saving/NodeConfig.cs | 7 +- .../Keras/Saving/TensorShapeConfig.cs | 21 +++ src/TensorFlowNET.Core/NumPy/Axis.cs | 11 +- src/TensorFlowNET.Core/Numpy/Shape.cs | 3 + .../Operations/Initializers/Constant.cs | 10 ++ .../Operations/Initializers/GlorotUniform.cs | 10 +- .../Operations/Initializers/IInitializer.cs | 7 + .../Operations/Initializers/Ones.cs | 7 + .../Operations/Initializers/Orthogonal.cs | 5 + .../Operations/Initializers/RandomNormal.cs | 12 ++ .../Operations/Initializers/RandomUniform.cs | 12 ++ .../Initializers/TruncatedNormal.cs | 11 ++ .../Initializers/VarianceScaling.cs | 13 ++ .../Operations/Initializers/Zeros.cs | 5 + .../Operations/NnOps/RNNCell.cs | 5 +- .../Tensorflow.Binding.csproj | 1 + src/TensorFlowNET.Core/Tensors/dtypes.cs | 18 +++ .../{ITrackable.cs => IWithTrackable.cs} | 2 +- src/TensorFlowNET.Core/Training/Trackable.cs | 2 +- .../Engine/Functional.GetConfig.cs | 31 +++-- src/TensorFlowNET.Keras/Engine/Functional.cs | 18 +++ src/TensorFlowNET.Keras/Engine/Layer.cs | 6 +- src/TensorFlowNET.Keras/Engine/Model.Save.cs | 6 +- .../Layers/Activation/ELU.cs | 1 + .../Layers/Activation/Exponential.cs | 1 + .../Layers/Activation/SELU.cs | 9 +- .../Layers/Attention/Attention.cs | 3 +- .../Layers/Attention/BaseDenseAttention.cs | 3 +- .../Layers/Convolution/Conv2DTranspose.cs | 1 + .../Layers/Convolution/Convolutional.cs | 1 + src/TensorFlowNET.Keras/Layers/Core/Dense.cs | 1 + .../Layers/Core/Embedding.cs | 1 + .../Layers/Cropping/Cropping1D.cs | 1 + .../Layers/Cropping/Cropping2D.cs | 3 +- .../Layers/Cropping/Cropping3D.cs | 3 +- src/TensorFlowNET.Keras/Layers/LayersApi.cs | 4 +- .../Layers/Merging/Concatenate.cs | 1 + .../Layers/Merging/Merge.cs | 1 + .../Normalization/BatchNormalization.cs | 1 + .../Normalization/LayerNormalization.cs | 1 + .../Layers/Reshaping/Permute.cs | 1 + .../Layers/Rnn/SimpleRNN.cs | 1 + .../Layers/Rnn/StackedRNNCells.cs | 3 +- .../Saving/SavedModel/Save.cs | 2 +- .../Saving/SavedModel/layer_serialization.cs | 33 +++-- .../Saving/TensorShapeConfig.cs | 15 --- .../Saving/serialization.cs | 125 ++++++++++++++++++ .../Utils/base_layer_utils.cs | 2 +- .../Utils/generic_utils.cs | 14 +- .../Layers/ModelSaveTest.cs | 5 +- test/TensorFlowNET.Keras.UnitTest/SaveTest.cs | 40 ++++-- 70 files changed, 849 insertions(+), 109 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs create mode 100644 src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs create mode 100644 src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs create mode 100644 src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs create mode 100644 src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs create mode 100644 src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs rename src/TensorFlowNET.Core/Training/{ITrackable.cs => IWithTrackable.cs} (82%) delete mode 100644 src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs create mode 100644 src/TensorFlowNET.Keras/Saving/serialization.cs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs index ca35d75d5..a37973bc6 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs @@ -1,9 +1,18 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; namespace Tensorflow.Keras.ArgsDefinition { - public class SoftmaxArgs : LayerArgs { - public Axis axis { get; set; } = -1; - } + public class SoftmaxArgs : LayerArgs + { + [JsonProperty("axis")] + public Axis axis { get; set; } = -1; + [JsonProperty("name")] + public override string Name { get => base.Name; set => base.Name = value; } + [JsonProperty("trainable")] + public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } + [JsonProperty("dtype")] + public override TF_DataType DType { get => base.DType; set => base.DType = value; } + } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs new file mode 100644 index 000000000..66b34a1ae --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs @@ -0,0 +1,19 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class AutoSerializeLayerArgs: LayerArgs + { + [JsonProperty("name")] + public override string Name { get => base.Name; set => base.Name = value; } + [JsonProperty("dtype")] + public override TF_DataType DType { get => base.DType; set => base.DType = value; } + [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] + public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } + [JsonProperty("trainable")] + public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs index e9b3c2fd9..8f4facbd4 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs @@ -1,13 +1,18 @@ -using System; +using Newtonsoft.Json; +using System; +using System.Xml.Linq; +using Tensorflow.Operations.Initializers; using static Tensorflow.Binding; namespace Tensorflow.Keras.ArgsDefinition { + // TODO: `activity_regularizer` public class DenseArgs : LayerArgs { /// /// Positive integer, dimensionality of the output space. /// + [JsonProperty("units")] public int Units { get; set; } /// @@ -15,39 +20,74 @@ public class DenseArgs : LayerArgs /// public Activation Activation { get; set; } + private string _activationName; + [JsonProperty("activation")] + public string ActivationName + { + get + { + if (string.IsNullOrEmpty(_activationName)) + { + return Activation.Method.Name; + } + else + { + return _activationName; + } + } + set + { + _activationName = value; + } + } + /// /// Whether the layer uses a bias vector. /// + [JsonProperty("use_bias")] public bool UseBias { get; set; } = true; /// /// Initializer for the `kernel` weights matrix. /// + [JsonProperty("kernel_initializer")] public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; /// /// Initializer for the bias vector. /// + [JsonProperty("bias_initializer")] public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; /// /// Regularizer function applied to the `kernel` weights matrix. /// + [JsonProperty("kernel_regularizer")] public IRegularizer KernelRegularizer { get; set; } /// /// Regularizer function applied to the bias vector. /// + [JsonProperty("bias_regularizer")] public IRegularizer BiasRegularizer { get; set; } /// /// Constraint function applied to the `kernel` weights matrix. /// + [JsonProperty("kernel_constraint")] public Action KernelConstraint { get; set; } /// /// Constraint function applied to the bias vector. /// + [JsonProperty("bias_constraint")] public Action BiasConstraint { get; set; } + + [JsonProperty("name")] + public override string Name { get => base.Name; set => base.Name = value; } + [JsonProperty("dtype")] + public override TF_DataType DType { get => base.DType; set => base.DType = value; } + [JsonProperty("trainable")] + public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs index 723109c27..be43e0a62 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs @@ -1,9 +1,22 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; +using Newtonsoft.Json.Serialization; +using Tensorflow.Keras.Common; + +namespace Tensorflow.Keras.ArgsDefinition { public class InputLayerArgs : LayerArgs { + [JsonIgnore] public Tensor InputTensor { get; set; } - public bool Sparse { get; set; } + [JsonProperty("sparse")] + public virtual bool Sparse { get; set; } + [JsonProperty("ragged")] public bool Ragged { get; set; } + [JsonProperty("name")] + public override string Name { get => base.Name; set => base.Name = value; } + [JsonProperty("dtype")] + public override TF_DataType DType { get => base.DType; set => base.DType = value; } + [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] + public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs index f3cca438f..8ce1ec655 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs @@ -1,8 +1,9 @@ using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.ArgsDefinition { - public class DataAdapterArgs + public class DataAdapterArgs: IKerasConfig { public Tensor X { get; set; } public Tensor Y { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs index b6e6849bc..fd603a85e 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs @@ -1,8 +1,9 @@ using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.ArgsDefinition { - public class DataHandlerArgs + public class DataHandlerArgs: IKerasConfig { public Tensor X { get; set; } public Tensor Y { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs index 4df4fb2b4..febf14176 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs @@ -1,51 +1,54 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.ArgsDefinition { - public class LayerArgs + [JsonObject(MemberSerialization.OptIn)] + public class LayerArgs: IKerasConfig { /// /// Indicates whether the layer's weights are updated during training /// and whether the layer's updates are run during training. /// - public bool Trainable { get; set; } = true; - - public string Name { get; set; } + public virtual bool Trainable { get; set; } = true; + public virtual string Name { get; set; } /// /// Only applicable to input layers. /// - public TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT; + public virtual TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT; /// /// Whether the `call` method can be used to build a TF graph without issues. /// This attribute has no effect if the model is created using the Functional /// API. Instead, `model.dynamic` is determined based on the internal layers. /// - public bool Dynamic { get; set; } = false; + public virtual bool Dynamic { get; set; } = false; /// /// Only applicable to input layers. /// - public Shape InputShape { get; set; } + public virtual Shape InputShape { get; set; } /// /// Only applicable to input layers. /// - public Shape BatchInputShape { get; set; } + public virtual Shape BatchInputShape { get; set; } - public int BatchSize { get; set; } = -1; + public virtual int BatchSize { get; set; } = -1; /// /// Initial weight values. /// - public float[] Weights { get; set; } + public virtual float[] Weights { get; set; } /// /// Regularizer function applied to the output of the layer(its "activation"). /// - public IRegularizer ActivityRegularizer { get; set; } + public virtual IRegularizer ActivityRegularizer { get; set; } - public bool Autocast { get; set; } + public virtual bool Autocast { get; set; } - public bool IsFromConfig { get; set; } + public virtual bool IsFromConfig { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs index 0d9e26ac4..ad55ff612 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs @@ -1,6 +1,8 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.ArgsDefinition { - public class NodeArgs + public class NodeArgs: IKerasConfig { public ILayer[] InboundLayers { get; set; } public int[] NodeIndices { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs index e2a0e43c8..6256fd329 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs @@ -1,6 +1,8 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.ArgsDefinition { - public class OptimizerV2Args + public class OptimizerV2Args: IKerasConfig { public string Name { get; set; } public float LearningRate { get; set; } = 0.001f; diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs index c2b48cc2f..91ffc2058 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs @@ -1,7 +1,10 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition { - public class FlattenArgs : LayerArgs + public class FlattenArgs : AutoSerializeLayerArgs { + [JsonProperty("data_format")] public string DataFormat { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs new file mode 100644 index 000000000..1bc13caf3 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs @@ -0,0 +1,50 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Converters; +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Common +{ + public class CustomizedActivationJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(Activation); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if (value is null) + { + var token = JToken.FromObject(""); + token.WriteTo(writer); + } + else if (value is not Activation) + { + throw new TypeError($"Unable to use `CustomizedActivationJsonConverter` to serialize the type {value.GetType()}."); + } + else + { + var token = JToken.FromObject((value as Activation)!.GetType().Name); + token.WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + throw new NotImplementedException(); + //var dims = serializer.Deserialize(reader, typeof(string)); + //if (dims is null) + //{ + // throw new ValueError("Cannot deserialize 'null' to `Activation`."); + //} + //return new Shape((long[])(dims!)); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs new file mode 100644 index 000000000..4e190605c --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs @@ -0,0 +1,48 @@ +using Newtonsoft.Json.Linq; +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Common +{ + public class CustomizedAxisJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(Axis); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if (value is null) + { + var token = JToken.FromObject(new int[] { }); + token.WriteTo(writer); + } + else if (value is not Axis) + { + throw new TypeError($"Unable to use `CustomizedAxisJsonConverter` to serialize the type {value.GetType()}."); + } + else + { + var token = JToken.FromObject((value as Axis)!.axis); + token.WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + var axis = serializer.Deserialize(reader, typeof(long[])); + if (axis is null) + { + throw new ValueError("Cannot deserialize 'null' to `Axis`."); + } + return new Axis((int[])(axis!)); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs new file mode 100644 index 000000000..1ad19fc89 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs @@ -0,0 +1,73 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Converters; +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.Common +{ + public class CustomizedNodeConfigJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(NodeConfig); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if (value is null) + { + var token = JToken.FromObject(null); + token.WriteTo(writer); + } + else if (value is not NodeConfig) + { + throw new TypeError($"Unable to use `CustomizedNodeConfigJsonConverter` to serialize the type {value.GetType()}."); + } + else + { + var config = value as NodeConfig; + var token = JToken.FromObject(new object[] { config!.Name, config.NodeIndex, config.TensorIndex }); + token.WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + var values = serializer.Deserialize(reader, typeof(object[])) as object[]; + if (values is null) + { + throw new ValueError("Cannot deserialize 'null' to `Shape`."); + } + if(values.Length != 3) + { + throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); + } + if (values[0] is not string) + { + throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`"); + } + if (values[1] is not int) + { + throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`"); + } + if (values[2] is not int) + { + throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`"); + } + return new NodeConfig() + { + Name = values[0] as string, + NodeIndex = (int)values[1], + TensorIndex = (int)values[2] + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs new file mode 100644 index 000000000..300cb2f28 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs @@ -0,0 +1,67 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Converters; +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Common +{ + public class CustomizedShapeJsonConverter: JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(Shape); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if(value is null) + { + var token = JToken.FromObject(null); + token.WriteTo(writer); + } + else if(value is not Shape) + { + throw new TypeError($"Unable to use `CustomizedShapeJsonConverter` to serialize the type {value.GetType()}."); + } + else + { + var shape = (value as Shape)!; + long?[] dims = new long?[shape.ndim]; + for(int i = 0; i < dims.Length; i++) + { + if (shape.dims[i] == -1) + { + dims[i] = null; + } + else + { + dims[i] = shape.dims[i]; + } + } + var token = JToken.FromObject(dims); + token.WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + var dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; + if(dims is null) + { + throw new ValueError("Cannot deserialize 'null' to `Shape`."); + } + long[] convertedDims = new long[dims.Length]; + for(int i = 0; i < dims.Length; i++) + { + convertedDims[i] = dims[i] ?? (-1); + } + return new Shape(convertedDims); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs index 7280594b7..6743935c8 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs @@ -16,23 +16,27 @@ limitations under the License. using System.Collections.Generic; using System.Linq; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Engine { /// /// Specifies the ndim, dtype and shape of every input to a layer. /// - public class InputSpec + public class InputSpec: IKerasConfigable { public int? ndim; + public int? max_ndim; public int? min_ndim; Dictionary axes; Shape shape; + TF_DataType dtype; public int[] AllAxisDim; public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, int? ndim = null, int? min_ndim = null, + int? max_ndim = null, Dictionary axes = null, Shape shape = null) { @@ -41,7 +45,9 @@ public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, axes = new Dictionary(); this.axes = axes; this.min_ndim = min_ndim; + this.max_ndim = max_ndim; this.shape = shape; + this.dtype = dtype; if (ndim == null && shape != null) this.ndim = shape.ndim; @@ -49,7 +55,30 @@ public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, AllAxisDim = axes.Select(x => x.Value).ToArray(); } + public IKerasConfig get_config() + { + return new Config() + { + DType = dtype == TF_DataType.DtInvalid ? null : dtype, + Shape = shape, + Ndim = ndim, + MinNdim = min_ndim, + MaxNdim = max_ndim, + Axes = axes.ToDictionary(x => x.Key.ToString(), x => x.Value) + }; + } + public override string ToString() => $"ndim={ndim}, min_ndim={min_ndim}, axes={axes.Count}"; + + public class Config: IKerasConfig + { + public TF_DataType? DType { get; set; } + public Shape Shape { get; set; } + public int? Ndim { get; set; } + public int? MinNdim { get;set; } + public int? MaxNdim { get;set; } + public IDictionary Axes { get; set; } + } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index f1ca56325..ebf3358d7 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -1,11 +1,12 @@ using System.Collections.Generic; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; using Tensorflow.Training; namespace Tensorflow.Keras { - public interface ILayer: ITrackable + public interface ILayer: IWithTrackable, IKerasConfigable { string Name { get; } bool Trainable { get; } @@ -19,8 +20,8 @@ public interface ILayer: ITrackable List NonTrainableWeights { get; } Shape OutputShape { get; } Shape BatchInputShape { get; } + TensorShapeConfig BuildInputShape { get; } TF_DataType DType { get; } int count_params(); - LayerArgs get_config(); } } diff --git a/src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs new file mode 100644 index 000000000..1217e1e52 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Saving +{ + public interface IKerasConfig + { + } + + public interface IKerasConfigable + { + IKerasConfig get_config(); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs index b8b8cab40..4ce290c83 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs @@ -1,4 +1,5 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; using Tensorflow.Keras.ArgsDefinition; @@ -6,11 +7,15 @@ namespace Tensorflow.Keras.Saving { - public class LayerConfig + public class LayerConfig: IKerasConfig { + [JsonProperty("name")] public string Name { get; set; } + [JsonProperty("class_name")] public string ClassName { get; set; } + [JsonProperty("config")] public LayerArgs Config { get; set; } + [JsonProperty("inbound_nodes")] public List InboundNodes { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs index abfb235be..cac19180f 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs @@ -1,15 +1,20 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; using Tensorflow.Keras.Engine; namespace Tensorflow.Keras.Saving { - public class ModelConfig + public class ModelConfig : IKerasConfig { + [JsonProperty("name")] public string Name { get; set; } + [JsonProperty("layers")] public List Layers { get; set; } + [JsonProperty("input_layers")] public List InputLayers { get; set; } + [JsonProperty("output_layers")] public List OutputLayers { get; set; } public override string ToString() diff --git a/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs index 3132248ef..20e2fef59 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs @@ -1,10 +1,13 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Keras.Common; namespace Tensorflow.Keras.Saving { - public class NodeConfig + [JsonConverter(typeof(CustomizedNodeConfigJsonConverter))] + public class NodeConfig : IKerasConfig { public string Name { get; set; } public int NodeIndex { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs new file mode 100644 index 000000000..7abcfde26 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs @@ -0,0 +1,21 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Tensorflow.Keras.Saving +{ + public class TensorShapeConfig + { + [JsonProperty("class_name")] + public string ClassName { get; set; } = "TensorShape"; + [JsonProperty("items")] + public long?[] Items { get; set; } + + public static implicit operator Shape(TensorShapeConfig shape) + => shape == null ? null : new Shape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray()); + + public static implicit operator TensorShapeConfig(Shape shape) + => new TensorShapeConfig() { Items = shape.dims.Select(x => x == -1 ? null : x).ToArray() }; + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Axis.cs b/src/TensorFlowNET.Core/NumPy/Axis.cs index 6c7189df1..709ca9b27 100644 --- a/src/TensorFlowNET.Core/NumPy/Axis.cs +++ b/src/TensorFlowNET.Core/NumPy/Axis.cs @@ -14,20 +14,29 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using Newtonsoft.Json; using System; using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Keras.Common; namespace Tensorflow { - public record Axis(params int[] axis) + [JsonConverter(typeof(CustomizedAxisJsonConverter))] + public class Axis { + public int[] axis { get; set; } public int size => axis == null ? -1 : axis.Length; public bool IsScalar { get; init; } public int this[int index] => axis[index]; + public Axis(params int[] axis) + { + this.axis = axis; + } + public static implicit operator int[]?(Axis axis) => axis?.axis; diff --git a/src/TensorFlowNET.Core/Numpy/Shape.cs b/src/TensorFlowNET.Core/Numpy/Shape.cs index bc79fefca..ecf735869 100644 --- a/src/TensorFlowNET.Core/Numpy/Shape.cs +++ b/src/TensorFlowNET.Core/Numpy/Shape.cs @@ -14,14 +14,17 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using Newtonsoft.Json; using System; using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Keras.Common; using Tensorflow.NumPy; namespace Tensorflow { + [JsonConverter(typeof(CustomizedShapeJsonConverter))] public class Shape { public int ndim => _dims == null ? -1 : _dims.Length; diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs index fdcb5aff0..e7e9955c0 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs @@ -14,6 +14,8 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class Constant : IInitializer @@ -22,11 +24,19 @@ public class Constant : IInitializer T value; bool _verify_shape; + private readonly Dictionary _config; + + public string ClassName => "Constant"; + public IDictionary Config => _config; + public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) { this.value = value; this.dtype = dtype; _verify_shape = verify_shape; + + _config = new Dictionary(); + _config["value"] = this.value; } public Tensor Apply(InitializerArgs args) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs index d97d88308..def1cb7a0 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs @@ -14,10 +14,17 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class GlorotUniform : VarianceScaling { + private readonly Dictionary _config; + + public override string ClassName => "GlorotUniform"; + public override IDictionary Config => _config; + public GlorotUniform(float scale = 1.0f, string mode = "FAN_AVG", bool uniform = true, @@ -28,7 +35,8 @@ public GlorotUniform(float scale = 1.0f, seed: seed, dtype: dtype) { - + _config = new Dictionary(); + _config["seed"] = _seed; } } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs index 50d4d5037..9748b1004 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs @@ -14,10 +14,17 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using Newtonsoft.Json; +using System.Collections.Generic; + namespace Tensorflow { public interface IInitializer { + [JsonProperty("class_name")] + string ClassName { get; } + [JsonProperty("config")] + IDictionary Config { get; } Tensor Apply(InitializerArgs args); } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs index 02d3c93b2..3077a1e0e 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs @@ -14,12 +14,19 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class Ones : IInitializer { private TF_DataType dtype; + private readonly Dictionary _config; + + public string ClassName => "Ones"; + public IDictionary Config => new Dictionary(); + public Ones(TF_DataType dtype = TF_DataType.TF_FLOAT) { this.dtype = dtype; diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs index 254a7ee7b..cdc1c3edf 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs @@ -1,9 +1,14 @@ using System; +using System.Collections.Generic; namespace Tensorflow.Operations.Initializers { public class Orthogonal : IInitializer { + private readonly Dictionary _config; + + public string ClassName => "Orthogonal"; + public IDictionary Config => throw new NotImplementedException(); public Tensor Apply(InitializerArgs args) { throw new NotImplementedException(); diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs index 029b311bb..21fa7e2b2 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs @@ -14,6 +14,8 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class RandomNormal : IInitializer @@ -23,6 +25,11 @@ public class RandomNormal : IInitializer private int? seed; private TF_DataType dtype; + private readonly Dictionary _config; + + public string ClassName => "RandomNormal"; + public IDictionary Config => _config; + public RandomNormal(float mean = 0.0f, float stddev = 0.05f, int? seed = null, @@ -32,6 +39,11 @@ public RandomNormal(float mean = 0.0f, this.stddev = stddev; this.seed = seed; this.dtype = dtype; + + _config = new Dictionary(); + _config["mean"] = this.mean; + _config["stddev"] = this.stddev; + _config["seed"] = this.seed; } public Tensor Apply(InitializerArgs args) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs index a49d59212..87404708c 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs @@ -14,6 +14,8 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class RandomUniform : IInitializer @@ -23,12 +25,22 @@ public class RandomUniform : IInitializer private float maxval; private TF_DataType dtype; + private readonly Dictionary _config; + + public string ClassName => "RandomUniform"; + public IDictionary Config => _config; + public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT, float minval = -0.05f, float maxval = 0.05f, int? seed = null) { this.dtype = dtype; this.minval = minval; this.maxval = maxval; this.seed = seed; + + _config = new Dictionary(); + _config["minval"] = this.minval; + _config["maxval"] = this.maxval; + _config["seed"] = this.seed; } public Tensor Apply(InitializerArgs args) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs index 048c11e7a..c1c3e9996 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs @@ -14,6 +14,8 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class TruncatedNormal : IInitializer @@ -23,6 +25,11 @@ public class TruncatedNormal : IInitializer private int? seed; private TF_DataType dtype; + private readonly Dictionary _config; + + public string ClassName => "TruncatedNormal"; + public IDictionary Config => _config; + public TruncatedNormal(float mean = 0.0f, float stddev = 1.0f, int? seed = null, @@ -32,6 +39,10 @@ public TruncatedNormal(float mean = 0.0f, this.stddev = stddev; this.seed = seed; this.dtype = dtype; + _config = new Dictionary(); + _config["mean"] = this.mean; + _config["stddev"] = this.stddev; + _config["seed"] = this.seed; } public Tensor Apply(InitializerArgs args) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs index d313f4c9a..f104e8e83 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs @@ -15,7 +15,9 @@ limitations under the License. ******************************************************************************/ using System; +using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; namespace Tensorflow.Operations.Initializers { @@ -30,6 +32,11 @@ public class VarianceScaling : IInitializer protected int? _seed; protected TF_DataType _dtype; protected bool _uniform; + private readonly Dictionary _config; + + public virtual string ClassName => "VarianceScaling"; + + public virtual IDictionary Config => _config; public VarianceScaling(float factor = 2.0f, string mode = "FAN_IN", @@ -50,6 +57,12 @@ public VarianceScaling(float factor = 2.0f, _seed = seed; _dtype = dtype; _uniform = uniform; + + _config = new(); + _config["scale"] = _scale; + _config["mode"] = _mode; + _config["distribution"] = _distribution; + _config["seed"] = _seed; } public Tensor Apply(InitializerArgs args) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs index 5d045292f..c4ed25a17 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs @@ -14,6 +14,8 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class Zeros : IInitializer @@ -21,6 +23,9 @@ public class Zeros : IInitializer Shape shape; TF_DataType dtype; + public string ClassName => "Zeros"; + public IDictionary Config => new Dictionary(); + public Zeros(Shape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT) { this.shape = shape; diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 734f26089..c29ed47be 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -20,6 +20,7 @@ limitations under the License. using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; using Tensorflow.Operations; using Tensorflow.Train; using Tensorflow.Util; @@ -76,6 +77,8 @@ public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell public Shape BatchInputShape => throw new NotImplementedException(); + public TensorShapeConfig BuildInputShape => throw new NotImplementedException(); + public TF_DataType DType => throw new NotImplementedException(); protected bool built = false; public bool Built => built; @@ -144,7 +147,7 @@ public int count_params() throw new NotImplementedException(); } - public LayerArgs get_config() + public IKerasConfig get_config() { throw new NotImplementedException(); } diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index 0ebe61d0d..7068ed477 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -108,6 +108,7 @@ https://tensorflownet.readthedocs.io + diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index 372ac6762..deeb9e4b5 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -202,6 +202,24 @@ public static string as_numpy_name(this TF_DataType type) _ => type.ToString() }; + public static string as_python_name(this TF_DataType type) + => type switch + { + TF_DataType.TF_STRING => "str", + TF_DataType.TF_UINT8 => "uint8", + TF_DataType.TF_INT8 => "int8", + TF_DataType.TF_UINT32 => "uint32", + TF_DataType.TF_INT32 => "int32", + TF_DataType.TF_UINT64 => "uint64", + TF_DataType.TF_INT64 => "int64", + TF_DataType.TF_FLOAT => "float32", + TF_DataType.TF_DOUBLE => "float64", + TF_DataType.TF_BOOL => "bool", + TF_DataType.TF_RESOURCE => "resource", + TF_DataType.TF_VARIANT => "variant", + _ => type.ToString() + }; + public static int get_datatype_size(this TF_DataType type) => type.as_base_dtype() switch { diff --git a/src/TensorFlowNET.Core/Training/ITrackable.cs b/src/TensorFlowNET.Core/Training/IWithTrackable.cs similarity index 82% rename from src/TensorFlowNET.Core/Training/ITrackable.cs rename to src/TensorFlowNET.Core/Training/IWithTrackable.cs index e4ef2c8fc..87eda8795 100644 --- a/src/TensorFlowNET.Core/Training/ITrackable.cs +++ b/src/TensorFlowNET.Core/Training/IWithTrackable.cs @@ -5,7 +5,7 @@ namespace Tensorflow.Training { - public interface ITrackable + public interface IWithTrackable { Trackable GetTrackable(); } diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index 434d51b63..132571f2a 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -26,7 +26,7 @@ limitations under the License. namespace Tensorflow.Train { - public abstract class Trackable: ITrackable + public abstract class Trackable: IWithTrackable { /// /// Corresponding to tensorflow/python/trackable/constants.py diff --git a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs index a221444b7..3aeb3200d 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs @@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Engine { public partial class Functional { - public ModelConfig get_config() + public override IKerasConfig get_config() { return get_network_config(); } @@ -25,7 +25,7 @@ ModelConfig get_network_config() { Name = name }; - + var node_conversion_map = new Dictionary(); foreach (var layer in _self_tracked_trackables) { @@ -42,23 +42,26 @@ ModelConfig get_network_config() } var layer_configs = new List(); - foreach (var layer in _self_tracked_trackables) + using (SharedObjectSavingScope.Enter()) { - var filtered_inbound_nodes = new List(); - foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) + foreach (var layer in _self_tracked_trackables) { - var node_key = _make_node_key(layer.Name, original_node_index); - if (NetworkNodes.Contains(node_key) && !node.is_input) + var filtered_inbound_nodes = new List(); + foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) { - var node_data = node.serialize(_make_node_key, node_conversion_map); - filtered_inbound_nodes.append(node_data); + var node_key = _make_node_key(layer.Name, original_node_index); + if (NetworkNodes.Contains(node_key) && !node.is_input) + { + var node_data = node.serialize(_make_node_key, node_conversion_map); + filtered_inbound_nodes.append(node_data); + } } - } - var layer_config = generic_utils.serialize_layer_to_config(layer); - layer_config.Name = layer.Name; - layer_config.InboundNodes = filtered_inbound_nodes; - layer_configs.Add(layer_config); + var layer_config = generic_utils.serialize_layer_to_config(layer); + layer_config.Name = layer.Name; + layer_config.InboundNodes = filtered_inbound_nodes; + layer_configs.Add(layer_config); + } } config.Layers = layer_configs; diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index 7c8812adb..44eaef534 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -70,6 +70,7 @@ protected void _init_graph_network(Tensors inputs, Tensors outputs) this.inputs = inputs; this.outputs = outputs; built = true; + _buildInputShape = inputs.shape; if (outputs.Any(x => x.KerasHistory == null)) base_layer_utils.create_keras_history(outputs); @@ -357,5 +358,22 @@ public override IDictionary _trackable_children(SaveType save return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache)) .ToDictionary(x => x.Key, x => x.Value); } + + protected override void _init_set_name(string name, bool zero_based = true) + { + if (string.IsNullOrEmpty(name)) + { + string class_name = GetType().Name; + if (this.GetType() == typeof(Functional)) + { + class_name = "Model"; + } + this.name = base_layer_utils.unique_layer_name(generic_utils.to_snake_case(class_name), zero_based: zero_based); + } + else + { + this.name = name; + } + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index a2f92ba8b..31b37d681 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -61,6 +61,7 @@ public abstract partial class Layer : AutoTrackable, ILayer /// Provides information about which inputs are compatible with the layer. /// protected InputSpec inputSpec; + public InputSpec InputSpec => inputSpec; bool dynamic = true; public bool SupportsMasking { get; set; } protected List _trainable_weights; @@ -79,6 +80,8 @@ public abstract partial class Layer : AutoTrackable, ILayer protected bool computePreviousMask; protected List updates; public Shape BatchInputShape => args.BatchInputShape; + protected TensorShapeConfig _buildInputShape = null; + public TensorShapeConfig BuildInputShape => _buildInputShape; List inboundNodes; public List InboundNodes => inboundNodes; @@ -223,6 +226,7 @@ protected void MaybeBuild(Tensors inputs) public virtual void build(Shape input_shape) { + _buildInputShape = input_shape; built = true; } @@ -310,7 +314,7 @@ public List weights public List Variables => weights; - public virtual LayerArgs get_config() + public virtual IKerasConfig get_config() => args; } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Save.cs b/src/TensorFlowNET.Keras/Engine/Model.Save.cs index 59b205e44..85da920ef 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Save.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Save.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using Tensorflow.Functions; using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Saving; using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.ModelSaving; @@ -30,7 +31,10 @@ public void save(string filepath, } else { - KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); + using (SharedObjectSavingScope.Enter()) + { + KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); + } } } } diff --git a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs index 6e790a26f..45f64720f 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs @@ -25,6 +25,7 @@ public override void build(Shape input_shape) { throw new ValueError("Alpha must be a number greater than 0."); } + _buildInputShape = input_shape; built = true; } diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs index aba175de9..2fd2caee1 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs @@ -14,6 +14,7 @@ public Exponential(LayerArgs args) : base(args) } public override void build(Shape input_shape) { + _buildInputShape = input_shape; built = true; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) diff --git a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs index b12d7deec..1ef8d0e58 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs @@ -16,10 +16,11 @@ public SELU ( LayerArgs args ) : base(args) { // SELU has no arguments } public override void build(Shape input_shape) { - if ( alpha < 0f ) { - throw new ValueError("Alpha must be a number greater than 0."); - } - built = true; + if ( alpha < 0f ) { + throw new ValueError("Alpha must be a number greater than 0."); + } + _buildInputShape = input_shape; + built = true; } protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { Tensor output = inputs; diff --git a/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs b/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs index 6f6dd7e85..c51316308 100644 --- a/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs +++ b/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Layers { @@ -146,7 +147,7 @@ public override Tensor _calculate_scores(Tensor query, Tensor key) return scores; } - public override LayerArgs get_config() => this.args; + public override IKerasConfig get_config() => this.args; //var config = new Dictionary { // { // "use_scale", diff --git a/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs b/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs index 3f618b5db..1348e19cf 100644 --- a/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs +++ b/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Keras.Saving; /// /// Base class for attention layers that can be used in sequence DNN/CNN models. @@ -252,6 +253,6 @@ public static Tensor _merge_masks(Tensor x, Tensor y) return tf.logical_and(x, y); } - public override LayerArgs get_config() => this.args; + public override IKerasConfig get_config() => this.args; } } diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs index e0a337caa..b8286be67 100644 --- a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs @@ -49,6 +49,7 @@ public override void build(Shape input_shape) initializer: bias_initializer, trainable: true); built = true; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs index 912a429b7..933aa9cf1 100644 --- a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs @@ -98,6 +98,7 @@ public override void build(Shape input_shape) name: tf_op_name); built = true; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = false) diff --git a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs index e4c227456..ca8007d09 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs @@ -43,6 +43,7 @@ public Dense(DenseArgs args) : public override void build(Shape input_shape) { + _buildInputShape = input_shape; var last_dim = input_shape.dims.Last(); var axes = new Dictionary(); axes[-1] = (int)last_dim; diff --git a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs index 79f4e5ce9..606f387bb 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs @@ -62,6 +62,7 @@ public override void build(Shape input_shape) name: "embeddings"); tf.Context.graph_mode(); built = true; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs index 45f5bf0f6..44b338c25 100644 --- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs +++ b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs @@ -22,6 +22,7 @@ public override void build(Shape input_shape) throw new ValueError("The `cropping` argument must be a tuple of 2 integers."); } built = true; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs index 6cb03e1e0..1f33ee3af 100644 --- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs @@ -13,7 +13,8 @@ public Cropping2D ( Cropping2DArgs args ) : base(args) { this.args = args; } public override void build(Shape input_shape) { - built = true; + built = true; + _buildInputShape = input_shape; } protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { Tensor output = inputs; diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs index 2d6751bf9..838a50434 100644 --- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs +++ b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs @@ -12,7 +12,8 @@ public Cropping3D ( Cropping3DArgs args ) : base(args) { } public override void build(Shape input_shape) { - built = true; + built = true; + _buildInputShape = input_shape; } protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 50c66be70..c1ec0ddc7 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -300,7 +300,8 @@ public ILayer Dense(int units) => new Dense(new DenseArgs { Units = units, - Activation = GetActivationByName("linear") + Activation = GetActivationByName("linear"), + ActivationName = "linear" }); /// @@ -321,6 +322,7 @@ public ILayer Dense(int units, { Units = units, Activation = GetActivationByName(activation), + ActivationName = activation, InputShape = input_shape }); diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs index 5f8217604..da7e857a2 100644 --- a/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs +++ b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs @@ -37,6 +37,7 @@ public override void build(Shape input_shape) }).ToArray(); shape_set.Add(shape); }*/ + _buildInputShape = input_shape; } protected override Tensors _merge_function(Tensors inputs) diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs index 0363d58f4..3cd43af92 100644 --- a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs +++ b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs @@ -17,6 +17,7 @@ public Merge(MergeArgs args) : base(args) public override void build(Shape input_shape) { // output_shape = input_shape.dims[1^]; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs index dac92f812..c0b16c812 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs @@ -118,6 +118,7 @@ public override void build(Shape input_shape) throw new NotImplementedException("build when renorm is true"); built = true; + _buildInputShape = input_shape; } public override Shape ComputeOutputShape(Shape input_shape) diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs index 5eebd7350..e19b9c30e 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs @@ -81,6 +81,7 @@ public override void build(Shape input_shape) _fused = _fused_can_be_used(ndims); built = true; + _buildInputShape = input_shape; } bool _fused_can_be_used(int ndims) diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs index 868506b6b..8e7a19a9a 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs @@ -24,6 +24,7 @@ public override void build(Shape input_shape) permute = new int[input_shape.rank]; dims.CopyTo(permute, 1); built = true; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs index c8366ff48..38abe2a79 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs @@ -18,6 +18,7 @@ public SimpleRNN(SimpleRNNArgs args) : base(args) public override void build(Shape input_shape) { var input_dim = input_shape[-1]; + _buildInputShape = input_shape; kernel = add_weight("kernel", (input_shape[-1], args.Units), initializer: args.KernelInitializer diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs index eead274a1..20962df1f 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs @@ -4,6 +4,7 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Layers.Rnn { @@ -136,7 +137,7 @@ public void build() // self.built = True } - public override LayerArgs get_config() + public override IKerasConfig get_config() { throw new NotImplementedException(); //def get_config(self): diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs index 4ff8f02f0..9d1c9609a 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -79,7 +79,7 @@ public static SavedMetadata generate_keras_metadata(IList saved_nodes var path = node_paths[node]; string node_path; - if (path is null) + if (path is null || path.Count() == 0) { node_path = "root"; } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs index 655127af9..8675ea65b 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using Newtonsoft.Json; using Newtonsoft.Json.Linq; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Layers; using Tensorflow.Keras.Utils; @@ -85,31 +86,38 @@ public override string TrackingMetadata JObject metadata = new JObject(); metadata["name"] = _layer.Name; metadata["trainable"] = _layer.Trainable; - // metadata["expects_training_arg"] = _obj._expects_training_arg; - // metadata["dtype"] = policy.serialize(_obj._dtype_policy) + // TODO: implement `expects_training_arg`. + metadata["expects_training_arg"] = false; + metadata["dtype"] = _layer.DType.as_python_name(); metadata["batch_input_shape"] = _layer.BatchInputShape is null ? null : JToken.FromObject(_layer.BatchInputShape); // metadata["stateful"] = _obj.stateful; // metadata["must_restore_from_config"] = _obj.must_restore_from_config; // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; metadata["autocast"] = _layer.AutoCast; - var temp = JObject.FromObject(get_serialized(_layer)); - metadata.Merge(temp, new JsonMergeSettings + if(_layer.InputSpec is not null) + { + metadata["input_spec"] = generic_utils.serialize_keras_object(_layer.InputSpec); + } + + metadata.Merge(get_serialized(_layer), new JsonMergeSettings { // Handle conflicts by using values from obj2 MergeArrayHandling = MergeArrayHandling.Merge }); // skip the check of `input_spec` and `build_input_shape` for the lack of members. // skip the check of `activity_regularizer` for the type problem. + if(_layer.BuildInputShape is not null) + { + metadata["build_input_shape"] = JToken.FromObject(_layer.BuildInputShape); + } return metadata.ToString(); } } - public static IDictionary get_serialized(Layer obj) + public static JObject get_serialized(Layer obj) { - // TODO: complete the implmentation (need to revise `get_config`). - return new Dictionary(); - //return generic_utils.serialize_keras_object(obj); + return generic_utils.serialize_keras_object(obj); } } @@ -135,18 +143,19 @@ public override string TrackingMetadata { get { - if(_obj is not Layer) + if(_obj is not InputLayer) { throw new TypeError($"The type {_obj.GetType()} cannot be recognized as an input layer."); } - var layer = (Layer)_obj; + var layer = (InputLayer)_obj; + var config = (layer.get_config() as InputLayerArgs)!; var info = new { class_name = layer.GetType().Name, name = layer.Name, dtype = layer.DType, - //sparse = layer.sparse, - //ragged = layer.ragged, + sparse = config.Sparse, + ragged = config.Ragged, batch_input_shape = layer.BatchInputShape, config = layer.get_config() }; diff --git a/src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs b/src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs deleted file mode 100644 index 4c2ecc0d8..000000000 --- a/src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs +++ /dev/null @@ -1,15 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; - -namespace Tensorflow.Keras.Saving -{ - public class TensorShapeConfig - { - public string ClassName { get; set; } - public int?[] Items { get; set; } - - public static implicit operator Shape(TensorShapeConfig shape) - => shape == null ? null : new Shape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray()); - } -} diff --git a/src/TensorFlowNET.Keras/Saving/serialization.cs b/src/TensorFlowNET.Keras/Saving/serialization.cs new file mode 100644 index 000000000..d5e46d11c --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/serialization.cs @@ -0,0 +1,125 @@ +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using Tensorflow.Keras.Saving.SavedModel; + +namespace Tensorflow.Keras.Saving +{ + // TODO: make it thread safe. + public class SharedObjectSavingScope: IDisposable + { + private class WeakReferenceEqualityComparer: IEqualityComparer> + { + public bool Equals(WeakReference x, WeakReference y) + { + if(!x.TryGetTarget(out var tx)) + { + return false; + } + if(!y.TryGetTarget(out var ty)) + { + return false; + } + return tx.Equals(ty); + } + public int GetHashCode(WeakReference obj) + { + if (!obj.TryGetTarget(out var w)) + { + return 0; + } + return w.GetHashCode(); + } + } + private static SharedObjectSavingScope? _instance = null; + private readonly Dictionary, int> _shared_object_ids= new Dictionary, int>(); + private int _currentId = 0; + /// + /// record how many times the scope is nested. + /// + private int _nestedDepth = 0; + private SharedObjectSavingScope() + { + + } + + public static SharedObjectSavingScope Enter() + { + if(_instance is not null) + { + _instance._nestedDepth++; + return _instance; + } + else + { + _instance = new SharedObjectSavingScope(); + _instance._nestedDepth++; + return _instance; + } + } + + public static SharedObjectSavingScope GetScope() + { + return _instance; + } + + public int GetId(object? obj) + { + if(obj is null) + { + return _currentId++; + } + var maybe_key = _shared_object_ids.Keys.SingleOrDefault(x => new WeakReferenceEqualityComparer().Equals(x, new WeakReference(obj))); + if (maybe_key is not null) + { + return _shared_object_ids[maybe_key]; + } + _shared_object_ids[new WeakReference(obj)] = _currentId++; + return _currentId; + } + + public void Dispose() + { + _nestedDepth--; + if(_nestedDepth== 0) + { + _instance = null; + } + } + } + + public static class serialize_utils + { + public static readonly string SHARED_OBJECT_KEY = "shared_object_id"; + /// + /// Returns the serialization of the class with the given config. + /// + /// + /// + /// + /// + /// + public static JObject serialize_keras_class_and_config(string class_name, JToken config, object? obj = null, int? shared_object_id = null) + { + JObject res = new JObject(); + res["class_name"] = class_name; + res["config"] = config; + + if(shared_object_id is not null) + { + res[SHARED_OBJECT_KEY] = shared_object_id!; + } + + var scope = SharedObjectSavingScope.GetScope(); + if(scope is not null && obj is not null) + { + res[SHARED_OBJECT_KEY] = scope.GetId(obj); + } + + return res; + } + } +} diff --git a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs index 1e6ce4091..d845f3ca9 100644 --- a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs @@ -53,7 +53,7 @@ public static IVariableV1 make_variable(VariableArgs args) } /// - /// Makes a layer name (or arbitrary string) unique within a TensorFlow graph. + /// Makes a layer name (or arbitrary string) unique within a TensorFlow graph. (correponding to `backend.unique_object_name` of python.) /// /// /// diff --git a/src/TensorFlowNET.Keras/Utils/generic_utils.cs b/src/TensorFlowNET.Keras/Utils/generic_utils.cs index 68903eb2c..730a33e3e 100644 --- a/src/TensorFlowNET.Keras/Utils/generic_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/generic_utils.cs @@ -14,10 +14,14 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; using System; using System.Collections; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Utils @@ -32,13 +36,21 @@ public class generic_utils public static LayerConfig serialize_layer_to_config(ILayer instance) { var config = instance.get_config(); + Debug.Assert(config is LayerArgs); return new LayerConfig { - Config = config, + Config = config as LayerArgs, ClassName = instance.GetType().Name }; } + public static JObject serialize_keras_object(IKerasConfigable instance) + { + var config = JToken.FromObject(instance.get_config()); + // TODO: change the class_name to registered name, instead of system class name. + return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance); + } + public static string to_snake_case(string name) { return string.Concat(name.Select((x, i) => diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs index 0a1098af7..67e8ff797 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs @@ -1,6 +1,8 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using Tensorflow.Keras.Engine; +using System.Diagnostics; using static Tensorflow.KerasApi; +using Tensorflow.Keras.Saving; namespace TensorFlowNET.Keras.UnitTest { @@ -15,7 +17,8 @@ public void GetAndFromConfig() { var model = GetFunctionalModel(); var config = model.get_config(); - var new_model = keras.models.from_config(config); + Debug.Assert(config is ModelConfig); + var new_model = keras.models.from_config(config as ModelConfig); Assert.AreEqual(model.Layers.Count, new_model.Layers.Count); } diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs b/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs index 0f34ff10d..90d0a48a5 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs @@ -15,17 +15,14 @@ using Tensorflow.Keras.Losses; using Tensorflow.Keras.Metrics; using Tensorflow.Keras.Optimizers; +using Tensorflow.Operations; namespace TensorFlowNET.Keras.UnitTest; -// class MNISTLoader -// { -// public MNISTLoader() -// { -// var mnist = new MnistModelLoader() -// -// } -// } +public static class AutoGraphExtension +{ + +} [TestClass] public class SaveTest @@ -42,6 +39,8 @@ public void Test() model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[]{"accuracy"}); + var g = ops.get_default_graph(); + var data_loader = new MnistModelLoader(); var num_epochs = 1; var batch_size = 50; @@ -50,11 +49,34 @@ public void Test() { TrainDir = "mnist", OneHot = false, - ValidationSize = 0, + ValidationSize = 50000, }).Result; model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); model.save("C:\\Work\\tf.net\\tf_test\\tf.net.model", save_format:"pb"); } + + [TestMethod] + public void Temp() + { + var graph = new Graph(); + var g = graph.as_default(); + //var input_tensor = array_ops.placeholder(TF_DataType.TF_FLOAT, new int[] { 1 }, "test_string_tensor"); + var input_tensor = tf.placeholder(tf.int32, new int[] { 1 }, "aa"); + var wrapped_func = tf.autograph.to_graph(func); + var res = wrapped_func(input_tensor); + g.Exit(); + } + + private Tensor func(Tensor tensor) + { + return gen_ops.neg(tensor); + //return array_ops.identity(tensor); + //tf.device("cpu:0"); + //using (ops.control_dependencies(new object[] { res.op })) + //{ + // return array_ops.identity(tensor); + //} + } } \ No newline at end of file From 59c17705c33bd3fb1cf2e9e92b9893d1ed6b1f28 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Fri, 3 Feb 2023 19:48:49 +0800 Subject: [PATCH 11/15] Refine some code after merge. --- .../Operations/Initializers/Orthogonal.cs | 6 +++--- test/TensorFlowNET.Keras.UnitTest/InitializerTest.cs | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs index 3790a864d..492047c9f 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs @@ -32,10 +32,10 @@ public Orthogonal(float gain = 1.0f, int? seed = null) _seed = seed; } - private readonly Dictionary _config; + private readonly Dictionary _config; - public string ClassName => "Orthogonal"; - public IDictionary Config => throw new NotImplementedException(); + public string ClassName => "Orthogonal"; + public IDictionary Config => throw new NotImplementedException(); public Tensor Apply(InitializerArgs args) { return _generate_init_val(args.Shape, args.DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : args.DType); diff --git a/test/TensorFlowNET.Keras.UnitTest/InitializerTest.cs b/test/TensorFlowNET.Keras.UnitTest/InitializerTest.cs index c811b5643..6950e65fc 100644 --- a/test/TensorFlowNET.Keras.UnitTest/InitializerTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/InitializerTest.cs @@ -6,7 +6,7 @@ using TensorFlowNET.Keras.UnitTest; using static Tensorflow.Binding; -namespace Tensorflow.Keras.UnitTest; +namespace TensorFlowNET.Keras.UnitTest; [TestClass] public class InitializerTest : EagerModeTestBase @@ -15,6 +15,6 @@ public class InitializerTest : EagerModeTestBase public void Orthogonal() { var initializer = tf.keras.initializers.Orthogonal(); - var values = initializer.Apply(new InitializerArgs((2, 2))); + var values = initializer.Apply(new Tensorflow.InitializerArgs((2, 2))); } } From 6c07778243fb0bc8ab6d209e33a87703db10bee1 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Fri, 3 Feb 2023 20:37:25 +0800 Subject: [PATCH 12/15] Add two simple sequential test case of pb model save. --- src/TensorFlowNET.Keras/Engine/Model.Save.cs | 2 +- .../SequentialModelTest.cs} | 66 +++++++++---------- 2 files changed, 34 insertions(+), 34 deletions(-) rename test/TensorFlowNET.Keras.UnitTest/{SaveTest.cs => SaveModel/SequentialModelTest.cs} (53%) diff --git a/src/TensorFlowNET.Keras/Engine/Model.Save.cs b/src/TensorFlowNET.Keras/Engine/Model.Save.cs index 85da920ef..a1e891f98 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Save.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Save.cs @@ -25,7 +25,7 @@ public void save(string filepath, ConcreteFunction? signatures = null, bool save_traces = true) { - if (save_format != "pb") + if (save_format != "tf") { saver.save(this, filepath); } diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs similarity index 53% rename from test/TensorFlowNET.Keras.UnitTest/SaveTest.cs rename to test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs index 90d0a48a5..288a92b32 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs @@ -17,18 +17,13 @@ using Tensorflow.Keras.Optimizers; using Tensorflow.Operations; -namespace TensorFlowNET.Keras.UnitTest; - -public static class AutoGraphExtension -{ - -} +namespace TensorFlowNET.Keras.UnitTest.SaveModel; [TestClass] -public class SaveTest +public class SequentialModelTest { [TestMethod] - public void Test() + public void SimpleModelFromAutoCompile() { var inputs = new KerasInterface().Input((28, 28, 1)); var x = new Flatten(new FlattenArgs()).Apply(inputs); @@ -36,10 +31,8 @@ public void Test() x = new LayersApi().Dense(units: 10).Apply(x); var outputs = new LayersApi().Softmax(axis: 1).Apply(x); var model = new KerasInterface().Model(inputs, outputs); - - model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[]{"accuracy"}); - var g = ops.get_default_graph(); + model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); var data_loader = new MnistModelLoader(); var num_epochs = 1; @@ -49,34 +42,41 @@ public void Test() { TrainDir = "mnist", OneHot = false, - ValidationSize = 50000, + ValidationSize = 10000, }).Result; - + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); - - model.save("C:\\Work\\tf.net\\tf_test\\tf.net.model", save_format:"pb"); + + model.save("C:\\Work\\tf.net\\tf_test\\tf.net.simple.compile", save_format: "tf"); } [TestMethod] - public void Temp() + public void SimpleModelFromSequential() { - var graph = new Graph(); - var g = graph.as_default(); - //var input_tensor = array_ops.placeholder(TF_DataType.TF_FLOAT, new int[] { 1 }, "test_string_tensor"); - var input_tensor = tf.placeholder(tf.int32, new int[] { 1 }, "aa"); - var wrapped_func = tf.autograph.to_graph(func); - var res = wrapped_func(input_tensor); - g.Exit(); - } + Model model = KerasApi.keras.Sequential(new List() + { + keras.layers.InputLayer((28, 28, 1)), + keras.layers.Flatten(), + keras.layers.Dense(100, "relu"), + keras.layers.Dense(10), + keras.layers.Softmax(1) + }); - private Tensor func(Tensor tensor) - { - return gen_ops.neg(tensor); - //return array_ops.identity(tensor); - //tf.device("cpu:0"); - //using (ops.control_dependencies(new object[] { res.op })) - //{ - // return array_ops.identity(tensor); - //} + model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); + + var data_loader = new MnistModelLoader(); + var num_epochs = 1; + var batch_size = 50; + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 10000, + }).Result; + + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); + + model.save("C:\\Work\\tf.net\\tf_test\\tf.net.simple.sequential", save_format: "tf"); } } \ No newline at end of file From ad541a7971d5c46b46b7daac69de002f3bdc1b6b Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Sat, 4 Feb 2023 21:03:40 +0800 Subject: [PATCH 13/15] Implement serializing attributes other keras arg definitions. --- .../ArgsDefinition/Activation/ELUArgs.cs | 11 +- .../Activation/LeakyReLuArgs.cs | 6 +- .../ArgsDefinition/Activation/SoftmaxArgs.cs | 8 +- .../ArgsDefinition/Attention/AttentionArgs.cs | 4 + .../Attention/BaseDenseAttentionArgs.cs | 5 +- .../Attention/MultiHeadAttentionArgs.cs | 20 ++- .../ArgsDefinition/AutoSerializeLayerArgs.cs | 6 + .../Convolution/ConvolutionalArgs.cs | 40 ++++- .../{Attention => Core}/EinsumDenseArgs.cs | 36 ++++- .../ArgsDefinition/Core/EmbeddingArgs.cs | 15 +- .../ArgsDefinition/Cropping/Cropping2DArgs.cs | 16 -- .../ArgsDefinition/Cropping/Cropping3DArgs.cs | 16 -- .../ArgsDefinition/Cropping/CroppingArgs.cs | 10 -- .../Keras/ArgsDefinition/Lstm/LSTMCellArgs.cs | 6 - .../Keras/ArgsDefinition/Merging/MergeArgs.cs | 1 + .../Normalization/BatchNormalizationArgs.cs | 20 ++- .../Normalization/LayerNormalizationArgs.cs | 15 +- .../ArgsDefinition/Pooling/Pooling1DArgs.cs | 10 +- .../ArgsDefinition/Pooling/Pooling2DArgs.cs | 10 +- .../Preprocessing/PreprocessingLayerArgs.cs | 2 +- .../Preprocessing/RescalingArgs.cs | 12 ++ .../Preprocessing/ResizingArgs.cs | 1 + .../Preprocessing/TextVectorizationArgs.cs | 11 +- .../Regularization/DropoutArgs.cs | 9 +- .../ArgsDefinition/Rescaling/RescalingArgs.cs | 8 - .../Reshaping/Cropping2DArgs.cs | 18 +++ .../Reshaping/Cropping3DArgs.cs | 18 +++ .../ArgsDefinition/Reshaping/CroppingArgs.cs | 12 ++ .../ArgsDefinition/Reshaping/PermuteArgs.cs | 12 +- .../ArgsDefinition/Reshaping/ReshapeArgs.cs | 7 +- .../Reshaping/UpSampling2DArgs.cs | 9 +- .../Reshaping/ZeroPadding2DArgs.cs | 1 + .../ArgsDefinition/{Lstm => Rnn}/LSTMArgs.cs | 5 +- .../Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs | 7 + .../Keras/ArgsDefinition/Rnn/RNNArgs.cs | 15 +- .../Keras/Layers/ILayersApi.Cropping.cs | 2 +- src/TensorFlowNET.Keras/Activations.cs | 82 ++++++++++ .../Activations/Activations.Linear.cs | 10 -- .../Activations/Activations.Relu.cs | 10 -- .../Activations/Activations.Sigmoid.cs | 11 -- .../Activations/Activations.Softmax.cs | 11 -- .../Activations/Activations.Tanh.cs | 11 -- .../Layers/Attention/MultiHeadAttention.cs | 1 + .../Layers/Core/EinsumDense.cs | 2 +- .../Layers/Cropping/Cropping2D.cs | 114 ------------- .../Layers/Cropping/Cropping3D.cs | 124 --------------- .../Layers/LayersApi.Cropping.cs | 10 +- src/TensorFlowNET.Keras/Layers/LayersApi.cs | 37 ++--- .../Normalization/BatchNormalization.cs | 2 +- .../{Rescaling => Preprocessing}/Rescaling.cs | 0 .../{Cropping => Reshaping}/Cropping1D.cs | 15 +- .../Layers/Reshaping/Cropping2D.cs | 140 ++++++++++++++++ .../Layers/Reshaping/Cropping3D.cs | 150 ++++++++++++++++++ .../Layers/{Lstm => Rnn}/LSTM.cs | 5 +- .../Layers/{Lstm => Rnn}/LSTMCell.cs | 4 +- src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs | 1 - .../SaveModel/SequentialModelTest.cs | 82 ---------- 57 files changed, 702 insertions(+), 524 deletions(-) rename src/TensorFlowNET.Core/Keras/ArgsDefinition/{Attention => Core}/EinsumDenseArgs.cs (65%) delete mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping2DArgs.cs delete mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping3DArgs.cs delete mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/CroppingArgs.cs delete mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMCellArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/RescalingArgs.cs delete mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Rescaling/RescalingArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping2DArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping3DArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/CroppingArgs.cs rename src/TensorFlowNET.Core/Keras/ArgsDefinition/{Lstm => Rnn}/LSTMArgs.cs (67%) create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs create mode 100644 src/TensorFlowNET.Keras/Activations.cs delete mode 100644 src/TensorFlowNET.Keras/Activations/Activations.Linear.cs delete mode 100644 src/TensorFlowNET.Keras/Activations/Activations.Relu.cs delete mode 100644 src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs delete mode 100644 src/TensorFlowNET.Keras/Activations/Activations.Softmax.cs delete mode 100644 src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs delete mode 100644 src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs delete mode 100644 src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs rename src/TensorFlowNET.Keras/Layers/{Rescaling => Preprocessing}/Rescaling.cs (100%) rename src/TensorFlowNET.Keras/Layers/{Cropping => Reshaping}/Cropping1D.cs (79%) create mode 100644 src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs create mode 100644 src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs rename src/TensorFlowNET.Keras/Layers/{Lstm => Rnn}/LSTM.cs (87%) rename src/TensorFlowNET.Keras/Layers/{Lstm => Rnn}/LSTMCell.cs (72%) delete mode 100644 test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ELUArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ELUArgs.cs index 235523161..e830e5bf8 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ELUArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ELUArgs.cs @@ -1,9 +1,12 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; namespace Tensorflow.Keras.ArgsDefinition { - public class ELUArgs : LayerArgs { - public float Alpha { get; set; } = 0.1f; - } + public class ELUArgs : AutoSerializeLayerArgs + { + [JsonProperty("alpha")] + public float Alpha { get; set; } = 0.1f; + } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/LeakyReLuArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/LeakyReLuArgs.cs index 6bdb294c2..6d9531346 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/LeakyReLuArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/LeakyReLuArgs.cs @@ -1,14 +1,16 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; namespace Tensorflow.Keras.ArgsDefinition { - public class LeakyReLuArgs : LayerArgs + public class LeakyReLuArgs : AutoSerializeLayerArgs { /// /// Negative slope coefficient. /// + [JsonProperty("alpha")] public float Alpha { get; set; } = 0.3f; } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs index a37973bc6..1c1d147f1 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs @@ -4,15 +4,9 @@ using System.Text; namespace Tensorflow.Keras.ArgsDefinition { - public class SoftmaxArgs : LayerArgs + public class SoftmaxArgs : AutoSerializeLayerArgs { [JsonProperty("axis")] public Axis axis { get; set; } = -1; - [JsonProperty("name")] - public override string Name { get => base.Name; set => base.Name = value; } - [JsonProperty("trainable")] - public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } - [JsonProperty("dtype")] - public override TF_DataType DType { get => base.DType; set => base.DType = value; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs index 73477c58f..4cdfb46bd 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs @@ -1,3 +1,5 @@ +using Newtonsoft.Json; + namespace Tensorflow.Keras.ArgsDefinition { public class AttentionArgs : BaseDenseAttentionArgs @@ -6,6 +8,7 @@ public class AttentionArgs : BaseDenseAttentionArgs /// /// If `true`, will create a scalar variable to scale the attention scores. /// + [JsonProperty("use_scale")] public bool use_scale { get; set; } = false; /// @@ -14,6 +17,7 @@ public class AttentionArgs : BaseDenseAttentionArgs /// and key vectors. `"concat"` refers to the hyperbolic tangent of the /// concatenation of the query and key vectors. /// + [JsonProperty("score_mode")] public string score_mode { get; set; } = "dot"; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs index b2a0c3a51..0ef017370 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs @@ -1,6 +1,8 @@ +using Newtonsoft.Json; + namespace Tensorflow.Keras.ArgsDefinition { - public class BaseDenseAttentionArgs : LayerArgs + public class BaseDenseAttentionArgs : AutoSerializeLayerArgs { /// @@ -14,6 +16,7 @@ public class BaseDenseAttentionArgs : LayerArgs /// Float between 0 and 1. Fraction of the units to drop for the /// attention scores. /// + [JsonProperty("dropout")] public float dropout { get; set; } = 0f; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/MultiHeadAttentionArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/MultiHeadAttentionArgs.cs index 21b2d218c..077dea89d 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/MultiHeadAttentionArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/MultiHeadAttentionArgs.cs @@ -1,22 +1,40 @@ +using Newtonsoft.Json; using System; using static Tensorflow.Binding; namespace Tensorflow.Keras.ArgsDefinition { - public class MultiHeadAttentionArgs : LayerArgs + public class MultiHeadAttentionArgs : AutoSerializeLayerArgs { + [JsonProperty("num_heads")] public int NumHeads { get; set; } + [JsonProperty("key_dim")] public int KeyDim { get; set; } + [JsonProperty("value_dim")] public int? ValueDim { get; set; } = null; + [JsonProperty("dropout")] public float Dropout { get; set; } = 0f; + [JsonProperty("use_bias")] public bool UseBias { get; set; } = true; + [JsonProperty("output_shape")] public Shape OutputShape { get; set; } = null; + [JsonProperty("attention_axes")] public Shape AttentionAxis { get; set; } = null; + [JsonProperty("kernel_initializer")] public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; + [JsonProperty("bias_initializer")] public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; + [JsonProperty("kernel_regularizer")] public IRegularizer KernelRegularizer { get; set; } = null; + [JsonProperty("bias_regularizer")] public IRegularizer BiasRegularizer { get; set; } = null; + [JsonProperty("kernel_constraint")] public Action KernelConstraint { get; set; } = null; + [JsonProperty("bias_constraint")] public Action BiasConstraint { get; set; } = null; + [JsonProperty("activity_regularizer")] + public override IRegularizer ActivityRegularizer { get => base.ActivityRegularizer; set => base.ActivityRegularizer = value; } + + // TODO: Add `key_shape`, `value_shape`, `query_shape`. } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs index 66b34a1ae..1a97b0135 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs @@ -5,6 +5,12 @@ namespace Tensorflow.Keras.ArgsDefinition { + /// + /// This class has nothing but the attributes different from `LayerArgs`. + /// It's used to serialize the model to `tf` format. + /// If the `get_config` of a `Layer` in python code of tensorflow contains `super().get_config`, + /// then the Arg definition should inherit `utoSerializeLayerArgs` instead of `LayerArgs`. + /// public class AutoSerializeLayerArgs: LayerArgs { [JsonProperty("name")] diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs index 4f050228b..08d563c1a 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs @@ -1,31 +1,65 @@ -using System; +using Newtonsoft.Json; +using System; using static Tensorflow.Binding; namespace Tensorflow.Keras.ArgsDefinition { - public class ConvolutionalArgs : LayerArgs + public class ConvolutionalArgs : AutoSerializeLayerArgs { public int Rank { get; set; } = 2; + [JsonProperty("filters")] public int Filters { get; set; } public int NumSpatialDims { get; set; } = Unknown; + [JsonProperty("kernel_size")] public Shape KernelSize { get; set; } = 5; /// /// specifying the stride length of the convolution. /// + [JsonProperty("strides")] public Shape Strides { get; set; } = (1, 1); - + [JsonProperty("padding")] public string Padding { get; set; } = "valid"; + [JsonProperty("data_format")] public string DataFormat { get; set; } + [JsonProperty("dilation_rate")] public Shape DilationRate { get; set; } = (1, 1); + [JsonProperty("groups")] public int Groups { get; set; } = 1; public Activation Activation { get; set; } + private string _activationName; + [JsonProperty("activation")] + public string ActivationName + { + get + { + if (string.IsNullOrEmpty(_activationName)) + { + return Activation.Method.Name; + } + else + { + return _activationName; + } + } + set + { + _activationName = value; + } + } + [JsonProperty("use_bias")] public bool UseBias { get; set; } + [JsonProperty("kernel_initializer")] public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; + [JsonProperty("bias_initializer")] public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; + [JsonProperty("kernel_regularizer")] public IRegularizer KernelRegularizer { get; set; } + [JsonProperty("bias_regularizer")] public IRegularizer BiasRegularizer { get; set; } + [JsonProperty("kernel_constraint")] public Action KernelConstraint { get; set; } + [JsonProperty("bias_constraint")] public Action BiasConstraint { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/EinsumDenseArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EinsumDenseArgs.cs similarity index 65% rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/EinsumDenseArgs.cs rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EinsumDenseArgs.cs index 3a8642ffc..9817e9c6d 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/EinsumDenseArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EinsumDenseArgs.cs @@ -1,9 +1,10 @@ +using Newtonsoft.Json; using System; using static Tensorflow.Binding; -namespace Tensorflow.Keras.ArgsDefinition +namespace Tensorflow.Keras.ArgsDefinition.Core { - public class EinsumDenseArgs : LayerArgs + public class EinsumDenseArgs : AutoSerializeLayerArgs { /// /// An equation describing the einsum to perform. This equation must @@ -11,6 +12,7 @@ public class EinsumDenseArgs : LayerArgs /// `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum axis /// expression sequence. /// + [JsonProperty("equation")] public string Equation { get; set; } /// @@ -19,6 +21,7 @@ public class EinsumDenseArgs : LayerArgs /// None for any dimension that is unknown or can be inferred from the input /// shape. /// + [JsonProperty("output_shape")] public Shape OutputShape { get; set; } /// @@ -26,41 +29,70 @@ public class EinsumDenseArgs : LayerArgs /// Each character in the `bias_axes` string should correspond to a character /// in the output portion of the `equation` string. /// + [JsonProperty("bias_axes")] public string BiasAxes { get; set; } = null; /// /// Activation function to use. /// public Activation Activation { get; set; } + private string _activationName; + [JsonProperty("activation")] + public string ActivationName + { + get + { + if (string.IsNullOrEmpty(_activationName)) + { + return Activation.Method.Name; + } + else + { + return _activationName; + } + } + set + { + _activationName = value; + } + } /// /// Initializer for the `kernel` weights matrix. /// + [JsonProperty("kernel_initializer")] public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; /// /// Initializer for the bias vector. /// + [JsonProperty("bias_initializer")] public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; /// /// Regularizer function applied to the `kernel` weights matrix. /// + [JsonProperty("kernel_regularizer")] public IRegularizer KernelRegularizer { get; set; } /// /// Regularizer function applied to the bias vector. /// + [JsonProperty("bias_regularizer")] public IRegularizer BiasRegularizer { get; set; } /// /// Constraint function applied to the `kernel` weights matrix. /// + [JsonProperty("kernel_constraint")] public Action KernelConstraint { get; set; } /// /// Constraint function applied to the bias vector. /// + [JsonProperty("bias_constraint")] public Action BiasConstraint { get; set; } + [JsonProperty("activity_regularizer")] + public override IRegularizer ActivityRegularizer { get => base.ActivityRegularizer; set => base.ActivityRegularizer = value; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EmbeddingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EmbeddingArgs.cs index b1f4fddd3..c462961b3 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EmbeddingArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EmbeddingArgs.cs @@ -1,11 +1,22 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition { - public class EmbeddingArgs : LayerArgs + public class EmbeddingArgs : AutoSerializeLayerArgs { + [JsonProperty("input_dim")] public int InputDim { get; set; } + [JsonProperty("output_dim")] public int OutputDim { get; set; } + [JsonProperty("mask_zero")] public bool MaskZero { get; set; } + [JsonProperty("input_length")] public int InputLength { get; set; } = -1; + [JsonProperty("embeddings_initializer")] public IInitializer EmbeddingsInitializer { get; set; } + [JsonProperty("activity_regularizer")] + public override IRegularizer ActivityRegularizer { get => base.ActivityRegularizer; set => base.ActivityRegularizer = value; } + + // TODO: `embeddings_regularizer`, `embeddings_constraint`. } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping2DArgs.cs deleted file mode 100644 index 16705063e..000000000 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping2DArgs.cs +++ /dev/null @@ -1,16 +0,0 @@ -using Tensorflow.NumPy; - -namespace Tensorflow.Keras.ArgsDefinition { - public class Cropping2DArgs : LayerArgs { - /// - /// channel last: (b, h, w, c) - /// channels_first: (b, c, h, w) - /// - public enum DataFormat { channels_first = 0, channels_last = 1 } - /// - /// Accept: int[1][2], int[1][1], int[2][2] - /// - public NDArray cropping { get; set; } - public DataFormat data_format { get; set; } = DataFormat.channels_last; - } -} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping3DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping3DArgs.cs deleted file mode 100644 index 9da2adc7f..000000000 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping3DArgs.cs +++ /dev/null @@ -1,16 +0,0 @@ -using Tensorflow.NumPy; - -namespace Tensorflow.Keras.ArgsDefinition { - public class Cropping3DArgs : LayerArgs { - /// - /// channel last: (b, h, w, c) - /// channels_first: (b, c, h, w) - /// - public enum DataFormat { channels_first = 0, channels_last = 1 } - /// - /// Accept: int[1][3], int[1][1], int[3][2] - /// - public NDArray cropping { get; set; } - public DataFormat data_format { get; set; } = DataFormat.channels_last; - } -} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/CroppingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/CroppingArgs.cs deleted file mode 100644 index 9d23acd43..000000000 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/CroppingArgs.cs +++ /dev/null @@ -1,10 +0,0 @@ -using Tensorflow.NumPy; - -namespace Tensorflow.Keras.ArgsDefinition { - public class CroppingArgs : LayerArgs { - /// - /// Accept length 1 or 2 - /// - public NDArray cropping { get; set; } - } -} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMCellArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMCellArgs.cs deleted file mode 100644 index fb0868dc5..000000000 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMCellArgs.cs +++ /dev/null @@ -1,6 +0,0 @@ -namespace Tensorflow.Keras.ArgsDefinition.Lstm -{ - public class LSTMCellArgs : LayerArgs - { - } -} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs index 3e6791e3b..0140b3dd0 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs @@ -4,6 +4,7 @@ namespace Tensorflow.Keras.ArgsDefinition { + // TODO: complete the implementation public class MergeArgs : LayerArgs { public Tensors Inputs { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/BatchNormalizationArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/BatchNormalizationArgs.cs index 954ede574..6ee91e80b 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/BatchNormalizationArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/BatchNormalizationArgs.cs @@ -1,21 +1,37 @@ -using static Tensorflow.Binding; +using Newtonsoft.Json; +using static Tensorflow.Binding; namespace Tensorflow.Keras.ArgsDefinition { - public class BatchNormalizationArgs : LayerArgs + public class BatchNormalizationArgs : AutoSerializeLayerArgs { + [JsonProperty("axis")] public Shape Axis { get; set; } = -1; + [JsonProperty("momentum")] public float Momentum { get; set; } = 0.99f; + [JsonProperty("epsilon")] public float Epsilon { get; set; } = 1e-3f; + [JsonProperty("center")] public bool Center { get; set; } = true; + [JsonProperty("scale")] public bool Scale { get; set; } = true; + [JsonProperty("beta_initializer")] public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer; + [JsonProperty("gamma_initializer")] public IInitializer GammaInitializer { get; set; } = tf.ones_initializer; + [JsonProperty("moving_mean_initializer")] public IInitializer MovingMeanInitializer { get; set; } = tf.zeros_initializer; + [JsonProperty("moving_variance_initializer")] public IInitializer MovingVarianceInitializer { get; set; } = tf.ones_initializer; + [JsonProperty("beta_regularizer")] public IRegularizer BetaRegularizer { get; set; } + [JsonProperty("gamma_regularizer")] public IRegularizer GammaRegularizer { get; set; } + // TODO: `beta_constraint` and `gamma_constraint`. + [JsonProperty("renorm")] public bool Renorm { get; set; } + // TODO: `renorm_clipping` and `virtual_batch_size`. + [JsonProperty("renorm_momentum")] public float RenormMomentum { get; set; } = 0.99f; } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/LayerNormalizationArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/LayerNormalizationArgs.cs index 13fd98b41..1ac661b37 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/LayerNormalizationArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/LayerNormalizationArgs.cs @@ -1,16 +1,27 @@ -using static Tensorflow.Binding; +using Newtonsoft.Json; +using static Tensorflow.Binding; namespace Tensorflow.Keras.ArgsDefinition { - public class LayerNormalizationArgs : LayerArgs + public class LayerNormalizationArgs : AutoSerializeLayerArgs { + [JsonProperty("axis")] public Axis Axis { get; set; } = -1; + [JsonProperty("epsilon")] public float Epsilon { get; set; } = 1e-3f; + [JsonProperty("center")] public bool Center { get; set; } = true; + [JsonProperty("scale")] public bool Scale { get; set; } = true; + [JsonProperty("beta_initializer")] public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer; + [JsonProperty("gamma_initializer")] public IInitializer GammaInitializer { get; set; } = tf.ones_initializer; + [JsonProperty("beta_regularizer")] public IRegularizer BetaRegularizer { get; set; } + [JsonProperty("gamma_regularizer")] public IRegularizer GammaRegularizer { get; set; } + + // TODO: `beta_constraint` and `gamma_constraint`. } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling1DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling1DArgs.cs index 9742203d6..c5fdca675 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling1DArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling1DArgs.cs @@ -1,6 +1,8 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition { - public class Pooling1DArgs : LayerArgs + public class Pooling1DArgs : AutoSerializeLayerArgs { /// /// The pooling function to apply, e.g. `tf.nn.max_pool2d`. @@ -10,11 +12,13 @@ public class Pooling1DArgs : LayerArgs /// /// specifying the size of the pooling window. /// + [JsonProperty("pool_size")] public int PoolSize { get; set; } /// /// specifying the strides of the pooling operation. /// + [JsonProperty("strides")] public int Strides { get { return _strides.HasValue ? _strides.Value : PoolSize; } set { _strides = value; } @@ -24,11 +28,13 @@ public int Strides { /// /// The padding method, either 'valid' or 'same'. /// + [JsonProperty("padding")] public string Padding { get; set; } = "valid"; /// /// one of `channels_last` (default) or `channels_first`. /// + [JsonProperty("data_format")] public string DataFormat { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling2DArgs.cs index 1260af4c6..91a372ef3 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling2DArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling2DArgs.cs @@ -1,6 +1,8 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition { - public class Pooling2DArgs : LayerArgs + public class Pooling2DArgs : AutoSerializeLayerArgs { /// /// The pooling function to apply, e.g. `tf.nn.max_pool2d`. @@ -10,21 +12,25 @@ public class Pooling2DArgs : LayerArgs /// /// specifying the size of the pooling window. /// + [JsonProperty("pool_size")] public Shape PoolSize { get; set; } /// /// specifying the strides of the pooling operation. /// + [JsonProperty("strides")] public Shape Strides { get; set; } /// /// The padding method, either 'valid' or 'same'. /// + [JsonProperty("padding")] public string Padding { get; set; } = "valid"; /// /// one of `channels_last` (default) or `channels_first`. /// + [JsonProperty("data_format")] public string DataFormat { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs index 28ccf9f74..97cb364d9 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs @@ -4,7 +4,7 @@ namespace Tensorflow.Keras.ArgsDefinition { - public class PreprocessingLayerArgs : LayerArgs + public class PreprocessingLayerArgs : AutoSerializeLayerArgs { } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/RescalingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/RescalingArgs.cs new file mode 100644 index 000000000..154bd8c89 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/RescalingArgs.cs @@ -0,0 +1,12 @@ +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class RescalingArgs : AutoSerializeLayerArgs + { + [JsonProperty("scale")] + public float Scale { get; set; } + [JsonProperty("offset")] + public float Offset { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs index cf11595e2..39fa52211 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs @@ -1,5 +1,6 @@ namespace Tensorflow.Keras.ArgsDefinition { + // TODO: no corresponding class found in keras python, maybe obselete? public class ResizingArgs : PreprocessingLayerArgs { public int Height { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs index ddeadc001..1a7149f5a 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs @@ -1,4 +1,5 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; @@ -6,11 +7,19 @@ namespace Tensorflow.Keras.ArgsDefinition { public class TextVectorizationArgs : PreprocessingLayerArgs { + [JsonProperty("standardize")] public Func Standardize { get; set; } + [JsonProperty("split")] public string Split { get; set; } = "standardize"; + [JsonProperty("max_tokens")] public int MaxTokens { get; set; } = -1; + [JsonProperty("output_mode")] public string OutputMode { get; set; } = "int"; + [JsonProperty("output_sequence_length")] public int OutputSequenceLength { get; set; } = -1; + [JsonProperty("vocabulary")] public string[] Vocabulary { get; set; } + + // TODO: Add `ngrams`, `sparse`, `ragged`, `idf_weights`, `encoding` } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Regularization/DropoutArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Regularization/DropoutArgs.cs index c41c6fe85..1c85d4936 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Regularization/DropoutArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Regularization/DropoutArgs.cs @@ -1,21 +1,26 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition { - public class DropoutArgs : LayerArgs + public class DropoutArgs : AutoSerializeLayerArgs { /// /// Float between 0 and 1. Fraction of the input units to drop. /// + [JsonProperty("rate")] public float Rate { get; set; } /// /// 1D integer tensor representing the shape of the /// binary dropout mask that will be multiplied with the input. /// + [JsonProperty("noise_shape")] public Shape NoiseShape { get; set; } /// /// random seed. /// + [JsonProperty("seed")] public int? Seed { get; set; } public bool SupportsMasking { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rescaling/RescalingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rescaling/RescalingArgs.cs deleted file mode 100644 index ec9b53150..000000000 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rescaling/RescalingArgs.cs +++ /dev/null @@ -1,8 +0,0 @@ -namespace Tensorflow.Keras.ArgsDefinition -{ - public class RescalingArgs : LayerArgs - { - public float Scale { get; set; } - public float Offset { get; set; } - } -} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping2DArgs.cs new file mode 100644 index 000000000..8c2626390 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping2DArgs.cs @@ -0,0 +1,18 @@ +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.ArgsDefinition.Reshaping +{ + public class Cropping2DArgs : LayerArgs + { + /// + /// channel last: (b, h, w, c) + /// channels_first: (b, c, h, w) + /// + public enum DataFormat { channels_first = 0, channels_last = 1 } + /// + /// Accept: int[1][2], int[1][1], int[2][2] + /// + public NDArray cropping { get; set; } + public DataFormat data_format { get; set; } = DataFormat.channels_last; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping3DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping3DArgs.cs new file mode 100644 index 000000000..2d98e55db --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping3DArgs.cs @@ -0,0 +1,18 @@ +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.ArgsDefinition.Reshaping +{ + public class Cropping3DArgs : LayerArgs + { + /// + /// channel last: (b, h, w, c) + /// channels_first: (b, c, h, w) + /// + public enum DataFormat { channels_first = 0, channels_last = 1 } + /// + /// Accept: int[1][3], int[1][1], int[3][2] + /// + public NDArray cropping { get; set; } + public DataFormat data_format { get; set; } = DataFormat.channels_last; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/CroppingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/CroppingArgs.cs new file mode 100644 index 000000000..21b85966b --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/CroppingArgs.cs @@ -0,0 +1,12 @@ +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.ArgsDefinition.Reshaping +{ + public class Cropping1DArgs : LayerArgs + { + /// + /// Accept length 1 or 2 + /// + public NDArray cropping { get; set; } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/PermuteArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/PermuteArgs.cs index 2686f6cd7..92be10ab1 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/PermuteArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/PermuteArgs.cs @@ -1,5 +1,9 @@ -namespace Tensorflow.Keras.ArgsDefinition { - public class PermuteArgs : LayerArgs { - public int[] dims { get; set; } - } +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition { + public class PermuteArgs : AutoSerializeLayerArgs + { + [JsonProperty("dims")] + public int[] dims { get; set; } + } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs index 77bca8ad0..4d1123c8a 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs @@ -1,7 +1,10 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition { - public class ReshapeArgs : LayerArgs + public class ReshapeArgs : AutoSerializeLayerArgs { + [JsonProperty("target_shape")] public Shape TargetShape { get; set; } public object[] TargetShapeObjects { get; set; } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/UpSampling2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/UpSampling2DArgs.cs index 7fdda32d3..b35e0e4b6 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/UpSampling2DArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/UpSampling2DArgs.cs @@ -1,12 +1,17 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition { - public class UpSampling2DArgs : LayerArgs + public class UpSampling2DArgs : AutoSerializeLayerArgs { + [JsonProperty("size")] public Shape Size { get; set; } + [JsonProperty("data_format")] public string DataFormat { get; set; } /// /// 'nearest', 'bilinear' /// + [JsonProperty("interpolation")] public string Interpolation { get; set; } = "nearest"; } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ZeroPadding2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ZeroPadding2DArgs.cs index ed6e7cc9c..4831e435b 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ZeroPadding2DArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ZeroPadding2DArgs.cs @@ -2,6 +2,7 @@ namespace Tensorflow.Keras.ArgsDefinition { + // TODO: complete the implementation public class ZeroPadding2DArgs : LayerArgs { public NDArray Padding { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs similarity index 67% rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMArgs.cs rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs index b08d21d88..764641474 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs @@ -1,9 +1,8 @@ -using Tensorflow.Keras.ArgsDefinition.Rnn; - -namespace Tensorflow.Keras.ArgsDefinition.Lstm +namespace Tensorflow.Keras.ArgsDefinition.Rnn { public class LSTMArgs : RNNArgs { + // TODO: maybe change the `RNNArgs` and implement this class. public bool UnitForgetBias { get; set; } public float Dropout { get; set; } public float RecurrentDropout { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs new file mode 100644 index 000000000..594c99bb0 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs @@ -0,0 +1,7 @@ +namespace Tensorflow.Keras.ArgsDefinition.Rnn +{ + // TODO: complete the implementation + public class LSTMCellArgs : LayerArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs index da5279257..2585592c1 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs @@ -1,21 +1,30 @@ -using System.Collections.Generic; +using Newtonsoft.Json; +using System.Collections.Generic; namespace Tensorflow.Keras.ArgsDefinition.Rnn { - public class RNNArgs : LayerArgs + public class RNNArgs : AutoSerializeLayerArgs { public interface IRnnArgCell : ILayer { object state_size { get; } } - + [JsonProperty("cell")] + // TODO: the cell should be serialized with `serialize_keras_object`. public IRnnArgCell Cell { get; set; } = null; + [JsonProperty("return_sequences")] public bool ReturnSequences { get; set; } = false; + [JsonProperty("return_state")] public bool ReturnState { get; set; } = false; + [JsonProperty("go_backwards")] public bool GoBackwards { get; set; } = false; + [JsonProperty("stateful")] public bool Stateful { get; set; } = false; + [JsonProperty("unroll")] public bool Unroll { get; set; } = false; + [JsonProperty("time_major")] public bool TimeMajor { get; set; } = false; + // TODO: Add `num_constants` and `zero_output_for_mask`. public Dictionary Kwargs { get; set; } = null; public int Units { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs index 602e7a880..3578652ee 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs @@ -1,5 +1,5 @@ using System; -using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.ArgsDefinition.Reshaping; using Tensorflow.NumPy; namespace Tensorflow.Keras.Layers diff --git a/src/TensorFlowNET.Keras/Activations.cs b/src/TensorFlowNET.Keras/Activations.cs new file mode 100644 index 000000000..444c783e0 --- /dev/null +++ b/src/TensorFlowNET.Keras/Activations.cs @@ -0,0 +1,82 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Operations.Activation; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras +{ + public class Activations + { + private static Dictionary _nameActivationMap; + private static Dictionary _activationNameMap; + + private static Activation _linear = (features, name) => features; + private static Activation _relu = (features, name) + => tf.Context.ExecuteOp("Relu", name, new ExecuteOpArgs(features)); + private static Activation _sigmoid = (features, name) + => tf.Context.ExecuteOp("Sigmoid", name, new ExecuteOpArgs(features)); + private static Activation _softmax = (features, name) + => tf.Context.ExecuteOp("Softmax", name, new ExecuteOpArgs(features)); + private static Activation _tanh = (features, name) + => tf.Context.ExecuteOp("Tanh", name, new ExecuteOpArgs(features)); + + /// + /// Register the name-activation mapping in this static class. + /// + /// + /// + private static void RegisterActivation(string name, Activation activation) + { + _nameActivationMap[name] = activation; + _activationNameMap[activation] = name; + } + + static Activations() + { + _nameActivationMap = new Dictionary(); + _activationNameMap= new Dictionary(); + + RegisterActivation("relu", _relu); + RegisterActivation("linear", _linear); + RegisterActivation("sigmoid", _sigmoid); + RegisterActivation("softmax", _softmax); + RegisterActivation("tanh", _tanh); + } + + public Activation Linear => _linear; + + public Activation Relu => _relu; + + public Activation Sigmoid => _sigmoid; + + public Activation Softmax => _softmax; + + public Activation Tanh => _tanh; + + + public static Activation GetActivationByName(string name) + { + if (!_nameActivationMap.TryGetValue(name, out var res)) + { + throw new Exception($"Activation {name} not found"); + } + else + { + return res; + } + } + + public static string GetNameByActivation(Activation activation) + { + if(!_activationNameMap.TryGetValue(activation, out var name)) + { + throw new Exception($"Activation {activation} not found"); + } + else + { + return name; + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Activations/Activations.Linear.cs b/src/TensorFlowNET.Keras/Activations/Activations.Linear.cs deleted file mode 100644 index acd4de6e7..000000000 --- a/src/TensorFlowNET.Keras/Activations/Activations.Linear.cs +++ /dev/null @@ -1,10 +0,0 @@ -namespace Tensorflow.Keras -{ - public partial class Activations - { - /// - /// Linear activation function (pass-through). - /// - public Activation Linear = (features, name) => features; - } -} diff --git a/src/TensorFlowNET.Keras/Activations/Activations.Relu.cs b/src/TensorFlowNET.Keras/Activations/Activations.Relu.cs deleted file mode 100644 index dfebfb297..000000000 --- a/src/TensorFlowNET.Keras/Activations/Activations.Relu.cs +++ /dev/null @@ -1,10 +0,0 @@ -using static Tensorflow.Binding; - -namespace Tensorflow.Keras -{ - public partial class Activations - { - public Activation Relu = (features, name) - => tf.Context.ExecuteOp("Relu", name, new ExecuteOpArgs(features)); - } -} diff --git a/src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs b/src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs deleted file mode 100644 index ad900bdef..000000000 --- a/src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs +++ /dev/null @@ -1,11 +0,0 @@ -using System; -using static Tensorflow.Binding; - -namespace Tensorflow.Keras -{ - public partial class Activations - { - public Activation Sigmoid = (features, name) - => tf.Context.ExecuteOp("Sigmoid", name, new ExecuteOpArgs(features)); - } -} diff --git a/src/TensorFlowNET.Keras/Activations/Activations.Softmax.cs b/src/TensorFlowNET.Keras/Activations/Activations.Softmax.cs deleted file mode 100644 index 02d86acea..000000000 --- a/src/TensorFlowNET.Keras/Activations/Activations.Softmax.cs +++ /dev/null @@ -1,11 +0,0 @@ -using System; -using static Tensorflow.Binding; - -namespace Tensorflow.Keras -{ - public partial class Activations - { - public Activation Softmax = (features, name) - => tf.Context.ExecuteOp("Softmax", name, new ExecuteOpArgs(features)); - } -} diff --git a/src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs b/src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs deleted file mode 100644 index 33dc5ba62..000000000 --- a/src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs +++ /dev/null @@ -1,11 +0,0 @@ -using System; -using static Tensorflow.Binding; - -namespace Tensorflow.Keras -{ - public partial class Activations - { - public Activation Tanh = (features, name) - => tf.Context.ExecuteOp("Tanh", name, new ExecuteOpArgs(features)); - } -} diff --git a/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs b/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs index 1b82e0a96..701724d5b 100644 --- a/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs +++ b/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs @@ -1,4 +1,5 @@ using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.ArgsDefinition.Core; using Tensorflow.Keras.Engine; using Tensorflow.NumPy; using static Tensorflow.Binding; diff --git a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs index 0f387570b..af71ddf9f 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs @@ -4,8 +4,8 @@ using System.Collections.Generic; using System.Linq; using System.Text.RegularExpressions; -using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.ArgsDefinition.Core; namespace Tensorflow.Keras.Layers { diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs deleted file mode 100644 index 1f33ee3af..000000000 --- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs +++ /dev/null @@ -1,114 +0,0 @@ -using Tensorflow.Keras.ArgsDefinition; -using Tensorflow.Keras.Engine; - -namespace Tensorflow.Keras.Layers { - /// - /// Crop the input along axis 1 and 2. - /// For example: - /// shape (1, 5, 5, 5) -- crop2D ((1, 2), (1, 3)) --> shape (1, 2, 1, 5) - /// - public class Cropping2D : Layer { - Cropping2DArgs args; - public Cropping2D ( Cropping2DArgs args ) : base(args) { - this.args = args; - } - public override void build(Shape input_shape) { - built = true; - _buildInputShape = input_shape; - } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { - Tensor output = inputs; - if ( output.rank != 4 ) { - // throw an ValueError exception - throw new ValueError("Expected dim=4, found dim=" + output.rank); - } - if ( args.cropping.shape == new Shape(1) ) { - int crop = args.cropping[0]; - if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) { - output = output[new Slice(), - new Slice(crop, ( int ) output.shape[1] - crop), - new Slice(crop, ( int ) output.shape[2] - crop), - new Slice()]; - } - else { - output = output[new Slice(), - new Slice(), - new Slice(crop, ( int ) output.shape[2] - crop), - new Slice(crop, ( int ) output.shape[3] - crop)]; - } - } - // a tuple of 2 integers - else if ( args.cropping.shape == new Shape(2) ) { - int crop_1 = args.cropping[0]; - int crop_2 = args.cropping[1]; - if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) { - output = output[new Slice(), - new Slice(crop_1, ( int ) output.shape[1] - crop_1), - new Slice(crop_2, ( int ) output.shape[2] - crop_2), - new Slice()]; - } - else { - output = output[new Slice(), - new Slice(), - new Slice(crop_1, ( int ) output.shape[2] - crop_1), - new Slice(crop_2, ( int ) output.shape[3] - crop_2)]; - } - } - else if ( args.cropping.shape[0] == 2 && args.cropping.shape[1] == 2 ) { - int x_start = args.cropping[0, 0], x_end = args.cropping[0, 1]; - int y_start = args.cropping[1, 0], y_end = args.cropping[1, 1]; - if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) { - output = output[new Slice(), - new Slice(x_start, ( int ) output.shape[1] - x_end), - new Slice(y_start, ( int ) output.shape[2] - y_end), - new Slice()]; - } - else { - output = output[new Slice(), - new Slice(), - new Slice(x_start, ( int ) output.shape[2] - x_end), - new Slice(y_start, ( int ) output.shape[3] - y_end) - ]; - } - } - return output; - } - - public override Shape ComputeOutputShape ( Shape input_shape ) { - if ( args.cropping.shape == new Shape(1) ) { - int crop = args.cropping[0]; - if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1] - crop * 2, ( int ) input_shape[2] - crop * 2, ( int ) input_shape[3]); - } - else { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1], ( int ) input_shape[2] - crop * 2, ( int ) input_shape[3] - crop * 2); - } - } - // a tuple of 2 integers - else if ( args.cropping.shape == new Shape(2) ) { - int crop_1 = args.cropping[0], crop_2 = args.cropping[1]; - if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1] - crop_1 * 2, ( int ) input_shape[2] - crop_2 * 2, ( int ) input_shape[3]); - } - else { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1], ( int ) input_shape[2] - crop_1 * 2, ( int ) input_shape[3] - crop_2 * 2); - } - } - else if ( args.cropping.shape == new Shape(2, 2) ) { - int crop_1_start = args.cropping[0, 0], crop_1_end = args.cropping[0, 1]; - int crop_2_start = args.cropping[1, 0], crop_2_end = args.cropping[1, 1]; - if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1] - crop_1_start - crop_1_end, - ( int ) input_shape[2] - crop_2_start - crop_2_end, ( int ) input_shape[3]); - } - else { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1], - ( int ) input_shape[2] - crop_1_start - crop_1_end, ( int ) input_shape[3] - crop_2_start - crop_2_end); - } - } - else { - throw new ValueError(); - } - } - } -} diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs deleted file mode 100644 index 838a50434..000000000 --- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs +++ /dev/null @@ -1,124 +0,0 @@ -using Tensorflow.Keras.ArgsDefinition; -using Tensorflow.Keras.Engine; - -namespace Tensorflow.Keras.Layers { - /// - /// Similar to copping 2D - /// - public class Cropping3D : Layer { - Cropping3DArgs args; - public Cropping3D ( Cropping3DArgs args ) : base(args) { - this.args = args; - } - - public override void build(Shape input_shape) { - built = true; - _buildInputShape = input_shape; - } - - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { - Tensor output = inputs; - if ( output.rank != 5 ) { - // throw an ValueError exception - throw new ValueError("Expected dim=5, found dim=" + output.rank); - } - - if ( args.cropping.shape == new Shape(1) ) { - int crop = args.cropping[0]; - if ( args.data_format == Cropping3DArgs.DataFormat.channels_last ) { - output = output[new Slice(), - new Slice(crop, ( int ) output.shape[1] - crop), - new Slice(crop, ( int ) output.shape[2] - crop), - new Slice(crop, ( int ) output.shape[3] - crop), - new Slice()]; - } - else { - output = output[new Slice(), - new Slice(), - new Slice(crop, ( int ) output.shape[2] - crop), - new Slice(crop, ( int ) output.shape[3] - crop), - new Slice(crop, ( int ) output.shape[4] - crop)]; - } - - } - // int[1][3] equivalent to a tuple of 3 integers - else if ( args.cropping.shape == new Shape(3) ) { - var crop_1 = args.cropping[0]; - var crop_2 = args.cropping[1]; - var crop_3 = args.cropping[2]; - if ( args.data_format == Cropping3DArgs.DataFormat.channels_last ) { - output = output[new Slice(), - new Slice(crop_1, ( int ) output.shape[1] - crop_1), - new Slice(crop_2, ( int ) output.shape[2] - crop_2), - new Slice(crop_3, ( int ) output.shape[3] - crop_3), - new Slice()]; - } - else { - output = output[new Slice(), - new Slice(), - new Slice(crop_1, ( int ) output.shape[2] - crop_1), - new Slice(crop_2, ( int ) output.shape[3] - crop_2), - new Slice(crop_3, ( int ) output.shape[4] - crop_3)]; - } - } - else if ( args.cropping.shape[0] == 3 && args.cropping.shape[1] == 2 ) { - int x = args.cropping[0, 0], x_end = args.cropping[0, 1]; - int y = args.cropping[1, 0], y_end = args.cropping[1, 1]; - int z = args.cropping[2, 0], z_end = args.cropping[2, 1]; - if ( args.data_format == Cropping3DArgs.DataFormat.channels_last ) { - output = output[new Slice(), - new Slice(x, ( int ) output.shape[1] - x_end), - new Slice(y, ( int ) output.shape[2] - y_end), - new Slice(z, ( int ) output.shape[3] - z_end), - new Slice()]; - } - else { - output = output[new Slice(), - new Slice(), - new Slice(x, ( int ) output.shape[2] - x_end), - new Slice(y, ( int ) output.shape[3] - y_end), - new Slice(z, ( int ) output.shape[4] - z_end) - ]; - } - } - return output; - } - public override Shape ComputeOutputShape ( Shape input_shape ) { - if ( args.cropping.shape == new Shape(1) ) { - int crop = args.cropping[0]; - if ( args.data_format == Cropping3DArgs.DataFormat.channels_last ) { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1] - crop * 2, ( int ) input_shape[2] - crop * 2, ( int ) input_shape[3] - crop * 2, ( int ) input_shape[4]); - } - else { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1], ( int ) input_shape[2] - crop * 2, ( int ) input_shape[3] - crop * 2, ( int ) input_shape[4] - crop * 2); - } - } - // int[1][3] equivalent to a tuple of 3 integers - else if ( args.cropping.shape == new Shape(3) ) { - var crop_start_1 = args.cropping[0]; - var crop_start_2 = args.cropping[1]; - var crop_start_3 = args.cropping[2]; - if ( args.data_format == Cropping3DArgs.DataFormat.channels_last ) { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1] - crop_start_1 * 2, ( int ) input_shape[2] - crop_start_2 * 2, ( int ) input_shape[3] - crop_start_3 * 2, ( int ) input_shape[4]); - } - else { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1], ( int ) input_shape[2] - crop_start_1 * 2, ( int ) input_shape[3] - crop_start_2 * 2, ( int ) input_shape[4] - crop_start_3 * 2); - } - } - else if ( args.cropping.shape == new Shape(3, 2) ) { - int x = args.cropping[0, 0], x_end = args.cropping[0, 1]; - int y = args.cropping[1, 0], y_end = args.cropping[1, 1]; - int z = args.cropping[2, 0], z_end = args.cropping[2, 1]; - if ( args.data_format == Cropping3DArgs.DataFormat.channels_last ) { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1] - x - x_end, ( int ) input_shape[2] - y - y_end, ( int ) input_shape[3] - z - z_end, ( int ) input_shape[4]); - } - else { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1], ( int ) input_shape[2] - x - x_end, ( int ) input_shape[3] - y - y_end, ( int ) input_shape[4] - z - z_end); - } - } - else { - throw new ValueError(); - } - } - } -} diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.Cropping.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.Cropping.cs index 339ddb85b..3e3442f25 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.Cropping.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.Cropping.cs @@ -2,16 +2,18 @@ using System; using System.Collections.Generic; using System.Text; -using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Layers.Reshaping; +using Tensorflow.Keras.ArgsDefinition.Reshaping; -namespace Tensorflow.Keras.Layers { - public partial class LayersApi { +namespace Tensorflow.Keras.Layers +{ + public partial class LayersApi { /// /// Cropping layer for 1D input /// /// cropping size public ILayer Cropping1D ( NDArray cropping ) - => new Cropping1D(new CroppingArgs { + => new Cropping1D(new Cropping1DArgs { cropping = cropping }); diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 769beea0a..76634918d 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -1,9 +1,8 @@ using System; using Tensorflow.Keras.ArgsDefinition; -using Tensorflow.Keras.ArgsDefinition.Lstm; +using Tensorflow.Keras.ArgsDefinition.Core; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; -using Tensorflow.Keras.Layers.Lstm; using Tensorflow.Keras.Layers.Rnn; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -108,7 +107,7 @@ public ILayer Conv1D(int filters, DilationRate = dilation_rate, Groups = groups, UseBias = use_bias, - Activation = GetActivationByName(activation), + Activation = Activations.GetActivationByName(activation), KernelInitializer = GetInitializerByName(kernel_initializer), BiasInitializer = GetInitializerByName(bias_initializer) }); @@ -163,7 +162,7 @@ public ILayer Conv2D(int filters, BiasInitializer = bias_initializer == null ? tf.zeros_initializer : bias_initializer, BiasRegularizer = bias_regularizer, ActivityRegularizer = activity_regularizer, - Activation = activation ?? keras.activations.Linear + Activation = activation ?? keras.activations.Linear, }); /// @@ -210,7 +209,8 @@ public ILayer Conv2D(int filters, UseBias = use_bias, KernelInitializer = GetInitializerByName(kernel_initializer), BiasInitializer = GetInitializerByName(bias_initializer), - Activation = GetActivationByName(activation) + Activation = Activations.GetActivationByName(activation), + ActivationName = activation }); /// @@ -255,7 +255,7 @@ public ILayer Conv2DTranspose(int filters, UseBias = use_bias, KernelInitializer = GetInitializerByName(kernel_initializer), BiasInitializer = GetInitializerByName(bias_initializer), - Activation = GetActivationByName(activation) + Activation = Activations.GetActivationByName(activation) }); /// @@ -300,7 +300,7 @@ public ILayer Dense(int units) => new Dense(new DenseArgs { Units = units, - Activation = GetActivationByName("linear"), + Activation = Activations.GetActivationByName("linear"), ActivationName = "linear" }); @@ -321,7 +321,7 @@ public ILayer Dense(int units, => new Dense(new DenseArgs { Units = units, - Activation = GetActivationByName(activation), + Activation = Activations.GetActivationByName(activation), ActivationName = activation, InputShape = input_shape }); @@ -666,7 +666,7 @@ public ILayer SimpleRNN(int units, => new SimpleRNN(new SimpleRNNArgs { Units = units, - Activation = GetActivationByName(activation), + Activation = Activations.GetActivationByName(activation), KernelInitializer = GetInitializerByName(kernel_initializer), RecurrentInitializer = GetInitializerByName(recurrent_initializer), BiasInitializer = GetInitializerByName(bias_initializer), @@ -814,24 +814,7 @@ public ILayer GlobalMaxPooling1D(string data_format = "channels_last") public ILayer GlobalMaxPooling2D(string data_format = "channels_last") => new GlobalMaxPooling2D(new Pooling2DArgs { DataFormat = data_format }); - - /// - /// Get an activation function layer from its name. - /// - /// The name of the activation function. One of linear, relu, sigmoid, and tanh. - /// - - Activation GetActivationByName(string name) - => name switch - { - "linear" => keras.activations.Linear, - "relu" => keras.activations.Relu, - "sigmoid" => keras.activations.Sigmoid, - "tanh" => keras.activations.Tanh, - "softmax" => keras.activations.Softmax, - _ => throw new Exception($"Activation {name} not found") - }; - + Activation GetActivationByName(string name) => Activations.GetActivationByName(name); /// /// Get an weights initializer from its name. /// diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs index c0b16c812..3b8e1ee8d 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs @@ -58,7 +58,7 @@ public override void build(Shape input_shape) var ndims = input_shape.ndim; foreach (var (idx, x) in enumerate(axis)) if (x < 0) - axis[idx] = ndims + x; + args.Axis.dims[idx] = axis[idx] = ndims + x; fused = ndims == 4; diff --git a/src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs similarity index 100% rename from src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs rename to src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs similarity index 79% rename from src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs rename to src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs index 44b338c25..10c15b698 100644 --- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs @@ -1,11 +1,12 @@ -using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.ArgsDefinition.Reshaping; using Tensorflow.Keras.Engine; -namespace Tensorflow.Keras.Layers { +namespace Tensorflow.Keras.Layers.Reshaping +{ public class Cropping1D : Layer { - CroppingArgs args; - public Cropping1D(CroppingArgs args) : base(args) + Cropping1DArgs args; + public Cropping1D(Cropping1DArgs args) : base(args) { this.args = args; } @@ -41,7 +42,7 @@ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? train else { int crop_start = args.cropping[0], crop_end = args.cropping[1]; - output = output[new Slice(), new Slice(crop_start, (int)(output.shape[1]) - crop_end), new Slice()]; + output = output[new Slice(), new Slice(crop_start, (int)output.shape[1] - crop_end), new Slice()]; } return output; } @@ -51,12 +52,12 @@ public override Shape ComputeOutputShape(Shape input_shape) if (args.cropping.shape[0] == 1) { int crop = args.cropping[0]; - return new Shape((int)(input_shape[0]), (int)(input_shape[1] - crop * 2), (int)(input_shape[2])); + return new Shape((int)input_shape[0], (int)(input_shape[1] - crop * 2), (int)input_shape[2]); } else { int crop_start = args.cropping[0], crop_end = args.cropping[1]; - return new Shape((int)(input_shape[0]), (int)(input_shape[1] - crop_start - crop_end), (int)(input_shape[2])); + return new Shape((int)input_shape[0], (int)(input_shape[1] - crop_start - crop_end), (int)input_shape[2]); } } } diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs new file mode 100644 index 000000000..a8d7043ed --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs @@ -0,0 +1,140 @@ +using Tensorflow.Keras.ArgsDefinition.Reshaping; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Layers.Reshaping +{ + /// + /// Crop the input along axis 1 and 2. + /// For example: + /// shape (1, 5, 5, 5) -- crop2D ((1, 2), (1, 3)) --> shape (1, 2, 1, 5) + /// + public class Cropping2D : Layer + { + Cropping2DArgs args; + public Cropping2D(Cropping2DArgs args) : base(args) + { + this.args = args; + } + public override void build(Shape input_shape) + { + built = true; + _buildInputShape = input_shape; + } + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + { + Tensor output = inputs; + if (output.rank != 4) + { + // throw an ValueError exception + throw new ValueError("Expected dim=4, found dim=" + output.rank); + } + if (args.cropping.shape == new Shape(1)) + { + int crop = args.cropping[0]; + if (args.data_format == Cropping2DArgs.DataFormat.channels_last) + { + output = output[new Slice(), + new Slice(crop, (int)output.shape[1] - crop), + new Slice(crop, (int)output.shape[2] - crop), + new Slice()]; + } + else + { + output = output[new Slice(), + new Slice(), + new Slice(crop, (int)output.shape[2] - crop), + new Slice(crop, (int)output.shape[3] - crop)]; + } + } + // a tuple of 2 integers + else if (args.cropping.shape == new Shape(2)) + { + int crop_1 = args.cropping[0]; + int crop_2 = args.cropping[1]; + if (args.data_format == Cropping2DArgs.DataFormat.channels_last) + { + output = output[new Slice(), + new Slice(crop_1, (int)output.shape[1] - crop_1), + new Slice(crop_2, (int)output.shape[2] - crop_2), + new Slice()]; + } + else + { + output = output[new Slice(), + new Slice(), + new Slice(crop_1, (int)output.shape[2] - crop_1), + new Slice(crop_2, (int)output.shape[3] - crop_2)]; + } + } + else if (args.cropping.shape[0] == 2 && args.cropping.shape[1] == 2) + { + int x_start = args.cropping[0, 0], x_end = args.cropping[0, 1]; + int y_start = args.cropping[1, 0], y_end = args.cropping[1, 1]; + if (args.data_format == Cropping2DArgs.DataFormat.channels_last) + { + output = output[new Slice(), + new Slice(x_start, (int)output.shape[1] - x_end), + new Slice(y_start, (int)output.shape[2] - y_end), + new Slice()]; + } + else + { + output = output[new Slice(), + new Slice(), + new Slice(x_start, (int)output.shape[2] - x_end), + new Slice(y_start, (int)output.shape[3] - y_end) + ]; + } + } + return output; + } + + public override Shape ComputeOutputShape(Shape input_shape) + { + if (args.cropping.shape == new Shape(1)) + { + int crop = args.cropping[0]; + if (args.data_format == Cropping2DArgs.DataFormat.channels_last) + { + return new Shape((int)input_shape[0], (int)input_shape[1] - crop * 2, (int)input_shape[2] - crop * 2, (int)input_shape[3]); + } + else + { + return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2); + } + } + // a tuple of 2 integers + else if (args.cropping.shape == new Shape(2)) + { + int crop_1 = args.cropping[0], crop_2 = args.cropping[1]; + if (args.data_format == Cropping2DArgs.DataFormat.channels_last) + { + return new Shape((int)input_shape[0], (int)input_shape[1] - crop_1 * 2, (int)input_shape[2] - crop_2 * 2, (int)input_shape[3]); + } + else + { + return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop_1 * 2, (int)input_shape[3] - crop_2 * 2); + } + } + else if (args.cropping.shape == new Shape(2, 2)) + { + int crop_1_start = args.cropping[0, 0], crop_1_end = args.cropping[0, 1]; + int crop_2_start = args.cropping[1, 0], crop_2_end = args.cropping[1, 1]; + if (args.data_format == Cropping2DArgs.DataFormat.channels_last) + { + return new Shape((int)input_shape[0], (int)input_shape[1] - crop_1_start - crop_1_end, + (int)input_shape[2] - crop_2_start - crop_2_end, (int)input_shape[3]); + } + else + { + return new Shape((int)input_shape[0], (int)input_shape[1], + (int)input_shape[2] - crop_1_start - crop_1_end, (int)input_shape[3] - crop_2_start - crop_2_end); + } + } + else + { + throw new ValueError(); + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs new file mode 100644 index 000000000..796c2dd33 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs @@ -0,0 +1,150 @@ +using Tensorflow.Keras.ArgsDefinition.Reshaping; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Layers.Reshaping +{ + /// + /// Similar to copping 2D + /// + public class Cropping3D : Layer + { + Cropping3DArgs args; + public Cropping3D(Cropping3DArgs args) : base(args) + { + this.args = args; + } + + public override void build(Shape input_shape) + { + built = true; + _buildInputShape = input_shape; + } + + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + { + Tensor output = inputs; + if (output.rank != 5) + { + // throw an ValueError exception + throw new ValueError("Expected dim=5, found dim=" + output.rank); + } + + if (args.cropping.shape == new Shape(1)) + { + int crop = args.cropping[0]; + if (args.data_format == Cropping3DArgs.DataFormat.channels_last) + { + output = output[new Slice(), + new Slice(crop, (int)output.shape[1] - crop), + new Slice(crop, (int)output.shape[2] - crop), + new Slice(crop, (int)output.shape[3] - crop), + new Slice()]; + } + else + { + output = output[new Slice(), + new Slice(), + new Slice(crop, (int)output.shape[2] - crop), + new Slice(crop, (int)output.shape[3] - crop), + new Slice(crop, (int)output.shape[4] - crop)]; + } + + } + // int[1][3] equivalent to a tuple of 3 integers + else if (args.cropping.shape == new Shape(3)) + { + var crop_1 = args.cropping[0]; + var crop_2 = args.cropping[1]; + var crop_3 = args.cropping[2]; + if (args.data_format == Cropping3DArgs.DataFormat.channels_last) + { + output = output[new Slice(), + new Slice(crop_1, (int)output.shape[1] - crop_1), + new Slice(crop_2, (int)output.shape[2] - crop_2), + new Slice(crop_3, (int)output.shape[3] - crop_3), + new Slice()]; + } + else + { + output = output[new Slice(), + new Slice(), + new Slice(crop_1, (int)output.shape[2] - crop_1), + new Slice(crop_2, (int)output.shape[3] - crop_2), + new Slice(crop_3, (int)output.shape[4] - crop_3)]; + } + } + else if (args.cropping.shape[0] == 3 && args.cropping.shape[1] == 2) + { + int x = args.cropping[0, 0], x_end = args.cropping[0, 1]; + int y = args.cropping[1, 0], y_end = args.cropping[1, 1]; + int z = args.cropping[2, 0], z_end = args.cropping[2, 1]; + if (args.data_format == Cropping3DArgs.DataFormat.channels_last) + { + output = output[new Slice(), + new Slice(x, (int)output.shape[1] - x_end), + new Slice(y, (int)output.shape[2] - y_end), + new Slice(z, (int)output.shape[3] - z_end), + new Slice()]; + } + else + { + output = output[new Slice(), + new Slice(), + new Slice(x, (int)output.shape[2] - x_end), + new Slice(y, (int)output.shape[3] - y_end), + new Slice(z, (int)output.shape[4] - z_end) + ]; + } + } + return output; + } + public override Shape ComputeOutputShape(Shape input_shape) + { + if (args.cropping.shape == new Shape(1)) + { + int crop = args.cropping[0]; + if (args.data_format == Cropping3DArgs.DataFormat.channels_last) + { + return new Shape((int)input_shape[0], (int)input_shape[1] - crop * 2, (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2, (int)input_shape[4]); + } + else + { + return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2, (int)input_shape[4] - crop * 2); + } + } + // int[1][3] equivalent to a tuple of 3 integers + else if (args.cropping.shape == new Shape(3)) + { + var crop_start_1 = args.cropping[0]; + var crop_start_2 = args.cropping[1]; + var crop_start_3 = args.cropping[2]; + if (args.data_format == Cropping3DArgs.DataFormat.channels_last) + { + return new Shape((int)input_shape[0], (int)input_shape[1] - crop_start_1 * 2, (int)input_shape[2] - crop_start_2 * 2, (int)input_shape[3] - crop_start_3 * 2, (int)input_shape[4]); + } + else + { + return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop_start_1 * 2, (int)input_shape[3] - crop_start_2 * 2, (int)input_shape[4] - crop_start_3 * 2); + } + } + else if (args.cropping.shape == new Shape(3, 2)) + { + int x = args.cropping[0, 0], x_end = args.cropping[0, 1]; + int y = args.cropping[1, 0], y_end = args.cropping[1, 1]; + int z = args.cropping[2, 0], z_end = args.cropping[2, 1]; + if (args.data_format == Cropping3DArgs.DataFormat.channels_last) + { + return new Shape((int)input_shape[0], (int)input_shape[1] - x - x_end, (int)input_shape[2] - y - y_end, (int)input_shape[3] - z - z_end, (int)input_shape[4]); + } + else + { + return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - x - x_end, (int)input_shape[3] - y - y_end, (int)input_shape[4] - z - z_end); + } + } + else + { + throw new ValueError(); + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Lstm/LSTM.cs b/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs similarity index 87% rename from src/TensorFlowNET.Keras/Layers/Lstm/LSTM.cs rename to src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs index b7d973847..59555e62b 100644 --- a/src/TensorFlowNET.Keras/Layers/Lstm/LSTM.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs @@ -1,9 +1,8 @@ using System.Linq; -using Tensorflow.Keras.ArgsDefinition.Lstm; +using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; -using Tensorflow.Keras.Layers.Rnn; -namespace Tensorflow.Keras.Layers.Lstm +namespace Tensorflow.Keras.Layers.Rnn { /// /// Long Short-Term Memory layer - Hochreiter 1997. diff --git a/src/TensorFlowNET.Keras/Layers/Lstm/LSTMCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs similarity index 72% rename from src/TensorFlowNET.Keras/Layers/Lstm/LSTMCell.cs rename to src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs index 3cd35a091..a622c91a9 100644 --- a/src/TensorFlowNET.Keras/Layers/Lstm/LSTMCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs @@ -1,7 +1,7 @@ -using Tensorflow.Keras.ArgsDefinition.Lstm; +using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; -namespace Tensorflow.Keras.Layers.Lstm +namespace Tensorflow.Keras.Layers.Rnn { public class LSTMCell : Layer { diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs index 877c35994..6b755ecee 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs @@ -3,7 +3,6 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; -using Tensorflow.Keras.Layers.Lstm; // from tensorflow.python.distribute import distribution_strategy_context as ds_context; namespace Tensorflow.Keras.Layers.Rnn diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs deleted file mode 100644 index 288a92b32..000000000 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs +++ /dev/null @@ -1,82 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using Tensorflow.NumPy; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; -using Tensorflow; -using static Tensorflow.Binding; -using static Tensorflow.KerasApi; -using Tensorflow.Keras; -using Tensorflow.Keras.ArgsDefinition; -using Tensorflow.Keras.Engine; -using Tensorflow.Keras.Layers; -using Tensorflow.Keras.Losses; -using Tensorflow.Keras.Metrics; -using Tensorflow.Keras.Optimizers; -using Tensorflow.Operations; - -namespace TensorFlowNET.Keras.UnitTest.SaveModel; - -[TestClass] -public class SequentialModelTest -{ - [TestMethod] - public void SimpleModelFromAutoCompile() - { - var inputs = new KerasInterface().Input((28, 28, 1)); - var x = new Flatten(new FlattenArgs()).Apply(inputs); - x = new Dense(new DenseArgs() { Units = 100, Activation = tf.nn.relu }).Apply(x); - x = new LayersApi().Dense(units: 10).Apply(x); - var outputs = new LayersApi().Softmax(axis: 1).Apply(x); - var model = new KerasInterface().Model(inputs, outputs); - - model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); - - var data_loader = new MnistModelLoader(); - var num_epochs = 1; - var batch_size = 50; - - var dataset = data_loader.LoadAsync(new ModelLoadSetting - { - TrainDir = "mnist", - OneHot = false, - ValidationSize = 10000, - }).Result; - - model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); - - model.save("C:\\Work\\tf.net\\tf_test\\tf.net.simple.compile", save_format: "tf"); - } - - [TestMethod] - public void SimpleModelFromSequential() - { - Model model = KerasApi.keras.Sequential(new List() - { - keras.layers.InputLayer((28, 28, 1)), - keras.layers.Flatten(), - keras.layers.Dense(100, "relu"), - keras.layers.Dense(10), - keras.layers.Softmax(1) - }); - - model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); - - var data_loader = new MnistModelLoader(); - var num_epochs = 1; - var batch_size = 50; - - var dataset = data_loader.LoadAsync(new ModelLoadSetting - { - TrainDir = "mnist", - OneHot = false, - ValidationSize = 10000, - }).Result; - - model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); - - model.save("C:\\Work\\tf.net\\tf_test\\tf.net.simple.sequential", save_format: "tf"); - } -} \ No newline at end of file From 88fe402b7fa39d180cfa6385b4493af82c2e7be2 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Sat, 4 Feb 2023 21:04:00 +0800 Subject: [PATCH 14/15] Add alexnet pb save test. --- .../SaveModel/SequentialModelTest.cs | 202 ++++++++++++++++++ 1 file changed, 202 insertions(+) create mode 100644 test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs new file mode 100644 index 000000000..c31453448 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs @@ -0,0 +1,202 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Tensorflow; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using Tensorflow.Keras; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Optimizers; +using Tensorflow.Operations; +using System.Diagnostics; + +namespace TensorFlowNET.Keras.UnitTest.SaveModel; + +[TestClass] +public class SequentialModelTest +{ + [TestMethod] + public void SimpleModelFromAutoCompile() + { + var inputs = new KerasInterface().Input((28, 28, 1)); + var x = new Flatten(new FlattenArgs()).Apply(inputs); + x = new Dense(new DenseArgs() { Units = 100, Activation = tf.nn.relu }).Apply(x); + x = new LayersApi().Dense(units: 10).Apply(x); + var outputs = new LayersApi().Softmax(axis: 1).Apply(x); + var model = new KerasInterface().Model(inputs, outputs); + + model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); + + var data_loader = new MnistModelLoader(); + var num_epochs = 1; + var batch_size = 50; + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 10000, + }).Result; + + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); + + model.save("./pb_simple_compile", save_format: "tf"); + } + + [TestMethod] + public void SimpleModelFromSequential() + { + Model model = KerasApi.keras.Sequential(new List() + { + keras.layers.InputLayer((28, 28, 1)), + keras.layers.Flatten(), + keras.layers.Dense(100, "relu"), + keras.layers.Dense(10), + keras.layers.Softmax(1) + }); + + model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); + + var data_loader = new MnistModelLoader(); + var num_epochs = 1; + var batch_size = 50; + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 10000, + }).Result; + + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); + + model.save("./pb_simple_sequential", save_format: "tf"); + } + + [TestMethod] + public void AlexModelFromSequential() + { + Model model = KerasApi.keras.Sequential(new List() + { + keras.layers.InputLayer((227, 227, 3)), + keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"), + keras.layers.BatchNormalization(), + keras.layers.MaxPooling2D((3, 3), strides:(2, 2)), + + keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: "relu"), + keras.layers.BatchNormalization(), + keras.layers.MaxPooling2D((3, 3), (2, 2)), + + keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"), + keras.layers.BatchNormalization(), + + keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"), + keras.layers.BatchNormalization(), + + keras.layers.Conv2D(256, (3, 3), (1, 1), "same", activation: "relu"), + keras.layers.BatchNormalization(), + keras.layers.MaxPooling2D((3, 3), (2, 2)), + + keras.layers.Flatten(), + keras.layers.Dense(4096, activation: "relu"), + keras.layers.Dropout(0.5f), + + keras.layers.Dense(4096, activation: "relu"), + keras.layers.Dropout(0.5f), + + keras.layers.Dense(1000, activation: "linear"), + keras.layers.Softmax(1) + }); + + model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits:true), new string[] { "accuracy" }); + + var num_epochs = 1; + var batch_size = 16; + + var dataset = new RandomDataSet(new Shape(227, 227, 3), 16); + + model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); + + model.save("./pb_elex_sequential", save_format: "tf"); + + // The saved model can be test with the following python code: + #region alexnet_python_code + //import pathlib + //import tensorflow as tf + + //def func(a): + // return -a + + //if __name__ == '__main__': + // model = tf.keras.models.load_model("./pb_elex_sequential") + // model.summary() + + // num_classes = 5 + // batch_size = 128 + // img_height = 227 + // img_width = 227 + // epochs = 100 + + // dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz" + // data_dir = tf.keras.utils.get_file('flower_photos', origin = dataset_url, untar = True) + // data_dir = pathlib.Path(data_dir) + + // train_ds = tf.keras.preprocessing.image_dataset_from_directory( + // data_dir, + // validation_split = 0.2, + // subset = "training", + // seed = 123, + // image_size = (img_height, img_width), + // batch_size = batch_size) + + // val_ds = tf.keras.preprocessing.image_dataset_from_directory( + // data_dir, + // validation_split = 0.2, + // subset = "validation", + // seed = 123, + // image_size = (img_height, img_width), + // batch_size = batch_size) + + + // model.compile(optimizer = 'adam', + // loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True), + // metrics =['accuracy']) + + // model.build((None, img_height, img_width, 3)) + + // history = model.fit( + // train_ds, + // validation_data = val_ds, + // epochs = epochs + // ) + #endregion + } + + public class RandomDataSet : DataSetBase + { + private Shape _shape; + + public RandomDataSet(Shape shape, int count) + { + _shape = shape; + Debug.Assert(_shape.ndim == 3); + long[] dims = new long[4]; + dims[0] = count; + for (int i = 1; i < 4; i++) + { + dims[i] = _shape[i - 1]; + } + Shape s = new Shape(dims); + Data = np.random.normal(0, 2, s); + Labels = np.random.uniform(0, 1, (count, 1)); + } + } +} \ No newline at end of file From 3a6a59e18cbfd3a948f3e4c16259e3ee07c73878 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Sun, 5 Feb 2023 01:08:37 +0800 Subject: [PATCH 15/15] Check and refine the code. --- .../Checkpoint/CheckPointUtils.cs | 6 +- .../Checkpoint/CheckpointOptions.cs | 2 +- .../Checkpoint/ObjectGraphView.cs | 4 +- src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs | 4 +- .../Checkpoint/SaveUtilV1.cs | 19 ++- .../Checkpoint/TrackableView.cs | 2 +- .../Checkpoint/functional_saver.cs | 161 ++++++++---------- src/TensorFlowNET.Core/DisposableObject.cs | 2 +- .../Exceptions/AssertionError.cs | 2 +- .../Training/AutoTrackable.cs | 6 +- .../Training/Saving/SaveableObject.cs | 4 +- .../Training/Saving/SavedModel/AssetInfo.cs | 2 +- .../Saving/SavedModel/AugmentedGraphView.cs | 4 +- .../Training/Saving/SavedModel/Constants.cs | 2 +- .../Saving/SavedModel/RevivedTypes.cs | 2 +- .../Training/Saving/SavedModel/SaveType.cs | 2 +- .../Saving/SavedModel/SaveableView.cs | 16 +- .../Saving/SavedModel/TagConstants.cs | 2 +- .../Training/Saving/SavedModel/builder.cs | 2 +- .../Training/Saving/SavedModel/save.cs | 6 +- .../SavedModel/signature_serialization.cs | 2 +- .../Training/Saving/SavedModel/utils.cs | 2 +- .../Saving/saveable_object_util.py.cs | 21 ++- .../Saving/SavedModel/Constants.cs | 2 +- .../Saving/SavedModel/KerasObjectWrapper.cs | 11 -- .../Saving/SavedModel/Save.cs | 66 ++++++- .../Saving/SavedModel/SaveImpl.cs | 66 ------- .../Saving/SavedModel/base_serialization.cs | 3 +- .../Saving/SavedModel/layer_serialization.cs | 2 +- .../Saving/SavedModel/utils.cs | 2 +- .../SaveModel/SequentialModelTest.cs | 8 +- 31 files changed, 194 insertions(+), 241 deletions(-) delete mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs delete mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs index cd37703b6..8ae2dae8f 100644 --- a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs +++ b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs @@ -12,9 +12,9 @@ namespace Tensorflow.Checkpoint; public static class CheckPointUtils { private static string _ESCAPE_CHAR = "."; - public static (List, Dictionary>, Dictionary, + public static (IList, IDictionary>, IDictionary, IDictionary>, - Dictionary) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) + IDictionary) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) { var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); Dictionary object_names = new(); @@ -149,4 +149,4 @@ public static void add_checkpoint_values_check(TrackableObjectGraph object_graph // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); // } } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs b/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs index f14b5ce78..75b392af8 100644 --- a/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs +++ b/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs @@ -2,4 +2,4 @@ public record class CheckpointOptions( string? experimental_io_device = null, - bool experimental_enable_async_checkpoint = false); \ No newline at end of file + bool experimental_enable_async_checkpoint = false); diff --git a/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs b/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs index cb01b539a..f435dd88b 100644 --- a/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs +++ b/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs @@ -45,7 +45,7 @@ public IEnumerable? AttachedDependencies get => _attached_dependencies; } - public virtual (List, Dictionary>) breadth_first_traversal() + public virtual (IList, IDictionary>) breadth_first_traversal() { return base._descendants_with_paths(); } @@ -61,4 +61,4 @@ public void frozen_saveable_objects(object? object_map = null, object? to_graph { throw new NotImplementedException(); } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs index 84e0ca4e1..c54cc93f6 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs @@ -58,7 +58,7 @@ public static (IDictionary, Dictionary) gather_trackable_data(ObjectGraphView graph_view, IDictionary? object_map) + private static (IList, IDictionary) gather_trackable_data(ObjectGraphView graph_view, IDictionary? object_map) { var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); Dictionary object_names = new(); @@ -173,7 +173,7 @@ private static IDictionary>> g tensor_dict[checkpoint_key] = maybe_tensor; - if(maybe_tensor.GetValueA() is SaveSpec) + if(maybe_tensor.IsTypeOrDeriveFrom()) { throw new NotImplementedException(); //((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs index 4f1d04d2e..3267ae126 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs @@ -13,7 +13,7 @@ namespace Tensorflow.Checkpoint; public static class SaveUtilV1 { - public static (Dictionary>, object?) get_checkpoint_factories_and_keys(IDictionary object_names, + public static (IDictionary>, object?) get_checkpoint_factories_and_keys(IDictionary object_names, IDictionary? object_map = null) { // According to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/registration/README.md, @@ -44,7 +44,7 @@ public static (Dictionary>, object return (checkpoint_factory_map, null); } - public static (List, IDictionary>?) frozen_saveables_and_savers(ObjectGraphView graph_view, + public static (IList, IDictionary>?) frozen_saveables_and_savers(ObjectGraphView graph_view, IDictionary object_map, Graph? to_graph, bool call_with_mapped_captures, object? saveables_cache = null) { @@ -73,7 +73,7 @@ public static (List, IDictionary, TrackableObjectGraph, object?, IDictionary>?) serialize_gathered_objects(ObjectGraphView graph_view, + public static (IList, TrackableObjectGraph, object?, IDictionary>?) serialize_gathered_objects(ObjectGraphView graph_view, IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) { var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); @@ -129,7 +129,8 @@ private static TrackableObjectGraph fill_object_graph_proto(ObjectGraphView grap return object_graph_proto; } - private static (List, object?, IDictionary>?) add_attributes_to_object_graph(IList trackable_objects, + private static (IList, object?, IDictionary>?) add_attributes_to_object_graph( + IList trackable_objects, TrackableObjectGraph object_graph_proto, IDictionary node_ids, IDictionary object_names, IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) @@ -150,7 +151,7 @@ private static (List, object?, IDictionary, object?) generate_saveable_objects( + public static (IList, object?) generate_saveable_objects( IDictionary> checkpoint_factory_map, TrackableObjectGraph? object_graph_proto, IDictionary? node_ids, IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) @@ -178,13 +179,13 @@ public static (List, object?) generate_saveable_objects( // TODO: oneflow python has a process with callable `saveable_factory`. List saveables = new(); - if (maybe_saveable.DataType == typeof(MySaveableObject)) + if (maybe_saveable.TryGet(out var s)) { - saveables.Add(maybe_saveable.GetValueB()); + saveables.Add(s); } else { - saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValueA() as Trackable, key)); + saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValue() as Trackable, key)); } foreach (var saveable in saveables) @@ -219,4 +220,4 @@ public record class CheckpointFactoryData Maybe factory, string name, string checkpoint_key -); \ No newline at end of file +); diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs index f89dc10d7..dab6d5d97 100644 --- a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs +++ b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs @@ -52,7 +52,7 @@ public Trackable Root /// Returns a list of all nodes and its paths from self.root using a breadth first traversal. /// Corresponding to tensorflow/python/checkpoint/trackable_view.Trackable._descendants_with_paths /// - protected (List, Dictionary>) _descendants_with_paths() + protected (IList, IDictionary>) _descendants_with_paths() { List bfs_sorted = new(); Queue to_visit = new(); diff --git a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs index 90bbccf07..09904d684 100644 --- a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs +++ b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs @@ -14,112 +14,91 @@ using Tensorflow.Graphs; using System.Xml.Linq; using System.Diagnostics; +using RestoreFunc = System.Func; namespace Tensorflow.Checkpoint { - /// - /// `FunctionHolder` is a series of containers to help dynamically call some dotnet functions. - /// Note that this API does not gurantee performance. Besides, it is not supposed to be exposed to users. - /// - public interface IFunctionHolder - { - int ArgCount { get; } - object DynamicInvoke(params object[] args); - } - internal record class FunctionHolder(Func Func): IFunctionHolder - { - public int ArgCount => 0; - public object DynamicInvoke(params object[] args) - { - return Func.DynamicInvoke(args); - } - public TR Invoke() - { - return Func.Invoke(); - } - } - internal record class FunctionHolder(Func Func) : IFunctionHolder - { - public int ArgCount => 1; - public object DynamicInvoke(params object[] args) - { - return Func.DynamicInvoke(args); - } - } - internal record class FunctionHolder(Func Func) : IFunctionHolder - { - public int ArgCount => 2; - public object DynamicInvoke(params object[] args) - { - return Func.DynamicInvoke(args); - } - } - internal record class FunctionHolder(Func Func) : IFunctionHolder - { - public int ArgCount => 3; - public object DynamicInvoke(params object[] args) - { - return Func.DynamicInvoke(args); - } - } public class Maybe { private TA? _valueA = default(TA); private TB? _valueB = default(TB); private Type _type; - private bool _assigned = false; + private bool _assignedTA; public Maybe(TA value) { _valueA = value; _type= typeof(TA); - _assigned = true; + _assignedTA = true; } public Maybe(TB value) { _valueB = value; _type = typeof(TB); - _assigned = true; + _assignedTA = false; } public Type DataType => _type; - public TA GetValueA() + /// + /// Try to get the type T member of this instance. It returns true when TA or TB derive from T and is correspondingly assigned. + /// It returns + /// + /// + /// + /// + public bool TryGet(out T? res) { - if(!_assigned || DataType != typeof(TA)) + if(_valueA is T && _valueB is not T) { - throw new TypeError("Cannot get the data because of wrong specified type."); + res = (T)(object)_valueA; + return _assignedTA; } - return _valueA; - } - public TB GetValueB() - { - if (!_assigned || DataType != typeof(TB)) + else if(_valueA is not T && _valueB is T) { - throw new TypeError("Cannot get the data because of wrong specified type."); + res = (T)(object)_valueB; + return !_assignedTA; } - return _valueB; + res = default(T); + return false; } - public object GetValue() + + public bool IsTypeOrDeriveFrom() { - if (!_assigned) + if (_valueA is T && _valueB is not T) { - throw new TypeError("Cannot get the data because of wrong specified type."); + return _assignedTA; } - if(DataType == typeof(TA) && _valueA is not null) + else if (_valueA is not T && _valueB is T) { - return _valueA; + return !_assignedTA; } - else if(DataType == typeof(TB) && _valueB is not null) + else if (_valueA is T && _valueB is T) { - return _valueB; + return true; } - else if(DataType == typeof(TA)) + else { - return _valueA; + return false; + } + } + + public T GetValue() + { + if (_valueA is T && _valueB is not T) + { + return (T)(object)_valueA; + } + else if (_valueA is not T && _valueB is T) + { + return (T)(object)_valueB; + } + else if (_valueA is T && _valueB is T) + { + throw new TypeError("The type is vague, this is always because TA and TB both derive from T."); } else { - return _valueB; + throw new TypeError($"Expected {typeof(TA)} or {typeof(TB)}, but got typeof{typeof(T)}."); } } @@ -170,9 +149,8 @@ public SingleDeviceSaver(IDictionary> tens { var slice_spec = slice.Key; var maybe_tensor = slice.Value; - if(maybe_tensor.DataType == typeof(SaveSpec)) + if(maybe_tensor.TryGet(out var spec)) { - var spec = maybe_tensor.GetValueB(); var tensor_value = spec.tensor; if (tensor_value is not null) { @@ -183,7 +161,7 @@ public SingleDeviceSaver(IDictionary> tens } else { - var tensor = maybe_tensor.GetValueA(); + var tensor = maybe_tensor.GetValue(); tensor_names.Add(checkpoint_key); tensors.Add(tensor); slice_specs.Add(slice_spec); @@ -215,16 +193,15 @@ public IDictionary> restore(Tensor file_pref var slice_spec = slice.Key; var maybe_tensor = slice.Value; // TODO: deal with other types. Currently only `SaveSpec` is allowed. - if(maybe_tensor.DataType == typeof(SaveSpec)) + if(maybe_tensor.TryGet(out var spec)) { - var spec = maybe_tensor.GetValueB(); tensor_dtypes.Add(spec.dtype); slice_specs.Add(spec.slice_spec); tensor_names.Add(spec.name); } else { - var tensor = maybe_tensor.GetValueA(); + var tensor = maybe_tensor.GetValue(); tensor_dtypes.Add(tensor.dtype); slice_specs.Add(slice_spec); tensor_names.Add(checkpoint_key); @@ -268,9 +245,9 @@ public IDictionary> restore(Tensor file_pref public class MultiDeviceSaver { private Dictionary _single_device_savers; - private IDictionary _registered_savers; - private Dictionary<(string, string), IFunctionHolder> _keys_to_restore_fn; - private Dictionary> _restore_fn_to_keys; + private IDictionary _registered_savers; + private Dictionary<(string, string), RestoreFunc> _keys_to_restore_fn; + private Dictionary> _restore_fn_to_keys; /// /// /// @@ -280,24 +257,28 @@ public class MultiDeviceSaver public MultiDeviceSaver(IDictionary>>> serialized_tensors, IDictionary>? registered_savers = null, bool call_with_mapped_capture = false) { - _keys_to_restore_fn = new Dictionary<(string, string), IFunctionHolder>(); - _restore_fn_to_keys = new Dictionary>(); + _keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>(); + _restore_fn_to_keys = new Dictionary>(); Dictionary>> tensors_by_device= new(); foreach(var pair in serialized_tensors) { var obj = pair.Key; var tensor_dict = pair.Value; - IFunctionHolder restore_fn; + RestoreFunc restore_fn; if(obj == Trackable.None) { - restore_fn = new FunctionHolder(() => null); + restore_fn = new RestoreFunc(x => null); } else { - restore_fn = new FunctionHolder>>, IDictionary>(x => + restore_fn = new RestoreFunc(x => { - return obj._restore_from_tensors(x); + if(x is IDictionary>>) + { + return obj._restore_from_tensors(x as IDictionary>>); + } + throw new TypeError($"Expected `IDictionary>>` as input, got{x.GetType()}."); }); } @@ -305,14 +286,14 @@ public MultiDeviceSaver(IDictionary spec_to_tensor; - if(item.Value.DataType != typeof(IDictionary)) + if(item.Value.TryGet(out var t)) { spec_to_tensor = new Dictionary(); - spec_to_tensor[""] = item.Value.GetValueA(); + spec_to_tensor[""] = t; } else { - spec_to_tensor = item.Value.GetValueB(); + spec_to_tensor = item.Value.GetValue>(); } foreach(var spec in spec_to_tensor) @@ -342,7 +323,7 @@ public MultiDeviceSaver(IDictionary x.Key, x => new SingleDeviceSaver(x.Value)); - _registered_savers = new Dictionary(); + _registered_savers = new Dictionary(); if(registered_savers is not null && registered_savers.Count > 0) { // TODO: complete the implementation. @@ -418,8 +399,8 @@ public IDictionary restore(Tensor file_prefix, CheckpointOpti IDictionary restore_func() { - Dictionary>>> restore_fn_inputs = new(); - Dictionary restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); + Dictionary>>> restore_fn_inputs = new(); + Dictionary restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); Dictionary restore_ops = new(); foreach(var single_saver in _single_device_savers.OrderBy(x => x.Key)) @@ -449,7 +430,7 @@ IDictionary restore_func() } else { - internal_dict[checkpoint_key].GetValueB()[slice_spec] = tensor; + internal_dict[checkpoint_key].GetValue>()[slice_spec] = tensor; } } else diff --git a/src/TensorFlowNET.Core/DisposableObject.cs b/src/TensorFlowNET.Core/DisposableObject.cs index 7fac3d0f1..c3c677fff 100644 --- a/src/TensorFlowNET.Core/DisposableObject.cs +++ b/src/TensorFlowNET.Core/DisposableObject.cs @@ -158,4 +158,4 @@ public void Dispose() Dispose(false); } } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Exceptions/AssertionError.cs b/src/TensorFlowNET.Core/Exceptions/AssertionError.cs index 84ec24cbf..977fe2340 100644 --- a/src/TensorFlowNET.Core/Exceptions/AssertionError.cs +++ b/src/TensorFlowNET.Core/Exceptions/AssertionError.cs @@ -11,4 +11,4 @@ public AssertionError(string message) : base(message) { } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/AutoTrackable.cs b/src/TensorFlowNET.Core/Training/AutoTrackable.cs index 4d5a664ec..4ba3e4074 100644 --- a/src/TensorFlowNET.Core/Training/AutoTrackable.cs +++ b/src/TensorFlowNET.Core/Training/AutoTrackable.cs @@ -37,10 +37,10 @@ public override IDictionary _trackable_children(SaveType save var properties = this.GetType().GetProperties(); foreach ( var property in properties ) { - string name = property.Name; - object value = property.GetValue(this, null); - if(value is Function || value is ConcreteFunction) + if(property.PropertyType == typeof(Function) || property.PropertyType == typeof(ConcreteFunction)) { + string name = property.Name; + object value = property.GetValue(this, null); functions[name] = (Trackable)value; } } diff --git a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs index 43d36dba3..1309a6174 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs @@ -25,9 +25,9 @@ public Tensor op { get { - if(_op.DataType == typeof(Tensor)) + if(_op.TryGet(out var tensor)) { - return _op.GetValueA(); + return tensor; } else { diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs index 24c8f2f05..d10257822 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs @@ -8,4 +8,4 @@ public record class AssetInfo Dictionary asset_initializers_by_resource, Dictionary asset_filename_map, Dictionary asset_index -); \ No newline at end of file +); diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs index 97162651a..a91933357 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs @@ -86,7 +86,7 @@ private ConcreteFunction maybe_uncache_variable_captures(ConcreteFunction concre return concrete_function; } - public override (List, Dictionary>) breadth_first_traversal() + public override (IList, IDictionary>) breadth_first_traversal() { Trackable get_merged_trackable(Trackable x) { @@ -130,4 +130,4 @@ public Trackable get_child(Trackable obj, string name) { return _children_cache[obj][name]; } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs index cb7abadad..726f6cfd4 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs @@ -30,4 +30,4 @@ public static class Constants public static readonly string VARIABLES_DIRECTORY = "variables"; public static readonly string VARIABLES_FILENAME = "variables"; -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs index fa9d6e504..fe0403c30 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs @@ -14,4 +14,4 @@ public class RevivedTypes // TODO: complete the implementation. return null; } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs index b973fd417..8dd4f008f 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs @@ -6,4 +6,4 @@ public enum SaveType { SAVEDMODEL, CHECKPOINT -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs index 6132e0254..1be54287e 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs @@ -18,13 +18,13 @@ public class SaveableView { private AugmentedGraphView _augmented_graph_view; private SaveOptions _options; - private List _trackable_objects; + private IList _trackable_objects; private List _nodes; - private Dictionary> _node_paths; - private Dictionary _node_ids; + private IDictionary> _node_paths; + private IDictionary _node_ids; private IDictionary> _slot_variables; - private Dictionary _object_names; + private IDictionary _object_names; private List _gradient_functions; // to be completed private List _gradient_defs; // to be completed private List _concrete_functions; @@ -45,7 +45,7 @@ public List Nodes { get => _nodes; } - public Dictionary NodeIds + public IDictionary NodeIds { get => _node_ids; } @@ -53,7 +53,7 @@ public List GradientDefs { get => _gradient_defs; } - public Dictionary> NodePaths + public IDictionary> NodePaths { get => _node_paths; } @@ -84,7 +84,7 @@ private void initialize_save_and_restore_functions() private void initialize_nodes_and_concrete_functions() { - _nodes = _trackable_objects.ConvertAll(x => x); // deep copy + _nodes = _trackable_objects.ToList().ConvertAll(x => x); // deep copy _gradient_functions = new(); _gradient_defs = new(); @@ -296,4 +296,4 @@ public void fill_object_graph_proto(SavedObjectGraph proto) proto.Nodes.Add(object_proto); } } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs index 9a066eed7..6aa1fbde1 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs @@ -7,4 +7,4 @@ public static class TagConstants public static readonly string EVAL = "eval"; public static readonly string GPU = "gpu"; public static readonly string TPU = "tpu"; -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs index bcd3ae05a..dbbab91d8 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs @@ -19,4 +19,4 @@ public static void copy_assets_to_destination_dir(IDictionary throw new NotImplementedException(); } } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs index d82d49d8f..94760e3df 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs @@ -81,8 +81,8 @@ public static (IList, IDictionary, - Dictionary>) _build_meta_graph(Trackable obj, + private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, IList, + IDictionary>) _build_meta_graph(Trackable obj, ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) { using (SaveContext.save_context(options)) @@ -266,4 +266,4 @@ public static void byte_swap_tensor_content(TensorProto tensor) } } } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs index 0d34907f7..4a0d3b002 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs @@ -104,4 +104,4 @@ public override IDictionary _trackable_children(SaveType save return _signatures.TakeWhile(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value); } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs index 2deff0275..b0e6411c9 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs @@ -54,4 +54,4 @@ public static string get_assets_dir(string export_dir) { return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.ASSETS_DIRECTORY)); } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index 582e2431e..a6e21e3e5 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -136,9 +136,8 @@ public static IEnumerable saveable_objects_for_op(Trackable ob { full_name = name + "_" + attr; } - if(factory.DataType == typeof(ResourceVariable)) + if(factory.TryGet(out var variable)) { - var variable = factory.GetValueA(); foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name)) { yield return op; @@ -146,8 +145,8 @@ public static IEnumerable saveable_objects_for_op(Trackable ob } else { - var variable = factory.GetValueB(); - foreach (var op in saveable_objects_for_op(variable, variable.name)) + var saveable = factory.GetValue(); + foreach (var op in saveable_objects_for_op(saveable, saveable.name)) { yield return op; } @@ -236,14 +235,14 @@ public static IDictionary> string spec_name = name + TrackableUtils.escape_local_name(tensor_name); IDictionary internal_dict; - if(maybe_tensor.DataType == typeof(Tensor)) + if(maybe_tensor.TryGet(out var tensor)) { internal_dict= new Dictionary(); - internal_dict[""] = maybe_tensor.GetValueA(); + internal_dict[""] = tensor; } else { - internal_dict = maybe_tensor.GetValueB(); + internal_dict = maybe_tensor.GetValue>(); } foreach(var item in internal_dict) @@ -287,7 +286,7 @@ public static Dictionary>> sav var slice_spec = convert_to_string(spec.slice_spec); if (!string.IsNullOrEmpty(slice_spec)) { - tensor_dict.SetDefault(name, new Dictionary()).GetValueB()[slice_spec] = spec.tensor; + tensor_dict.SetDefault(name, new Dictionary()).GetValue>()[slice_spec] = spec.tensor; } else { @@ -318,14 +317,14 @@ public static Func var maybe_tensor = restored_tensors[name]; IDictionary dict; - if(maybe_tensor.DataType == typeof(Tensor)) + if(maybe_tensor.TryGet(out var tensor)) { dict = new Dictionary(); - dict[""] = maybe_tensor.GetValueA(); + dict[""] = tensor; } else { - dict = maybe_tensor.GetValueB(); + dict = maybe_tensor.GetValue>(); } saveable_restored_tensors.Add(dict[slice_spec]); } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs index ea6853fde..3ea4f067e 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs @@ -38,4 +38,4 @@ public static class Constants RNN_LAYER_IDENTIFIER, SEQUENTIAL_IDENTIFIER }; -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs deleted file mode 100644 index a5f315bb3..000000000 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs +++ /dev/null @@ -1,11 +0,0 @@ -namespace Tensorflow.Keras.Saving.SavedModel; - -public class KerasObjectWrapper -{ - -} - -public class KerasObjectWrapper -{ - public T Item { get; set; } -} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs index 9d1c9609a..c7b7e52f4 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -3,19 +3,15 @@ using System.IO; using System.Linq; using Google.Protobuf; -using ICSharpCode.SharpZipLib.Zip; -using Tensorflow.Checkpoint; -using Tensorflow.Contexts; using Tensorflow.Functions; using Tensorflow.Keras.Engine; -using Tensorflow.Keras.Utils; using Tensorflow.ModelSaving; using Tensorflow.Train; -using Tensorflow.Exceptions; -using Tensorflow.IO; using Tensorflow.Keras.Optimizers; using ThirdParty.Tensorflow.Python.Keras.Protobuf; using static Tensorflow.Binding; +using Tensorflow.Training; + namespace Tensorflow.Keras.Saving.SavedModel; @@ -108,5 +104,59 @@ public static SavedMetadata generate_keras_metadata(IList saved_nodes return metadata; } - -} \ No newline at end of file + public static bool should_skip_serialization(object layer) + { + return false; + } + + /// + /// Returns extra trackable objects to attach to the serialized layer. + /// + /// + /// + /// + public static IDictionary wrap_layer_objects(Layer layer, IDictionary> serialization_cache) + { + // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. + + // TODO: change the inherits of `Variable` and revise the implmentation. + var variables = TrackableDataStructure.wrap_or_unwrap(layer.Variables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + })); + var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + })); + var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + })); + + Dictionary res = new(); + res["variables"] = variables; + res["trainable_variables"] = trainable_variables; + res["non_trainable_variables"] = non_trainable_variables; + res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); + + return res; + } + + /// + /// Returns dict of wrapped layer call function and losses in tf.functions. + /// + /// + /// + /// + public static IDictionary wrap_layer_functions(Layer layer, IDictionary> serialization_cache) + { + // TODO: deal with type `RevivedLayer` and `Sequential`. + + // skip the process because of lack of APIs of `Layer`. + + return new Dictionary(); + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs deleted file mode 100644 index f7e1bf45c..000000000 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs +++ /dev/null @@ -1,66 +0,0 @@ -using System.Collections.Generic; -using System.Linq; -using Tensorflow.Keras.Engine; -using Tensorflow.Train; -using Tensorflow.Training; - -namespace Tensorflow.Keras.Saving.SavedModel; - -public partial class KerasSavedModelUtils -{ - public static bool should_skip_serialization(object layer) - { - return false; - } - - /// - /// Returns extra trackable objects to attach to the serialized layer. - /// - /// - /// - /// - public static IDictionary wrap_layer_objects(Layer layer, IDictionary> serialization_cache) - { - // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. - - // TODO: change the inherits of `Variable` and revise the implmentation. - var variables = TrackableDataStructure.wrap_or_unwrap(layer.Variables.Select(x => - { - if (x is ResourceVariable or RefVariable) return (Trackable)x; - else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); - })); - var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x => - { - if (x is ResourceVariable or RefVariable) return (Trackable)x; - else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); - })); - var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x => - { - if (x is ResourceVariable or RefVariable) return (Trackable)x; - else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); - })); - - Dictionary res = new(); - res["variables"] = variables; - res["trainable_variables"] = trainable_variables; - res["non_trainable_variables"] = non_trainable_variables; - res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); - - return res; - } - - /// - /// Returns dict of wrapped layer call function and losses in tf.functions. - /// - /// - /// - /// - public static IDictionary wrap_layer_functions(Layer layer, IDictionary> serialization_cache) - { - // TODO: deal with type `RevivedLayer` and `Sequential`. - - // skip the process because of lack of APIs of `Layer`. - - return new Dictionary(); - } -} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs index 60c4ee5b8..eb88c8953 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs @@ -34,5 +34,4 @@ public IDictionary trackable_children(IDictionary x.Key, x => (Trackable)x.Value)) .ToDictionary(x => x.Key, x => x.Value); } - -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs index 8675ea65b..03693cb57 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs @@ -162,4 +162,4 @@ public override string TrackingMetadata return JsonConvert.SerializeObject(info); } } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs index 3054271ae..51f8d2c91 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs @@ -44,4 +44,4 @@ public void Dispose() { KerasSavedModelUtils.ShouldHaveTraces = _old_value; } -} \ No newline at end of file +} diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs index c31453448..269b9c058 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs @@ -73,7 +73,7 @@ public void SimpleModelFromSequential() { TrainDir = "mnist", OneHot = false, - ValidationSize = 10000, + ValidationSize = 50000, }).Result; model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); @@ -119,13 +119,13 @@ public void AlexModelFromSequential() model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits:true), new string[] { "accuracy" }); var num_epochs = 1; - var batch_size = 16; + var batch_size = 8; var dataset = new RandomDataSet(new Shape(227, 227, 3), 16); model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); - model.save("./pb_elex_sequential", save_format: "tf"); + model.save("./pb_alex_sequential", save_format: "tf"); // The saved model can be test with the following python code: #region alexnet_python_code @@ -136,7 +136,7 @@ public void AlexModelFromSequential() // return -a //if __name__ == '__main__': - // model = tf.keras.models.load_model("./pb_elex_sequential") + // model = tf.keras.models.load_model("./pb_alex_sequential") // model.summary() // num_classes = 5