diff --git a/src/TensorFlowNET.Core/APIs/tf.compat.cs b/src/TensorFlowNET.Core/APIs/tf.compat.cs index 4d979eb55..5b2b5a107 100644 --- a/src/TensorFlowNET.Core/APIs/tf.compat.cs +++ b/src/TensorFlowNET.Core/APIs/tf.compat.cs @@ -14,6 +14,8 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System.Text; + namespace Tensorflow { public partial class tensorflow @@ -23,6 +25,26 @@ public partial class tensorflow public class CompatApi { public CompatV1Api v1 { get; } = new CompatV1Api(); + + internal string as_text(string bytes_or_text, Encoding? encoding = null) + { + if(encoding is null) encoding = Encoding.UTF8; + return bytes_or_text; + } + internal string as_text(byte[] bytes_or_text, Encoding? encoding = null) + { + if(encoding is null) encoding = Encoding.UTF8; + return encoding.GetString(bytes_or_text); + } + + internal string as_str(string bytes_or_text, Encoding? encoding = null) + { + return as_text(bytes_or_text, encoding); + } + internal string as_str(byte[] bytes_or_text, Encoding? encoding = null) + { + return as_text(bytes_or_text, encoding); + } } public bool executing_eagerly() diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs new file mode 100644 index 000000000..8ae2dae8f --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs @@ -0,0 +1,152 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; + +namespace Tensorflow.Checkpoint; + +public static class CheckPointUtils +{ + private static string _ESCAPE_CHAR = "."; + public static (IList, IDictionary>, IDictionary, + IDictionary>, + IDictionary) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) + { + var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); + Dictionary object_names = new(); + foreach (var pair in node_paths) + { + object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); + } + + Dictionary node_ids = new(); + for (int i = 0; i < trackable_objects.Count; i++) + { + node_ids[trackable_objects[i]] = i; + } + + var slot_variables = serialize_slot_variables(trackable_objects, node_ids, object_names); + return (trackable_objects, node_paths, node_ids, slot_variables, object_names); + } + + public static + IDictionary> + serialize_slot_variables(IEnumerable trackable_objects, + IDictionary node_ids, IDictionary object_names) + { + var non_slot_objects = trackable_objects.ToList(); + Dictionary> + slot_variables = new(); + foreach (var trackable in non_slot_objects) + { + if (trackable is not Optimizer) + { + continue; + } + + var optim = (Optimizer)trackable; + var slot_names = optim.get_slot_names(); + foreach (var slot_name in slot_names) + { + for (int original_variable_node_id = 0; + original_variable_node_id < non_slot_objects.Count; + original_variable_node_id++) + { + var original_variable = non_slot_objects[original_variable_node_id]; + IVariableV1 slot_variable; + if (original_variable is not IVariableV1) + { + slot_variable = null; + } + slot_variable = optim.get_slot((IVariableV1)original_variable, slot_name); + if(slot_variable is null) continue; + + // There're some problems about the inherits of `Variable` and `Trackable`. + throw new NotImplementedException(); + } + } + } + + return slot_variables; + } + + public static Trackable get_mapped_trackable(Trackable trackable, IDictionary? object_map) + { + if (object_map is null || !object_map.TryGetValue(trackable, out var possible_res)) + { + return trackable; + } + else + { + return possible_res; + } + } + + public static string get_full_name(Trackable variable) + { + // TODO: This state is not correct, the whole framework need to be updated in the future. + if (!(variable is IVariableV1 || resource_variable_ops.is_resource_variable(variable))) + { + return ""; + } + // skip the check of attribute `_save_slice_info` . + + // TODO: Need to be revised!!! + Debug.Assert(variable is BaseResourceVariable); + return ((BaseResourceVariable)variable).Name; + } + + public static void add_checkpoint_values_check(TrackableObjectGraph object_graph_proto) + { + HashSet checkpointed_trackables = new(); + Dictionary> parents = new(); + for (int i = 0; i < object_graph_proto.Nodes.Count; i++) + { + var object_proto = object_graph_proto.Nodes[i]; + // skip the process of registered saver. + if (object_proto.Attributes is not null && object_proto.Attributes.Count > 0 || + object_proto.SlotVariables is not null && object_proto.SlotVariables.Count > 0) + { + checkpointed_trackables.Add(i); + } + + foreach (var child_proto in object_proto.Children) + { + var child = child_proto.NodeId; + if (!parents.ContainsKey(child)) + { + parents[child] = new HashSet(); + } + + parents[child].Add(i); + } + } + + Queue to_visit = new(checkpointed_trackables.AsEnumerable()); + while (to_visit.Count > 0) + { + var trackable = to_visit.Dequeue(); + if (!parents.ContainsKey(trackable)) continue; + var current_parents = parents[trackable]; + foreach (var parent in current_parents) + { + checkpointed_trackables.Add(parent); + if (parents.ContainsKey(parent)) + { + to_visit.Enqueue(parent); + } + } + parents.Remove(trackable); + } + + // TODO: Complete it after supporting checkpoint. + // for (int i = 0; i < object_graph_proto.Nodes.Count; i++) + // { + // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); + // } + } +} diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs b/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs new file mode 100644 index 000000000..75b392af8 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs @@ -0,0 +1,5 @@ +namespace Tensorflow.Checkpoint; + +public record class CheckpointOptions( + string? experimental_io_device = null, + bool experimental_enable_async_checkpoint = false); diff --git a/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs b/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs new file mode 100644 index 000000000..f435dd88b --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs @@ -0,0 +1,64 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Serilog.Debugging; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Train; + +namespace Tensorflow.Checkpoint; + +public class ObjectGraphView: TrackableView, ICloneable +{ + protected IEnumerable? _attached_dependencies; + // TODO: attached_dependencies + public ObjectGraphView(Trackable root, IEnumerable? attached_dependencies = null): base(root) + { + _attached_dependencies = attached_dependencies; + } + + public object Clone() + { + // TODO: Implement real deep copy corresponding to tensorflow/python/checkpoint/graph_view.ObjectGraphView.__deepcopy__ + return new ObjectGraphView(Root, _attached_dependencies); + } + + public virtual List list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary>? serialization_cache = null) + { + List res = base.children(obj, save_type, serialization_cache) + .Select(x => new TrackableReference(x.Key, x.Value)).ToList(); + // Check the reference, not value. + if (obj == Root && _attached_dependencies is not null) + { + res.AddRange(_attached_dependencies); + } + + return res; + } + + public override IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary>? serialization_cache = null) + { + return list_children(obj, save_type, serialization_cache).ToDictionary(x => x.Name, x => x.Refer); + } + + public IEnumerable? AttachedDependencies + { + get => _attached_dependencies; + } + + public virtual (IList, IDictionary>) breadth_first_traversal() + { + return base._descendants_with_paths(); + } + + // TODO: complete the implementation + public void serialize_object_graph(object? saveables_cache = null) + { + throw new NotImplementedException(); + } + + // TODO: complete the implementation + public void frozen_saveable_objects(object? object_map = null, object? to_graph = null, object call_with_mapped_captures = null) + { + throw new NotImplementedException(); + } +} diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs new file mode 100644 index 000000000..c54cc93f6 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs @@ -0,0 +1,255 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; + +namespace Tensorflow.Checkpoint +{ + internal record class TrackableData( + // A trackable in the root Trackable object graph. + Trackable trackable, + // The index at which the Trackable appears in TrackableObjectGraph.nodes. + int node_id, + // The BFS-generated path from the root object / used to generate readable checkpoint keys. + string object_name, + // A list of ObjectReference for each child connected to this Trackable. + pbc::RepeatedField children_proto, + // A list of SlotVariableReference to save to the object (only valid for Optimizer objects). + pbc::RepeatedField slot_variable_proto, + // The object to save to checkpoint. Usually this is the same as `trackable`, + // but can differ when the the caller wants to specify a different object to + // save. For example, when saving checkpoints asynchronously, variables are + // copied to the CPU. `object_to_save` is set as the copied variable. + Trackable object_to_save + ); + public static class SaveUtil + { + public static (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) + serialize_graph_view(ObjectGraphView graph_view, IDictionary? object_map = null, bool call_with_mapped_captures = false, object? cache = null) + { + var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); + var (tensor_trackables, pystate_trackables, registered_trackables) = split_trackables(trackable_data); + + var object_graph_proto = fill_object_graph_proto(trackable_data); + + var serialized_tensors = get_and_write_tensors_to_serialize(tensor_trackables, node_ids, call_with_mapped_captures, cache, object_graph_proto); + var registered_savers = get_and_write_registered_savers(registered_trackables, object_graph_proto); + + Dictionary feed_additions; + if(cache is null) + { + feed_additions = null; + serialized_tensors = serialized_tensors.Concat(get_and_write_tensors_to_serialize(pystate_trackables, node_ids, call_with_mapped_captures, + cache, object_graph_proto)).ToDictionary(x => x.Key, x => x.Value); + } + else + { + feed_additions = null; + // TODO: deal with cache. + throw new NotFiniteNumberException(); + } + + CheckPointUtils.add_checkpoint_values_check(object_graph_proto); + + return (serialized_tensors, feed_additions, registered_savers, object_graph_proto); + } + + private static (IList, IDictionary) gather_trackable_data(ObjectGraphView graph_view, IDictionary? object_map) + { + var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); + Dictionary object_names = new(); + foreach(var pair in node_paths) + { + object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); + } + Dictionary node_ids = new(); + for(int i = 0; i < trackable_objects.Count; i++) + { + node_ids[trackable_objects[i]] = i; + } + var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names); + List trackable_data = new(); + foreach(var trackable in trackable_objects) + { + pbc::RepeatedField children_proto = new(); + foreach(var child in graph_view.list_children(trackable)) + { + children_proto.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference() + { + NodeId = node_ids[child.Refer], + LocalName = child.Name + }); + } + slot_variables.TryGetValue(trackable, out var slot_variable); + trackable_data.Add(new TrackableData( + trackable: trackable, + node_id: node_ids[trackable], + object_name: object_names[trackable], + children_proto: children_proto, + slot_variable_proto: slot_variable??new pbc.RepeatedField(), + object_to_save: CheckPointUtils.get_mapped_trackable(trackable, object_map) + )); + } + return (trackable_data, node_ids); + } + + private static TrackableObjectGraph fill_object_graph_proto(IList trackable_data) + { + TrackableObjectGraph object_graph_proto = new(); + for(int i = 0; i < trackable_data.Count; i++) + { + var td = trackable_data[i]; + Debug.Assert(td.node_id == i); + object_graph_proto.Nodes.Add(new TrackableObjectGraph.Types.TrackableObject(td.slot_variable_proto, td.children_proto)); + } + return object_graph_proto; + } + + /// + /// Creates dictionary of tensors to checkpoint, and updates the proto. + /// + /// + /// + /// + /// + /// + private static IDictionary>>> get_and_write_tensors_to_serialize(IList tensor_trackables, IDictionary node_ids, + bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) + { + Dictionary>>> serialized_tensors = new(); + foreach(var td in tensor_trackables) + { + // TODO: deal with cache. + var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; + Trackable trackable = null; + IDictionary>> tensor_dict; + if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) + { + (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); + } + else + { + tensor_dict = get_tensors_from_trackable(td, call_with_mapped_captures, object_graph_proto); + trackable = td.object_to_save; + } + if(trackable is not null) + { + serialized_tensors[trackable] = tensor_dict; + } + else + { + serialized_tensors[Trackable.None] = tensor_dict; + } + } + return serialized_tensors; + } + + private static IDictionary>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) + { + var trackable = trackable_data.object_to_save; + + // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. + IDictionary>> ret_tensor_dict; + if (call_with_mapped_captures) + { + throw new NotImplementedException(); + } + else + { + ret_tensor_dict = trackable.serialize_to_tensors(); + } + + // TODO: deal with the type `SaveSpce` (currently it will never be it). + Dictionary>> tensor_dict = new(); + foreach(var pair in ret_tensor_dict) + { + var local_name = TrackableUtils.escape_local_name(pair.Key); + var maybe_tensor = pair.Value; + var checkpoint_key = TrackableUtils.checkpoint_key(trackable_data.object_name, local_name); + + tensor_dict[checkpoint_key] = maybe_tensor; + + if(maybe_tensor.IsTypeOrDeriveFrom()) + { + throw new NotImplementedException(); + //((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; + } + + if(object_graph_proto is not null) + { + object_graph_proto.Nodes[trackable_data.node_id].Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() + { + Name = local_name, + CheckpointKey = checkpoint_key, + FullName = CheckPointUtils.get_full_name(trackable) + }); + } + } + return tensor_dict; + } + + /// + /// Gets tensors to serialize from a Trackable with legacy SaveableObjects. + /// + /// + /// + /// + /// + /// + private static (Trackable, IDictionary>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary node_ids, + bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) + { + Dictionary object_names = new(); + object_names[trackable_data.trackable] = trackable_data.object_name; + Dictionary object_map = new(); + object_map[trackable_data.trackable] = trackable_data.object_to_save; + + var (checkpoint_factory_map, _) = SaveUtilV1.get_checkpoint_factories_and_keys(object_names, object_map); + var (named_saveable_objects, _) = SaveUtilV1.generate_saveable_objects(checkpoint_factory_map, object_graph_proto, node_ids, object_map, + call_with_mapped_captures, saveables_cache: null); + var trackable = new SaveableCompatibilityConverter(trackable_data.object_to_save, named_saveable_objects); + return (trackable, trackable.serialize_to_tensors()); + } + + private static IDictionary> get_and_write_registered_savers(IDictionary> registered_trackables, TrackableObjectGraph object_graph_proto) + { + Dictionary> registered_savers = new(); + foreach(var pair in registered_trackables) + { + foreach(var td in pair.Value) + { + if (registered_savers.ContainsKey(pair.Key)) + { + registered_savers[pair.Key] = new Dictionary(); + } + else + { + registered_savers[pair.Key][td.object_name] = td.object_to_save; + } + + var object_proto = object_graph_proto.Nodes[td.node_id]; + // TODO: add APIs and complete it. Now the `TrackableObjectGraph.Types.TrackableObject` lacks `registered_savers`. + } + } + return registered_savers; + } + + private static (IList, IList, IDictionary>) split_trackables(IEnumerable trackable_data) + { + List tensor_trackables = new(); + List py_state_trackables = new(); // skip the process of `PyState` for the lack of API. This is only a pleceholder. + Dictionary> registered_trackables = new(); + + foreach(var td in trackable_data) + { + // TODO: deal with registration. + tensor_trackables.Add(td); + } + return (tensor_trackables, py_state_trackables, registered_trackables); + } + } +} diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs new file mode 100644 index 000000000..3267ae126 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs @@ -0,0 +1,223 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Exceptions; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; +using static Tensorflow.Binding; +using Google.Protobuf; + +namespace Tensorflow.Checkpoint; + +public static class SaveUtilV1 +{ + public static (IDictionary>, object?) get_checkpoint_factories_and_keys(IDictionary object_names, + IDictionary? object_map = null) + { + // According to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/registration/README.md, + // till now only internal registrations are allowed. So, we won't return a saver in this function. + // The implementation of this function should be updated if tensorflow update it. + Dictionary> checkpoint_factory_map = new(); + foreach (var pair in object_names) + { + var trackable = pair.Key; + var object_name = pair.Value; + var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map); + + // skip the registration process. + + List current_list = new(); + foreach (var name_and_factory in saveable_object_util.saveable_objects_from_trackable(object_to_save)) + { + // treat name as key_suffix. + var name = name_and_factory.Key; + var checkpoint_key = TrackableUtils.checkpoint_key(object_name, name); + + current_list.Add(new CheckpointFactoryData(name_and_factory.Value, name, checkpoint_key)); + } + + checkpoint_factory_map[trackable] = current_list; + } + + return (checkpoint_factory_map, null); + } + + public static (IList, IDictionary>?) frozen_saveables_and_savers(ObjectGraphView graph_view, + IDictionary object_map, Graph? to_graph, bool call_with_mapped_captures, + object? saveables_cache = null) + { + if (to_graph is not null) + { + var g = to_graph.as_default(); + var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, + object_map, call_with_mapped_captures, saveables_cache); + tf.device("/cpu:0"); + var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); + named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); + g.Exit(); + return (named_saveable_objects, registered_savers); + } + else + { + using (new ops.NullContextManager()) + { + var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, + object_map, call_with_mapped_captures, saveables_cache); + tf.device("/cpu:0"); + var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); + named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); + return (named_saveable_objects, registered_savers); + } + } + } + + public static (IList, TrackableObjectGraph, object?, IDictionary>?) serialize_gathered_objects(ObjectGraphView graph_view, + IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) + { + var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); + Dictionary object_names = new(); + foreach (var pair in node_paths) + { + object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); + } + + Dictionary node_ids = new(); + for (int i = 0; i < trackable_objects.Count; i++) + { + node_ids[trackable_objects[i]] = i; + } + + var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names); + var object_graph_proto = fill_object_graph_proto(graph_view, trackable_objects, node_ids, slot_variables); + var (named_saveable_objects, feed_additions, registered_savers) = add_attributes_to_object_graph( + trackable_objects, object_graph_proto, node_ids, object_names, object_map, call_with_mapped_captures, + saveables_cache); + + CheckPointUtils.add_checkpoint_values_check(object_graph_proto); + return (named_saveable_objects, object_graph_proto, feed_additions, registered_savers); + } + + private static TrackableObjectGraph fill_object_graph_proto(ObjectGraphView graph_view, IList trackable_objects, + IDictionary node_ids, + IDictionary> + slot_variables) + { + TrackableObjectGraph object_graph_proto = new(); + for (int i = 0; i < trackable_objects.Count; i++) + { + var trackable = trackable_objects[i]; + Debug.Assert(node_ids[trackable] == i); + TrackableObjectGraph.Types.TrackableObject object_proto; + if (slot_variables.TryGetValue(trackable, out var slots)) + { + object_proto = new TrackableObjectGraph.Types.TrackableObject(slots); + } + else + { + object_proto = new TrackableObjectGraph.Types.TrackableObject(); + } + object_graph_proto.Nodes.Add(object_proto); + foreach (var child in graph_view.list_children(trackable)) + { + object_proto.Children.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference() + { NodeId = node_ids[child.Refer], LocalName = child.Name }); + } + } + + return object_graph_proto; + } + + private static (IList, object?, IDictionary>?) add_attributes_to_object_graph( + IList trackable_objects, + TrackableObjectGraph object_graph_proto, IDictionary node_ids, + IDictionary object_names, IDictionary object_map, + bool call_with_mapped_captures, object? saveables_cache = null) + { + int cnt = Math.Min(trackable_objects.Count, object_graph_proto.Nodes.Count); + for (int i = 0; i < cnt; i++) + { + Debug.Assert(node_ids[trackable_objects[i]] == i); + } + + var (checkpoint_factory_map, unmmaped_registered_savers) = + get_checkpoint_factories_and_keys(object_names, object_map); + + // skip the process of registered savers + + var (named_saveable_objects, feed_additions) = generate_saveable_objects(checkpoint_factory_map, + object_graph_proto, node_ids, object_map, call_with_mapped_captures, saveables_cache); + return (named_saveable_objects, feed_additions, null); + } + + public static (IList, object?) generate_saveable_objects( + IDictionary> checkpoint_factory_map, + TrackableObjectGraph? object_graph_proto, IDictionary? node_ids, + IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) + { + List named_saveable_objects = new(); + foreach (var pair in checkpoint_factory_map) + { + var trackable = pair.Key; + var factory_data_list = pair.Value; + bool fill_object_proto = object_graph_proto is not null && node_ids is not null; + TrackableObjectGraph.Types.TrackableObject object_proto = null!; + if (fill_object_proto) + { + object_proto = object_graph_proto.Nodes[node_ids[trackable]]; + } + + var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map); + // skip cache + + foreach (var factory_data in factory_data_list) + { + var name = factory_data.name; + var key = factory_data.checkpoint_key; + var maybe_saveable = factory_data.factory; + + // TODO: oneflow python has a process with callable `saveable_factory`. + List saveables = new(); + if (maybe_saveable.TryGet(out var s)) + { + saveables.Add(s); + } + else + { + saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValue() as Trackable, key)); + } + + foreach (var saveable in saveables) + { + if (!saveable.name.Contains(key)) + { + throw new AssertionError($"The object {trackable} produced a SaveableObject with name " + + $"'{saveable.name}' for attribute '{name}'. Expected a name" + + $" containing '{key}'."); + } + } + + // skip the process of PythonState + + named_saveable_objects.AddRange(saveables); + + if(!fill_object_proto) continue; + + // skip the process of `TrackableSaveable` because of lack of APIs. + + object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() + { Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) }); + } + } + + return (named_saveable_objects, null); + } +} + +public record class CheckpointFactoryData +( + Maybe factory, + string name, + string checkpoint_key +); diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs b/src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs new file mode 100644 index 000000000..fa441d799 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/SaveableCompat.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; + +namespace Tensorflow.Checkpoint +{ + internal static class SaveableCompat + { + public static string? get_saveable_name(Trackable cls_or_obj) + { + // TODO: implement it with Attribute. + return null; + } + } +} diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs new file mode 100644 index 000000000..dab6d5d97 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs @@ -0,0 +1,82 @@ +using System; +using Tensorflow.Train; +using System.Collections.Generic; +using System.IO; +using Tensorflow.Keras.Saving.SavedModel; + +namespace Tensorflow.Checkpoint; + +public class TrackableView +{ + protected WeakReference _root_ref; + public TrackableView(Trackable obj) + { + _root_ref = new WeakReference(obj); + } + + public TrackableView(WeakReference obj) + { + _root_ref = obj; + } + + public virtual IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) + { + obj._maybe_initialize_trackable(); + Dictionary children = new(); + // Note: in python the return type of `Trackable._trackable_children` is not fixed. + // Therefore it uses `convert_to_trackable` to have an extra process. + foreach (var pair in obj._trackable_children(save_type, cache)) + { + children[pair.Key] = pair.Value; + } + return children; + } + + public Trackable Root + { + get + { + if (_root_ref.TryGetTarget(out Trackable res)) + { + return res; + } + else + { + throw new InvalidDataException( + "Cannot get the object from the weak reference. Please consider if a null reference is passed to the constructor."); + } + } + } + + /// + /// Returns a list of all nodes and its paths from self.root using a breadth first traversal. + /// Corresponding to tensorflow/python/checkpoint/trackable_view.Trackable._descendants_with_paths + /// + protected (IList, IDictionary>) _descendants_with_paths() + { + List bfs_sorted = new(); + Queue to_visit = new(); + to_visit.Enqueue(Root); + Dictionary> node_paths = new(); + node_paths[this.Root] = new List(); + while (!to_visit.empty()) + { + var current_trackable = to_visit.Dequeue(); + bfs_sorted.Add(current_trackable); + var children_dict = this.children(current_trackable); + foreach (var name in children_dict.Keys) + { + var dependency = children_dict[name]; + if (!node_paths.ContainsKey(dependency)) + { + var list = new List(node_paths[current_trackable]); + list.Add(new TrackableReference(name, dependency)); + node_paths[dependency] = list; + to_visit.Enqueue(dependency); + } + } + } + + return (bfs_sorted, node_paths); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs new file mode 100644 index 000000000..0c2862dac --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs @@ -0,0 +1,195 @@ +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Contexts; +using Tensorflow.Eager; +using Tensorflow.Train; +using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; +using static Tensorflow.Binding; + +namespace Tensorflow.Checkpoint; + +/// +/// Saves and restores a `Trackable` object and its dependencies. +/// +public class TrackableSaver +{ + private ObjectGraphView _graph_view; + private Tensor _cached_save_operation; + private TrackableObjectGraph _last_save_object_graph; + private Tensor? _object_graph_feed_tensor = null; + private Tensor? _file_prefix_feed_tensor = null; + private Dictionary? _object_map = null; + private object? _cache = null; + public TrackableSaver(ObjectGraphView graph_view) + { + _graph_view = graph_view; + + // TODO: cache when not executing eagerly. + // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, + // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` + + } + + private (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) + gather_serialized_tensors(Tensor? object_graph_tensor = null) + { + var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); + + // TODO: cache. + + if(object_graph_tensor is null) + { + tf.device("/cpu:0"); + object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); + } + else + { + feed_additions[object_graph_tensor] = graph_proto.ToByteArray(); + } + Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); + if (!serialized_tensors.ContainsKey(Trackable.None)) + { + serialized_tensors[Trackable.None] = new Dictionary>>(); + } + serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; + return (serialized_tensors, feed_additions, registered_savers, graph_proto); + } + + private (Tensor, IDictionary) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) + { + var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); + + Func<(Tensor, IDictionary)> run_save = () => + { + if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) + { + var saver = new MultiDeviceSaver(serialized_tensors, registered_savers); + var save_op = saver.save(file_prefix, options); + + // tensorflow python: `with ops.device("/cpu:0"):` + using (ops.control_dependencies(new object[] { save_op })) + { + _cached_save_operation = array_ops.identity(file_prefix); + } + _last_save_object_graph = graph_proto; + } + return (_cached_save_operation, feed_additions); + }; + + if (options.experimental_enable_async_checkpoint) + { + throw new NotImplementedException(); + } + + return run_save(); + } + + private (Tensor, IDictionary) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options) + { + var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); + + Func<(Tensor, IDictionary)> run_save = () => + { + if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) + { + var saver = new MultiDeviceSaver(serialized_tensors, registered_savers); + var save_op = saver.save(file_prefix, options); + + // tensorflow python: `with ops.device("/cpu:0"):` + using (ops.control_dependencies(new object[] {save_op} )) + { + _cached_save_operation = array_ops.identity(tf.constant(file_prefix)); + } + _last_save_object_graph = graph_proto; + } + return (_cached_save_operation, feed_additions); + }; + + if (options.experimental_enable_async_checkpoint) + { + throw new NotImplementedException(); + } + + return run_save(); + } + + // TODO: parameter write_done_callback + public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null, + CheckpointOptions? options = null) + { + if (options is null) + { + options = new CheckpointOptions(); + } + + Dictionary feed_dict = new(); + bool use_session = (!tf.Context.executing_eagerly() && !ops.inside_function()); + if (checkpoint_number is not null) + { + file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; + } + + Tensor file_prefix_tensor; + Tensor object_graph_tensor; + string file_prefix_to_save; + if (use_session) + { + if (_object_graph_feed_tensor is null) + { + // In python there is `with ops.device("/cpu:0")`. + _object_graph_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING); + _file_prefix_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING); + } + + object_graph_tensor = _object_graph_feed_tensor; + file_prefix_tensor = _file_prefix_feed_tensor; + feed_dict[file_prefix_tensor] = file_prefix; + file_prefix_to_save = ""; + } + else + { + // In python there is `with ops.device("/cpu:0")`. + file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING); + object_graph_tensor = null; + file_prefix_to_save = file_prefix; + } + + var (save_path, new_feed_additions) = + save_cached_when_graph_building(file_prefix_to_save, object_graph_tensor, options); + + if (new_feed_additions is not null) + { + foreach (var pair in new_feed_additions) + { + feed_dict.Add(pair.Key, pair.Value); + } + } + if(!use_session) + { + session = null; + } + else if (session is null) + { + session = new Session(); // In python it uses `get_session`. + } + + if (session is not null) + { + var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray(); + return session.run((Tensor)save_path, s); + } + else if (use_session) + { + throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " + + "in graph mode without a default session. Please use " + + "`with tf.Session():` to create a session."); + } + else + { + return save_path; + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs new file mode 100644 index 000000000..09904d684 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs @@ -0,0 +1,540 @@ +using System; +using System.Buffers.Text; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; +using static Tensorflow.ApiDef.Types; +using static Tensorflow.CostGraphDef.Types; +using static Tensorflow.OptimizerOptions.Types; +using static Tensorflow.Binding; +using System.Text.RegularExpressions; +using System.Linq; +using Tensorflow.Operations; +using Tensorflow.Training; +using Tensorflow.Graphs; +using System.Xml.Linq; +using System.Diagnostics; +using RestoreFunc = System.Func; + +namespace Tensorflow.Checkpoint +{ + public class Maybe + { + private TA? _valueA = default(TA); + private TB? _valueB = default(TB); + private Type _type; + private bool _assignedTA; + public Maybe(TA value) + { + _valueA = value; + _type= typeof(TA); + _assignedTA = true; + } + public Maybe(TB value) + { + _valueB = value; + _type = typeof(TB); + _assignedTA = false; + } + + public Type DataType => _type; + + /// + /// Try to get the type T member of this instance. It returns true when TA or TB derive from T and is correspondingly assigned. + /// It returns + /// + /// + /// + /// + public bool TryGet(out T? res) + { + if(_valueA is T && _valueB is not T) + { + res = (T)(object)_valueA; + return _assignedTA; + } + else if(_valueA is not T && _valueB is T) + { + res = (T)(object)_valueB; + return !_assignedTA; + } + res = default(T); + return false; + } + + public bool IsTypeOrDeriveFrom() + { + if (_valueA is T && _valueB is not T) + { + return _assignedTA; + } + else if (_valueA is not T && _valueB is T) + { + return !_assignedTA; + } + else if (_valueA is T && _valueB is T) + { + return true; + } + else + { + return false; + } + } + + public T GetValue() + { + if (_valueA is T && _valueB is not T) + { + return (T)(object)_valueA; + } + else if (_valueA is not T && _valueB is T) + { + return (T)(object)_valueB; + } + else if (_valueA is T && _valueB is T) + { + throw new TypeError("The type is vague, this is always because TA and TB both derive from T."); + } + else + { + throw new TypeError($"Expected {typeof(TA)} or {typeof(TB)}, but got typeof{typeof(T)}."); + } + } + + public static implicit operator Maybe(TA a) + { + return new Maybe(a); + } + public static implicit operator Maybe(TB b) + { + return new Maybe(b); + } + } + internal class SingleDeviceSaver + { + private IDictionary>> _tensor_slice_dict; + public SingleDeviceSaver(IDictionary>> tensor_slice_dict) + { + _tensor_slice_dict = tensor_slice_dict; + } + public SingleDeviceSaver(IDictionary> tensor_slice_dict) + { + _tensor_slice_dict = tensor_slice_dict.ToDictionary( + x => x.Key, x => x.Value.ToDictionary( + y => y.Key, y => new Maybe(y.Value)) + as IDictionary>); + } + public SingleDeviceSaver(IDictionary> tensor_slice_dict) + { + _tensor_slice_dict = tensor_slice_dict.ToDictionary( + x => x.Key, x => x.Value.ToDictionary( + y => y.Key, y => new Maybe(y.Value)) + as IDictionary>); + } + public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) + { + if(options is null) + { + options = new CheckpointOptions(); + } + List tensor_names = new(); + List tensors = new(); + List slice_specs = new(); + foreach(var pair in _tensor_slice_dict) + { + var checkpoint_key = pair.Key; + var tensor_slices = pair.Value; + foreach(var slice in tensor_slices) + { + var slice_spec = slice.Key; + var maybe_tensor = slice.Value; + if(maybe_tensor.TryGet(out var spec)) + { + var tensor_value = spec.tensor; + if (tensor_value is not null) + { + tensor_names.Add(spec.name); + tensors.Add(tensor_value); + slice_specs.Add(spec.slice_spec); + } + } + else + { + var tensor = maybe_tensor.GetValue(); + tensor_names.Add(checkpoint_key); + tensors.Add(tensor); + slice_specs.Add(slice_spec); + } + } + } + // TODO: specify the device. + return tf.io.save_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensors.ToArray()); + } + + public Operation? save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix, TF_DataType.TF_STRING), options); + + public IDictionary> restore(Tensor file_prefix, CheckpointOptions? options = null) + { + if(options is null) + { + options = new CheckpointOptions(); + } + List tensor_names = new(); + List tensor_dtypes = new(); + List slice_specs = new(); + + foreach(var pair in _tensor_slice_dict) + { + var checkpoint_key = pair.Key; + var tensor_slices = pair.Value; + foreach(var slice in tensor_slices) + { + var slice_spec = slice.Key; + var maybe_tensor = slice.Value; + // TODO: deal with other types. Currently only `SaveSpec` is allowed. + if(maybe_tensor.TryGet(out var spec)) + { + tensor_dtypes.Add(spec.dtype); + slice_specs.Add(spec.slice_spec); + tensor_names.Add(spec.name); + } + else + { + var tensor = maybe_tensor.GetValue(); + tensor_dtypes.Add(tensor.dtype); + slice_specs.Add(slice_spec); + tensor_names.Add(checkpoint_key); + } + } + } + + string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!; + + // tf python has code `with ops.device(restore_device):` here. + tf.device(restore_device); // may be risky. + var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); + + Dictionary> restored_tensor_dict = new(); + int idx = 0; + foreach(var pair in _tensor_slice_dict) + { + var checkpoint_key = pair.Key; + var tensor_slices = pair.Value; + foreach(var slice_spec in tensor_slices.Keys) + { + var restored_tensor = restored_tensors[idx++]; + if (!restored_tensor_dict.ContainsKey(checkpoint_key)) + { + restored_tensor_dict[checkpoint_key] = new Dictionary(); + } + restored_tensor_dict[checkpoint_key][slice_spec] = restored_tensor; + } + } + return restored_tensor_dict; + } + + public IDictionary> restore(string file_prefix, CheckpointOptions? options = null) => restore(tf.constant(file_prefix)); + } + /// + /// Saves checkpoints directly from multiple devices. + /// Note that this is a low-level utility which stores Tensors in the keys + /// specified by `SaveableObject`s.Higher-level utilities for object-based + /// checkpointing are built on top of it. + /// + public class MultiDeviceSaver + { + private Dictionary _single_device_savers; + private IDictionary _registered_savers; + private Dictionary<(string, string), RestoreFunc> _keys_to_restore_fn; + private Dictionary> _restore_fn_to_keys; + /// + /// + /// + /// A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. + /// + /// + public MultiDeviceSaver(IDictionary>>> serialized_tensors, + IDictionary>? registered_savers = null, bool call_with_mapped_capture = false) + { + _keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>(); + _restore_fn_to_keys = new Dictionary>(); + Dictionary>> tensors_by_device= new(); + + foreach(var pair in serialized_tensors) + { + var obj = pair.Key; + var tensor_dict = pair.Value; + RestoreFunc restore_fn; + if(obj == Trackable.None) + { + restore_fn = new RestoreFunc(x => null); + } + else + { + restore_fn = new RestoreFunc(x => + { + if(x is IDictionary>>) + { + return obj._restore_from_tensors(x as IDictionary>>); + } + throw new TypeError($"Expected `IDictionary>>` as input, got{x.GetType()}."); + }); + } + + foreach(var item in tensor_dict) + { + var checkpoint_key = item.Key; + IDictionary spec_to_tensor; + if(item.Value.TryGet(out var t)) + { + spec_to_tensor = new Dictionary(); + spec_to_tensor[""] = t; + } + else + { + spec_to_tensor = item.Value.GetValue>(); + } + + foreach(var spec in spec_to_tensor) + { + var slice_spec = spec.Key; + var tensor = spec.Value; + if(_keys_to_restore_fn.ContainsKey((checkpoint_key, slice_spec))) + { + throw new ValueError("Recieved multiple tensors with the same checkpoint key and " + + $"slice spec. This is invalid because one will overwrite the " + + $"other in the checkpoint. This indicates a bug in the Checkpoint key-generation."); + } + _keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn; + _restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); + + // skip the process of device name because lack of API. + var host_device = tensor.Device; + var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary>()); + if (!internal_dict.ContainsKey(checkpoint_key)) + { + internal_dict[checkpoint_key] = new Dictionary(); + } + internal_dict[checkpoint_key][slice_spec] = tensor; + } + } + } + + _single_device_savers = tensors_by_device.ToDictionary(x => x.Key, x => new SingleDeviceSaver(x.Value)); + + _registered_savers = new Dictionary(); + if(registered_savers is not null && registered_savers.Count > 0) + { + // TODO: complete the implementation. + throw new NotImplementedException(); + } + } + + public Operation save(Tensor file_prefix, CheckpointOptions? options= null) + { + if(options is null) + { + options = new CheckpointOptions(); + } + + tf.device("CPU"); // may be risky. + var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")), + constant_op.constant(".part"), constant_op.constant("_temp/part")); + var tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix }); + IDictionary registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); + + Operation save_fn() + { + List saved_prefixes= new(); + foreach(var saver in _registered_savers) + { + // TODO: implementi it later. + throw new NotImplementedException(); + } + + int num_shards = _single_device_savers.Count; + List sharded_saves = new(); + var num_shards_tensor = constant_op.constant(num_shards, name: "num_shards"); + string? last_device = null; + int shard = 0; + foreach(var pair in _single_device_savers.OrderBy(x => x.Key)) + { + var device = pair.Key; + var saver = pair.Value; + last_device = device; + // skip the extra process of device name because of lack of API. + tf.device(device); + var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); + saved_prefixes.Add(shard_prefix); + sharded_saves.Add(saver.save(shard_prefix, options)); + } + using (var controller = ops.control_dependencies(sharded_saves.ToArray())) + { + string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; + tf.device(merge_device); + return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); + } + } + + if(tf.Context.executing_eagerly() && _single_device_savers.Count > 1) + { + // TODO: implement it. Currently `autograph` does not support the function with non parameter. + throw new NotImplementedException(); + } + else + { + return save_fn(); + } + } + + public Operation save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix), options); + + public IDictionary restore(Tensor file_prefix, CheckpointOptions? options = null) + { + if(options is null) + { + options = new CheckpointOptions(); + } + + IDictionary restore_func() + { + Dictionary>>> restore_fn_inputs = new(); + Dictionary restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); + Dictionary restore_ops = new(); + + foreach(var single_saver in _single_device_savers.OrderBy(x => x.Key)) + { + var device = single_saver.Key; + var saver = single_saver.Value; + tf.device(device); + var restored_tensor_dict = saver.restore(file_prefix, options); + + foreach(var pair in restored_tensor_dict) + { + var checkpoint_key = pair.Key; + var slice_and_tensor = pair.Value; + foreach(var item in slice_and_tensor) + { + var slice_spec = item.Key; + var tensor = item.Value; + var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; + var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary>>()); + if (!string.IsNullOrEmpty(slice_spec)) + { + if (!internal_dict.ContainsKey(checkpoint_key)) + { + Dictionary dict = new(); + dict[slice_spec] = tensor; + internal_dict[checkpoint_key] = new Maybe>(dict); + } + else + { + internal_dict[checkpoint_key].GetValue>()[slice_spec] = tensor; + } + } + else + { + internal_dict[checkpoint_key] = new Maybe>(tensor); + } + restore_fn_input_count[restore_fn]--; + + if (restore_fn_input_count[restore_fn] == 0) + { + Dictionary>> restored_tensors = new(); + foreach(var input in restore_fn_inputs[restore_fn]) + { + restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; + } + var ret = restore_fn.DynamicInvoke(restored_tensors); + if(ret is IDictionary) + { + var dict = (IDictionary)ret; + restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); + } + } + } + } + } + + foreach(var item in _registered_savers) + { + throw new NotImplementedException(); + } + return restore_ops; + } + + // TODO: complete the implementation. Currently skip it because of lack of API. + bool has_custom_device_saver = false; + + if (tf.Context.executing_eagerly() && (_single_device_savers.Count > 1 || has_custom_device_saver)) + { + // TODO: implement it. Currently `autograph` does not support the function with non parameter. + throw new NotImplementedException(); + } + else + { + return restore_func(); + } + } + + /// + /// Serializes to a SaverDef referencing the current graph. + /// + public SaverDef to_proto() + { + var filename_tensor = array_ops.placeholder(TF_DataType.TF_STRING, new int[] { }, "saver_filename"); + var traced_save_func = tf.autograph.to_graph(_traced_save, TF_DataType.TF_STRING); + var traced_restore_func = tf.autograph.to_graph(_traced_restore, TF_DataType.TF_STRING); + var save_tensor = traced_save_func(filename_tensor); + var restore_op = traced_restore_func(filename_tensor).op; + return new SaverDef() + { + FilenameTensorName = filename_tensor.name, + SaveTensorName = save_tensor.name, + RestoreOpName = restore_op.name, + Version = SaverDef.Types.CheckpointFormatVersion.V2 + }; + } + + private Tensor _traced_save(Tensor file_prefix) + { + var save_op = save(file_prefix); + tf.device("cpu:0"); + using (ops.control_dependencies(new object[]{ save_op })) + { + return array_ops.identity(file_prefix); + } + } + + private Tensor _traced_restore(Tensor file_prefix) + { + var restore_op = restore(file_prefix); + tf.device("cpu:0"); + using (ops.control_dependencies(restore_op.Values.ToArray())) + { + return array_ops.identity(file_prefix); + } + } + + public static MultiDeviceSaver from_saveables(IEnumerable saveables, IDictionary>? registered_savers = null, bool call_with_mapped_captures = false) + { + Dictionary>>> serialized_tensors = new(); + foreach (var saveable in saveables) + { + var trackable = new SaveableCompatibilityConverter(saveable, new List() { saveable }); + serialized_tensors[trackable] = trackable.serialize_to_tensors(); + } + return new MultiDeviceSaver(serialized_tensors, registered_savers, call_with_mapped_captures); + } + + private static Tensor registered_saver_filename(Tensor filename_tensor, string saver_name) + { + return gen_ops.string_join(new Tensor[] { filename_tensor, constant_op.constant($"-{saver_name}") }); + } + private static Tensor sharded_filename(Tensor filename_tensor, int shard, Tensor num_shards) + { + return gen_ops.sharded_filename(filename_tensor, tf.constant(shard), num_shards); + } + } +} diff --git a/src/TensorFlowNET.Core/DisposableObject.cs b/src/TensorFlowNET.Core/DisposableObject.cs index 3c70739bd..c3c677fff 100644 --- a/src/TensorFlowNET.Core/DisposableObject.cs +++ b/src/TensorFlowNET.Core/DisposableObject.cs @@ -17,6 +17,7 @@ limitations under the License. using System; using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; +using Tensorflow.Train; namespace Tensorflow { @@ -90,4 +91,71 @@ public void Dispose() Dispose(false); } } -} \ No newline at end of file + + public abstract class DisposableTrackableObject: Trackable, IDisposable + { + protected IntPtr _handle; + protected bool _disposed; + + protected DisposableTrackableObject() + { } + + protected DisposableTrackableObject(IntPtr handle) + => _handle = handle; + + private void Dispose(bool disposing) + { + if (_disposed) + return; + + //first handle managed, they might use the unmanaged resources. + if (disposing) + { + // dispose managed state (managed objects). + DisposeManagedResources(); + } + + // free unmanaged memory + if (_handle != IntPtr.Zero) + { + // Call the appropriate methods to clean up + // unmanaged resources here. + // If disposing is false, + // only the following code is executed. + DisposeUnmanagedResources(_handle); + _handle = IntPtr.Zero; + } + + // Note disposing has been done. + _disposed = true; + } + + /// + /// Dispose any managed resources. + /// + /// Equivalent to what you would perform inside + protected virtual void DisposeManagedResources() + { } + + /// + /// Dispose any unmanaged resources related to given . + /// + protected abstract void DisposeUnmanagedResources(IntPtr handle); + + public void Dispose() + { + Dispose(true); + // This object will be cleaned up by the Dispose method. + // Therefore, you should call GC.SupressFinalize to + // take this object off the finalization queue + // and prevent finalization code for this object + // from executing a second time. + GC.SuppressFinalize(this); + } + + ~DisposableTrackableObject() + { + Dispose(false); + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/execute.cs b/src/TensorFlowNET.Core/Eager/execute.cs new file mode 100644 index 000000000..cb3ea4d3c --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/execute.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Xml.Linq; +using Tensorflow.Contexts; +using static Tensorflow.ApiDef.Types; +using static Tensorflow.CostGraphDef.Types; +using static Tensorflow.Binding; + +namespace Tensorflow.Eager +{ + internal class execute + { + public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx) + { + var v = values.Select(t => ops.convert_to_tensor(t, ctx:ctx)); + var types = v.Select(t => t.dtype.as_datatype_enum()); + return (types.ToArray(), v.ToArray()); + } + public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) + { + string device_name = ctx.DeviceName; + + ctx.ensure_initialized(); + var tensors = tf.Runner.TFE_Execute(ctx, device_name, op_name, inputs, attrs, num_outputs); + + return tensors; + } + } +} diff --git a/src/TensorFlowNET.Core/Exceptions/AssertionError.cs b/src/TensorFlowNET.Core/Exceptions/AssertionError.cs new file mode 100644 index 000000000..977fe2340 --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/AssertionError.cs @@ -0,0 +1,14 @@ +namespace Tensorflow.Exceptions; + +public class AssertionError : TensorflowException +{ + public AssertionError() : base() + { + + } + + public AssertionError(string message) : base(message) + { + + } +} diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.cs b/src/TensorFlowNET.Core/Framework/meta_graph.cs index 6ce3bf3c5..c3616fafd 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.cs @@ -304,7 +304,7 @@ private static void add_collection_def(MetaGraphDef meta_graph_def, } } - private static OpList stripped_op_list_for_graph(GraphDef graph_def) + public static OpList stripped_op_list_for_graph(GraphDef graph_def) { var used_ops = ops_used_by_graph_def(graph_def); @@ -345,5 +345,89 @@ private static string[] ops_used_by_graph_def(GraphDef graph_def) return used_ops.ToArray(); } + + private static bool is_default_attr_value(OpDef op_def, string attr_name, AttrValue attr_value) + { + foreach (var attr_def in op_def.Attr) + { + if (attr_def.Name == attr_name) + { + if (attr_def.DefaultValue is null) return false; + // TODO: add new c_api `EqualAttrValueWrapper` and complete the check. + return true; + } + } + + return false; + } + + public static void strip_graph_default_valued_attrs(MetaGraphDef meta_graph_def) + { + Dictionary op_name_to_function = new(); + foreach (var function_def in meta_graph_def.GraphDef.Library.Function) + { + op_name_to_function[function_def.Signature.Name] = function_def; + } + + Action _strip_node_default_valued_attrs = (node_def) => + { + if (op_name_to_function.ContainsKey(node_def.Op)) return; + + var op_def = op_def_registry.GetOpDef(node_def.Op); + if(op_def is null) return; + + HashSet attrs_to_strip = new(); + foreach (var attr in node_def.Attr) + { + if (is_default_attr_value(op_def, attr.Key, attr.Value)) + { + attrs_to_strip.Add(attr.Key); + } + } + + foreach (var attr in attrs_to_strip) + { + node_def.Attr.Remove(attr); + } + }; + + foreach (var node_def in meta_graph_def.GraphDef.Node) + { + _strip_node_default_valued_attrs(node_def); + } + + foreach (var function_def in meta_graph_def.GraphDef.Library.Function) + { + foreach (var function_node_def in function_def.NodeDef) + { + _strip_node_default_valued_attrs(function_node_def); + } + } + + meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; + } + + /// + /// Extract the Op name from a Tensor name. + /// + /// + /// + public static string op_name(string tensor_name) + { + if (string.IsNullOrEmpty(tensor_name)) + { + throw new ValueError($"Tensor name cannot be empty or None. Received: {tensor_name}."); + } + + if (tensor_name.StartsWith("^")) + { + tensor_name = tensor_name.Substring(1); + } + if (tensor_name.Contains(":")) + { + return tensor_name.Split(':')[0]; + } + return tensor_name; + } } } diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index c52d0b5f5..bac9cedbf 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -3,6 +3,7 @@ using System.Linq; using Tensorflow.Framework.Models; using Tensorflow.Graphs; +using Tensorflow.Train; using static Tensorflow.Binding; namespace Tensorflow.Functions @@ -10,7 +11,7 @@ namespace Tensorflow.Functions /// /// /// - public class ConcreteFunction + public class ConcreteFunction: Trackable { FuncGraph func_graph; ForwardBackwardCall forward_backward; diff --git a/src/TensorFlowNET.Core/Functions/Function.cs b/src/TensorFlowNET.Core/Functions/Function.cs index d57097ae9..056d15f4d 100644 --- a/src/TensorFlowNET.Core/Functions/Function.cs +++ b/src/TensorFlowNET.Core/Functions/Function.cs @@ -1,16 +1,23 @@ using System; +using Tensorflow.Train; namespace Tensorflow { - public class Function + public class Function: Trackable { #pragma warning disable CS0169 // The field 'Function._handle' is never used private IntPtr _handle; #pragma warning restore CS0169 // The field 'Function._handle' is never used - + + public string Name { get; set; } public Function() { } + + public Function(string name) + { + Name = name; + } } } diff --git a/src/TensorFlowNET.Core/Graphs/AutoGraph.cs b/src/TensorFlowNET.Core/Graphs/AutoGraph.cs index 2af1a3720..ceeca8abf 100644 --- a/src/TensorFlowNET.Core/Graphs/AutoGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/AutoGraph.cs @@ -1,4 +1,5 @@ using System; +using System.Diagnostics; using System.Linq; using static Tensorflow.Binding; @@ -6,14 +7,14 @@ namespace Tensorflow.Graphs { public class AutoGraph { - public Func to_graph(Func func) + public Func to_graph(Func func, TF_DataType dtype = TF_DataType.TF_INT32) { string func_name = $"{func.Method.Name}_{ops.uid_function()}"; var graph = new FuncGraph(func_name); graph.as_default(); - var input = tf.placeholder(tf.int32); + var input = tf.placeholder(dtype); var output = func(input); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); @@ -26,25 +27,33 @@ public Func to_graph(Func func) return (Tensor input) => { - var result = tf.Runner.TFE_Execute(tf.Context, - tf.Context.DeviceName, - func_name, - new[] { input }, - null, - 1); - return result[0]; + if (tf.executing_eagerly()) + { + var result = tf.Runner.TFE_Execute(tf.Context, + tf.Context.DeviceName, + func_name, + new[] { input }, + null, + 1); + return result[0]; + } + using (var s = tf.Session(input.graph)) + { + var output = func(input); + return output; + } }; } - public Func to_graph(Func func) + public Func to_graph(Func func, params TF_DataType[] dtypes) { string func_name = $"{func.Method.Name}_{ops.uid_function()}"; var graph = new FuncGraph(func_name); graph.as_default(); - var input1 = tf.placeholder(tf.int32); - var input2 = tf.placeholder(tf.int32); + var input1 = tf.placeholder(dtypes.Length >= 1 ? dtypes[0] : tf.int32); + var input2 = tf.placeholder(dtypes.Length >= 2 ? dtypes[1] : tf.int32); var output = func(input1, input2); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); @@ -56,13 +65,22 @@ public Func to_graph(Func func) return (Tensor a, Tensor b) => { - var result = tf.Runner.TFE_Execute(tf.Context, + if (tf.executing_eagerly()) + { + var result = tf.Runner.TFE_Execute(tf.Context, tf.Context.DeviceName, func_name, new[] { a, b }, null, 1); - return result[0]; + return result[0]; + } + using (var s = tf.Session(a.graph)) + { + Debug.Assert(a.graph == b.graph); + var output = func(a, b); + return output; + } }; } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ELUArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ELUArgs.cs index 235523161..e830e5bf8 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ELUArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ELUArgs.cs @@ -1,9 +1,12 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; namespace Tensorflow.Keras.ArgsDefinition { - public class ELUArgs : LayerArgs { - public float Alpha { get; set; } = 0.1f; - } + public class ELUArgs : AutoSerializeLayerArgs + { + [JsonProperty("alpha")] + public float Alpha { get; set; } = 0.1f; + } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/LeakyReLuArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/LeakyReLuArgs.cs index 6bdb294c2..6d9531346 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/LeakyReLuArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/LeakyReLuArgs.cs @@ -1,14 +1,16 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; namespace Tensorflow.Keras.ArgsDefinition { - public class LeakyReLuArgs : LayerArgs + public class LeakyReLuArgs : AutoSerializeLayerArgs { /// /// Negative slope coefficient. /// + [JsonProperty("alpha")] public float Alpha { get; set; } = 0.3f; } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs index ca35d75d5..1c1d147f1 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs @@ -1,9 +1,12 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; namespace Tensorflow.Keras.ArgsDefinition { - public class SoftmaxArgs : LayerArgs { - public Axis axis { get; set; } = -1; - } + public class SoftmaxArgs : AutoSerializeLayerArgs + { + [JsonProperty("axis")] + public Axis axis { get; set; } = -1; + } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs index 73477c58f..4cdfb46bd 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs @@ -1,3 +1,5 @@ +using Newtonsoft.Json; + namespace Tensorflow.Keras.ArgsDefinition { public class AttentionArgs : BaseDenseAttentionArgs @@ -6,6 +8,7 @@ public class AttentionArgs : BaseDenseAttentionArgs /// /// If `true`, will create a scalar variable to scale the attention scores. /// + [JsonProperty("use_scale")] public bool use_scale { get; set; } = false; /// @@ -14,6 +17,7 @@ public class AttentionArgs : BaseDenseAttentionArgs /// and key vectors. `"concat"` refers to the hyperbolic tangent of the /// concatenation of the query and key vectors. /// + [JsonProperty("score_mode")] public string score_mode { get; set; } = "dot"; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs index b2a0c3a51..0ef017370 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs @@ -1,6 +1,8 @@ +using Newtonsoft.Json; + namespace Tensorflow.Keras.ArgsDefinition { - public class BaseDenseAttentionArgs : LayerArgs + public class BaseDenseAttentionArgs : AutoSerializeLayerArgs { /// @@ -14,6 +16,7 @@ public class BaseDenseAttentionArgs : LayerArgs /// Float between 0 and 1. Fraction of the units to drop for the /// attention scores. /// + [JsonProperty("dropout")] public float dropout { get; set; } = 0f; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/MultiHeadAttentionArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/MultiHeadAttentionArgs.cs index 21b2d218c..077dea89d 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/MultiHeadAttentionArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/MultiHeadAttentionArgs.cs @@ -1,22 +1,40 @@ +using Newtonsoft.Json; using System; using static Tensorflow.Binding; namespace Tensorflow.Keras.ArgsDefinition { - public class MultiHeadAttentionArgs : LayerArgs + public class MultiHeadAttentionArgs : AutoSerializeLayerArgs { + [JsonProperty("num_heads")] public int NumHeads { get; set; } + [JsonProperty("key_dim")] public int KeyDim { get; set; } + [JsonProperty("value_dim")] public int? ValueDim { get; set; } = null; + [JsonProperty("dropout")] public float Dropout { get; set; } = 0f; + [JsonProperty("use_bias")] public bool UseBias { get; set; } = true; + [JsonProperty("output_shape")] public Shape OutputShape { get; set; } = null; + [JsonProperty("attention_axes")] public Shape AttentionAxis { get; set; } = null; + [JsonProperty("kernel_initializer")] public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; + [JsonProperty("bias_initializer")] public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; + [JsonProperty("kernel_regularizer")] public IRegularizer KernelRegularizer { get; set; } = null; + [JsonProperty("bias_regularizer")] public IRegularizer BiasRegularizer { get; set; } = null; + [JsonProperty("kernel_constraint")] public Action KernelConstraint { get; set; } = null; + [JsonProperty("bias_constraint")] public Action BiasConstraint { get; set; } = null; + [JsonProperty("activity_regularizer")] + public override IRegularizer ActivityRegularizer { get => base.ActivityRegularizer; set => base.ActivityRegularizer = value; } + + // TODO: Add `key_shape`, `value_shape`, `query_shape`. } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs new file mode 100644 index 000000000..1a97b0135 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs @@ -0,0 +1,25 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + /// + /// This class has nothing but the attributes different from `LayerArgs`. + /// It's used to serialize the model to `tf` format. + /// If the `get_config` of a `Layer` in python code of tensorflow contains `super().get_config`, + /// then the Arg definition should inherit `utoSerializeLayerArgs` instead of `LayerArgs`. + /// + public class AutoSerializeLayerArgs: LayerArgs + { + [JsonProperty("name")] + public override string Name { get => base.Name; set => base.Name = value; } + [JsonProperty("dtype")] + public override TF_DataType DType { get => base.DType; set => base.DType = value; } + [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] + public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } + [JsonProperty("trainable")] + public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs index 4f050228b..08d563c1a 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs @@ -1,31 +1,65 @@ -using System; +using Newtonsoft.Json; +using System; using static Tensorflow.Binding; namespace Tensorflow.Keras.ArgsDefinition { - public class ConvolutionalArgs : LayerArgs + public class ConvolutionalArgs : AutoSerializeLayerArgs { public int Rank { get; set; } = 2; + [JsonProperty("filters")] public int Filters { get; set; } public int NumSpatialDims { get; set; } = Unknown; + [JsonProperty("kernel_size")] public Shape KernelSize { get; set; } = 5; /// /// specifying the stride length of the convolution. /// + [JsonProperty("strides")] public Shape Strides { get; set; } = (1, 1); - + [JsonProperty("padding")] public string Padding { get; set; } = "valid"; + [JsonProperty("data_format")] public string DataFormat { get; set; } + [JsonProperty("dilation_rate")] public Shape DilationRate { get; set; } = (1, 1); + [JsonProperty("groups")] public int Groups { get; set; } = 1; public Activation Activation { get; set; } + private string _activationName; + [JsonProperty("activation")] + public string ActivationName + { + get + { + if (string.IsNullOrEmpty(_activationName)) + { + return Activation.Method.Name; + } + else + { + return _activationName; + } + } + set + { + _activationName = value; + } + } + [JsonProperty("use_bias")] public bool UseBias { get; set; } + [JsonProperty("kernel_initializer")] public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; + [JsonProperty("bias_initializer")] public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; + [JsonProperty("kernel_regularizer")] public IRegularizer KernelRegularizer { get; set; } + [JsonProperty("bias_regularizer")] public IRegularizer BiasRegularizer { get; set; } + [JsonProperty("kernel_constraint")] public Action KernelConstraint { get; set; } + [JsonProperty("bias_constraint")] public Action BiasConstraint { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs index e9b3c2fd9..8f4facbd4 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs @@ -1,13 +1,18 @@ -using System; +using Newtonsoft.Json; +using System; +using System.Xml.Linq; +using Tensorflow.Operations.Initializers; using static Tensorflow.Binding; namespace Tensorflow.Keras.ArgsDefinition { + // TODO: `activity_regularizer` public class DenseArgs : LayerArgs { /// /// Positive integer, dimensionality of the output space. /// + [JsonProperty("units")] public int Units { get; set; } /// @@ -15,39 +20,74 @@ public class DenseArgs : LayerArgs /// public Activation Activation { get; set; } + private string _activationName; + [JsonProperty("activation")] + public string ActivationName + { + get + { + if (string.IsNullOrEmpty(_activationName)) + { + return Activation.Method.Name; + } + else + { + return _activationName; + } + } + set + { + _activationName = value; + } + } + /// /// Whether the layer uses a bias vector. /// + [JsonProperty("use_bias")] public bool UseBias { get; set; } = true; /// /// Initializer for the `kernel` weights matrix. /// + [JsonProperty("kernel_initializer")] public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; /// /// Initializer for the bias vector. /// + [JsonProperty("bias_initializer")] public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; /// /// Regularizer function applied to the `kernel` weights matrix. /// + [JsonProperty("kernel_regularizer")] public IRegularizer KernelRegularizer { get; set; } /// /// Regularizer function applied to the bias vector. /// + [JsonProperty("bias_regularizer")] public IRegularizer BiasRegularizer { get; set; } /// /// Constraint function applied to the `kernel` weights matrix. /// + [JsonProperty("kernel_constraint")] public Action KernelConstraint { get; set; } /// /// Constraint function applied to the bias vector. /// + [JsonProperty("bias_constraint")] public Action BiasConstraint { get; set; } + + [JsonProperty("name")] + public override string Name { get => base.Name; set => base.Name = value; } + [JsonProperty("dtype")] + public override TF_DataType DType { get => base.DType; set => base.DType = value; } + [JsonProperty("trainable")] + public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/EinsumDenseArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EinsumDenseArgs.cs similarity index 65% rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/EinsumDenseArgs.cs rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EinsumDenseArgs.cs index 3a8642ffc..9817e9c6d 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/EinsumDenseArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EinsumDenseArgs.cs @@ -1,9 +1,10 @@ +using Newtonsoft.Json; using System; using static Tensorflow.Binding; -namespace Tensorflow.Keras.ArgsDefinition +namespace Tensorflow.Keras.ArgsDefinition.Core { - public class EinsumDenseArgs : LayerArgs + public class EinsumDenseArgs : AutoSerializeLayerArgs { /// /// An equation describing the einsum to perform. This equation must @@ -11,6 +12,7 @@ public class EinsumDenseArgs : LayerArgs /// `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum axis /// expression sequence. /// + [JsonProperty("equation")] public string Equation { get; set; } /// @@ -19,6 +21,7 @@ public class EinsumDenseArgs : LayerArgs /// None for any dimension that is unknown or can be inferred from the input /// shape. /// + [JsonProperty("output_shape")] public Shape OutputShape { get; set; } /// @@ -26,41 +29,70 @@ public class EinsumDenseArgs : LayerArgs /// Each character in the `bias_axes` string should correspond to a character /// in the output portion of the `equation` string. /// + [JsonProperty("bias_axes")] public string BiasAxes { get; set; } = null; /// /// Activation function to use. /// public Activation Activation { get; set; } + private string _activationName; + [JsonProperty("activation")] + public string ActivationName + { + get + { + if (string.IsNullOrEmpty(_activationName)) + { + return Activation.Method.Name; + } + else + { + return _activationName; + } + } + set + { + _activationName = value; + } + } /// /// Initializer for the `kernel` weights matrix. /// + [JsonProperty("kernel_initializer")] public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; /// /// Initializer for the bias vector. /// + [JsonProperty("bias_initializer")] public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; /// /// Regularizer function applied to the `kernel` weights matrix. /// + [JsonProperty("kernel_regularizer")] public IRegularizer KernelRegularizer { get; set; } /// /// Regularizer function applied to the bias vector. /// + [JsonProperty("bias_regularizer")] public IRegularizer BiasRegularizer { get; set; } /// /// Constraint function applied to the `kernel` weights matrix. /// + [JsonProperty("kernel_constraint")] public Action KernelConstraint { get; set; } /// /// Constraint function applied to the bias vector. /// + [JsonProperty("bias_constraint")] public Action BiasConstraint { get; set; } + [JsonProperty("activity_regularizer")] + public override IRegularizer ActivityRegularizer { get => base.ActivityRegularizer; set => base.ActivityRegularizer = value; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EmbeddingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EmbeddingArgs.cs index b1f4fddd3..c462961b3 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EmbeddingArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EmbeddingArgs.cs @@ -1,11 +1,22 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition { - public class EmbeddingArgs : LayerArgs + public class EmbeddingArgs : AutoSerializeLayerArgs { + [JsonProperty("input_dim")] public int InputDim { get; set; } + [JsonProperty("output_dim")] public int OutputDim { get; set; } + [JsonProperty("mask_zero")] public bool MaskZero { get; set; } + [JsonProperty("input_length")] public int InputLength { get; set; } = -1; + [JsonProperty("embeddings_initializer")] public IInitializer EmbeddingsInitializer { get; set; } + [JsonProperty("activity_regularizer")] + public override IRegularizer ActivityRegularizer { get => base.ActivityRegularizer; set => base.ActivityRegularizer = value; } + + // TODO: `embeddings_regularizer`, `embeddings_constraint`. } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs index 723109c27..be43e0a62 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs @@ -1,9 +1,22 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; +using Newtonsoft.Json.Serialization; +using Tensorflow.Keras.Common; + +namespace Tensorflow.Keras.ArgsDefinition { public class InputLayerArgs : LayerArgs { + [JsonIgnore] public Tensor InputTensor { get; set; } - public bool Sparse { get; set; } + [JsonProperty("sparse")] + public virtual bool Sparse { get; set; } + [JsonProperty("ragged")] public bool Ragged { get; set; } + [JsonProperty("name")] + public override string Name { get => base.Name; set => base.Name = value; } + [JsonProperty("dtype")] + public override TF_DataType DType { get => base.DType; set => base.DType = value; } + [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] + public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping2DArgs.cs deleted file mode 100644 index 16705063e..000000000 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping2DArgs.cs +++ /dev/null @@ -1,16 +0,0 @@ -using Tensorflow.NumPy; - -namespace Tensorflow.Keras.ArgsDefinition { - public class Cropping2DArgs : LayerArgs { - /// - /// channel last: (b, h, w, c) - /// channels_first: (b, c, h, w) - /// - public enum DataFormat { channels_first = 0, channels_last = 1 } - /// - /// Accept: int[1][2], int[1][1], int[2][2] - /// - public NDArray cropping { get; set; } - public DataFormat data_format { get; set; } = DataFormat.channels_last; - } -} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping3DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping3DArgs.cs deleted file mode 100644 index 9da2adc7f..000000000 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping3DArgs.cs +++ /dev/null @@ -1,16 +0,0 @@ -using Tensorflow.NumPy; - -namespace Tensorflow.Keras.ArgsDefinition { - public class Cropping3DArgs : LayerArgs { - /// - /// channel last: (b, h, w, c) - /// channels_first: (b, c, h, w) - /// - public enum DataFormat { channels_first = 0, channels_last = 1 } - /// - /// Accept: int[1][3], int[1][1], int[3][2] - /// - public NDArray cropping { get; set; } - public DataFormat data_format { get; set; } = DataFormat.channels_last; - } -} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/CroppingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/CroppingArgs.cs deleted file mode 100644 index 9d23acd43..000000000 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/CroppingArgs.cs +++ /dev/null @@ -1,10 +0,0 @@ -using Tensorflow.NumPy; - -namespace Tensorflow.Keras.ArgsDefinition { - public class CroppingArgs : LayerArgs { - /// - /// Accept length 1 or 2 - /// - public NDArray cropping { get; set; } - } -} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs index f3cca438f..8ce1ec655 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs @@ -1,8 +1,9 @@ using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.ArgsDefinition { - public class DataAdapterArgs + public class DataAdapterArgs: IKerasConfig { public Tensor X { get; set; } public Tensor Y { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs index b6e6849bc..fd603a85e 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs @@ -1,8 +1,9 @@ using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.ArgsDefinition { - public class DataHandlerArgs + public class DataHandlerArgs: IKerasConfig { public Tensor X { get; set; } public Tensor Y { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs index 4df4fb2b4..febf14176 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs @@ -1,51 +1,54 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.ArgsDefinition { - public class LayerArgs + [JsonObject(MemberSerialization.OptIn)] + public class LayerArgs: IKerasConfig { /// /// Indicates whether the layer's weights are updated during training /// and whether the layer's updates are run during training. /// - public bool Trainable { get; set; } = true; - - public string Name { get; set; } + public virtual bool Trainable { get; set; } = true; + public virtual string Name { get; set; } /// /// Only applicable to input layers. /// - public TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT; + public virtual TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT; /// /// Whether the `call` method can be used to build a TF graph without issues. /// This attribute has no effect if the model is created using the Functional /// API. Instead, `model.dynamic` is determined based on the internal layers. /// - public bool Dynamic { get; set; } = false; + public virtual bool Dynamic { get; set; } = false; /// /// Only applicable to input layers. /// - public Shape InputShape { get; set; } + public virtual Shape InputShape { get; set; } /// /// Only applicable to input layers. /// - public Shape BatchInputShape { get; set; } + public virtual Shape BatchInputShape { get; set; } - public int BatchSize { get; set; } = -1; + public virtual int BatchSize { get; set; } = -1; /// /// Initial weight values. /// - public float[] Weights { get; set; } + public virtual float[] Weights { get; set; } /// /// Regularizer function applied to the output of the layer(its "activation"). /// - public IRegularizer ActivityRegularizer { get; set; } + public virtual IRegularizer ActivityRegularizer { get; set; } - public bool Autocast { get; set; } + public virtual bool Autocast { get; set; } - public bool IsFromConfig { get; set; } + public virtual bool IsFromConfig { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMCellArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMCellArgs.cs deleted file mode 100644 index fb0868dc5..000000000 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMCellArgs.cs +++ /dev/null @@ -1,6 +0,0 @@ -namespace Tensorflow.Keras.ArgsDefinition.Lstm -{ - public class LSTMCellArgs : LayerArgs - { - } -} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs index 3e6791e3b..0140b3dd0 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs @@ -4,6 +4,7 @@ namespace Tensorflow.Keras.ArgsDefinition { + // TODO: complete the implementation public class MergeArgs : LayerArgs { public Tensors Inputs { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs index 0d9e26ac4..ad55ff612 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs @@ -1,6 +1,8 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.ArgsDefinition { - public class NodeArgs + public class NodeArgs: IKerasConfig { public ILayer[] InboundLayers { get; set; } public int[] NodeIndices { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/BatchNormalizationArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/BatchNormalizationArgs.cs index 954ede574..6ee91e80b 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/BatchNormalizationArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/BatchNormalizationArgs.cs @@ -1,21 +1,37 @@ -using static Tensorflow.Binding; +using Newtonsoft.Json; +using static Tensorflow.Binding; namespace Tensorflow.Keras.ArgsDefinition { - public class BatchNormalizationArgs : LayerArgs + public class BatchNormalizationArgs : AutoSerializeLayerArgs { + [JsonProperty("axis")] public Shape Axis { get; set; } = -1; + [JsonProperty("momentum")] public float Momentum { get; set; } = 0.99f; + [JsonProperty("epsilon")] public float Epsilon { get; set; } = 1e-3f; + [JsonProperty("center")] public bool Center { get; set; } = true; + [JsonProperty("scale")] public bool Scale { get; set; } = true; + [JsonProperty("beta_initializer")] public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer; + [JsonProperty("gamma_initializer")] public IInitializer GammaInitializer { get; set; } = tf.ones_initializer; + [JsonProperty("moving_mean_initializer")] public IInitializer MovingMeanInitializer { get; set; } = tf.zeros_initializer; + [JsonProperty("moving_variance_initializer")] public IInitializer MovingVarianceInitializer { get; set; } = tf.ones_initializer; + [JsonProperty("beta_regularizer")] public IRegularizer BetaRegularizer { get; set; } + [JsonProperty("gamma_regularizer")] public IRegularizer GammaRegularizer { get; set; } + // TODO: `beta_constraint` and `gamma_constraint`. + [JsonProperty("renorm")] public bool Renorm { get; set; } + // TODO: `renorm_clipping` and `virtual_batch_size`. + [JsonProperty("renorm_momentum")] public float RenormMomentum { get; set; } = 0.99f; } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/LayerNormalizationArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/LayerNormalizationArgs.cs index 13fd98b41..1ac661b37 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/LayerNormalizationArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/LayerNormalizationArgs.cs @@ -1,16 +1,27 @@ -using static Tensorflow.Binding; +using Newtonsoft.Json; +using static Tensorflow.Binding; namespace Tensorflow.Keras.ArgsDefinition { - public class LayerNormalizationArgs : LayerArgs + public class LayerNormalizationArgs : AutoSerializeLayerArgs { + [JsonProperty("axis")] public Axis Axis { get; set; } = -1; + [JsonProperty("epsilon")] public float Epsilon { get; set; } = 1e-3f; + [JsonProperty("center")] public bool Center { get; set; } = true; + [JsonProperty("scale")] public bool Scale { get; set; } = true; + [JsonProperty("beta_initializer")] public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer; + [JsonProperty("gamma_initializer")] public IInitializer GammaInitializer { get; set; } = tf.ones_initializer; + [JsonProperty("beta_regularizer")] public IRegularizer BetaRegularizer { get; set; } + [JsonProperty("gamma_regularizer")] public IRegularizer GammaRegularizer { get; set; } + + // TODO: `beta_constraint` and `gamma_constraint`. } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs index e2a0e43c8..6256fd329 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs @@ -1,6 +1,8 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.ArgsDefinition { - public class OptimizerV2Args + public class OptimizerV2Args: IKerasConfig { public string Name { get; set; } public float LearningRate { get; set; } = 0.001f; diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling1DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling1DArgs.cs index 9742203d6..c5fdca675 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling1DArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling1DArgs.cs @@ -1,6 +1,8 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition { - public class Pooling1DArgs : LayerArgs + public class Pooling1DArgs : AutoSerializeLayerArgs { /// /// The pooling function to apply, e.g. `tf.nn.max_pool2d`. @@ -10,11 +12,13 @@ public class Pooling1DArgs : LayerArgs /// /// specifying the size of the pooling window. /// + [JsonProperty("pool_size")] public int PoolSize { get; set; } /// /// specifying the strides of the pooling operation. /// + [JsonProperty("strides")] public int Strides { get { return _strides.HasValue ? _strides.Value : PoolSize; } set { _strides = value; } @@ -24,11 +28,13 @@ public int Strides { /// /// The padding method, either 'valid' or 'same'. /// + [JsonProperty("padding")] public string Padding { get; set; } = "valid"; /// /// one of `channels_last` (default) or `channels_first`. /// + [JsonProperty("data_format")] public string DataFormat { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling2DArgs.cs index 1260af4c6..91a372ef3 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling2DArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling2DArgs.cs @@ -1,6 +1,8 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition { - public class Pooling2DArgs : LayerArgs + public class Pooling2DArgs : AutoSerializeLayerArgs { /// /// The pooling function to apply, e.g. `tf.nn.max_pool2d`. @@ -10,21 +12,25 @@ public class Pooling2DArgs : LayerArgs /// /// specifying the size of the pooling window. /// + [JsonProperty("pool_size")] public Shape PoolSize { get; set; } /// /// specifying the strides of the pooling operation. /// + [JsonProperty("strides")] public Shape Strides { get; set; } /// /// The padding method, either 'valid' or 'same'. /// + [JsonProperty("padding")] public string Padding { get; set; } = "valid"; /// /// one of `channels_last` (default) or `channels_first`. /// + [JsonProperty("data_format")] public string DataFormat { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs index 28ccf9f74..97cb364d9 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs @@ -4,7 +4,7 @@ namespace Tensorflow.Keras.ArgsDefinition { - public class PreprocessingLayerArgs : LayerArgs + public class PreprocessingLayerArgs : AutoSerializeLayerArgs { } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/RescalingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/RescalingArgs.cs new file mode 100644 index 000000000..154bd8c89 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/RescalingArgs.cs @@ -0,0 +1,12 @@ +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class RescalingArgs : AutoSerializeLayerArgs + { + [JsonProperty("scale")] + public float Scale { get; set; } + [JsonProperty("offset")] + public float Offset { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs index cf11595e2..39fa52211 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs @@ -1,5 +1,6 @@ namespace Tensorflow.Keras.ArgsDefinition { + // TODO: no corresponding class found in keras python, maybe obselete? public class ResizingArgs : PreprocessingLayerArgs { public int Height { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs index ddeadc001..1a7149f5a 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs @@ -1,4 +1,5 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; @@ -6,11 +7,19 @@ namespace Tensorflow.Keras.ArgsDefinition { public class TextVectorizationArgs : PreprocessingLayerArgs { + [JsonProperty("standardize")] public Func Standardize { get; set; } + [JsonProperty("split")] public string Split { get; set; } = "standardize"; + [JsonProperty("max_tokens")] public int MaxTokens { get; set; } = -1; + [JsonProperty("output_mode")] public string OutputMode { get; set; } = "int"; + [JsonProperty("output_sequence_length")] public int OutputSequenceLength { get; set; } = -1; + [JsonProperty("vocabulary")] public string[] Vocabulary { get; set; } + + // TODO: Add `ngrams`, `sparse`, `ragged`, `idf_weights`, `encoding` } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Regularization/DropoutArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Regularization/DropoutArgs.cs index c41c6fe85..1c85d4936 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Regularization/DropoutArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Regularization/DropoutArgs.cs @@ -1,21 +1,26 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition { - public class DropoutArgs : LayerArgs + public class DropoutArgs : AutoSerializeLayerArgs { /// /// Float between 0 and 1. Fraction of the input units to drop. /// + [JsonProperty("rate")] public float Rate { get; set; } /// /// 1D integer tensor representing the shape of the /// binary dropout mask that will be multiplied with the input. /// + [JsonProperty("noise_shape")] public Shape NoiseShape { get; set; } /// /// random seed. /// + [JsonProperty("seed")] public int? Seed { get; set; } public bool SupportsMasking { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rescaling/RescalingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rescaling/RescalingArgs.cs deleted file mode 100644 index ec9b53150..000000000 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rescaling/RescalingArgs.cs +++ /dev/null @@ -1,8 +0,0 @@ -namespace Tensorflow.Keras.ArgsDefinition -{ - public class RescalingArgs : LayerArgs - { - public float Scale { get; set; } - public float Offset { get; set; } - } -} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping2DArgs.cs new file mode 100644 index 000000000..8c2626390 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping2DArgs.cs @@ -0,0 +1,18 @@ +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.ArgsDefinition.Reshaping +{ + public class Cropping2DArgs : LayerArgs + { + /// + /// channel last: (b, h, w, c) + /// channels_first: (b, c, h, w) + /// + public enum DataFormat { channels_first = 0, channels_last = 1 } + /// + /// Accept: int[1][2], int[1][1], int[2][2] + /// + public NDArray cropping { get; set; } + public DataFormat data_format { get; set; } = DataFormat.channels_last; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping3DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping3DArgs.cs new file mode 100644 index 000000000..2d98e55db --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping3DArgs.cs @@ -0,0 +1,18 @@ +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.ArgsDefinition.Reshaping +{ + public class Cropping3DArgs : LayerArgs + { + /// + /// channel last: (b, h, w, c) + /// channels_first: (b, c, h, w) + /// + public enum DataFormat { channels_first = 0, channels_last = 1 } + /// + /// Accept: int[1][3], int[1][1], int[3][2] + /// + public NDArray cropping { get; set; } + public DataFormat data_format { get; set; } = DataFormat.channels_last; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/CroppingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/CroppingArgs.cs new file mode 100644 index 000000000..21b85966b --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/CroppingArgs.cs @@ -0,0 +1,12 @@ +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.ArgsDefinition.Reshaping +{ + public class Cropping1DArgs : LayerArgs + { + /// + /// Accept length 1 or 2 + /// + public NDArray cropping { get; set; } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs index c2b48cc2f..91ffc2058 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs @@ -1,7 +1,10 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition { - public class FlattenArgs : LayerArgs + public class FlattenArgs : AutoSerializeLayerArgs { + [JsonProperty("data_format")] public string DataFormat { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/PermuteArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/PermuteArgs.cs index 2686f6cd7..92be10ab1 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/PermuteArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/PermuteArgs.cs @@ -1,5 +1,9 @@ -namespace Tensorflow.Keras.ArgsDefinition { - public class PermuteArgs : LayerArgs { - public int[] dims { get; set; } - } +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition { + public class PermuteArgs : AutoSerializeLayerArgs + { + [JsonProperty("dims")] + public int[] dims { get; set; } + } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs index 77bca8ad0..4d1123c8a 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs @@ -1,7 +1,10 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition { - public class ReshapeArgs : LayerArgs + public class ReshapeArgs : AutoSerializeLayerArgs { + [JsonProperty("target_shape")] public Shape TargetShape { get; set; } public object[] TargetShapeObjects { get; set; } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/UpSampling2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/UpSampling2DArgs.cs index 7fdda32d3..b35e0e4b6 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/UpSampling2DArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/UpSampling2DArgs.cs @@ -1,12 +1,17 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition { - public class UpSampling2DArgs : LayerArgs + public class UpSampling2DArgs : AutoSerializeLayerArgs { + [JsonProperty("size")] public Shape Size { get; set; } + [JsonProperty("data_format")] public string DataFormat { get; set; } /// /// 'nearest', 'bilinear' /// + [JsonProperty("interpolation")] public string Interpolation { get; set; } = "nearest"; } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ZeroPadding2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ZeroPadding2DArgs.cs index ed6e7cc9c..4831e435b 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ZeroPadding2DArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ZeroPadding2DArgs.cs @@ -2,6 +2,7 @@ namespace Tensorflow.Keras.ArgsDefinition { + // TODO: complete the implementation public class ZeroPadding2DArgs : LayerArgs { public NDArray Padding { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs similarity index 67% rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMArgs.cs rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs index b08d21d88..764641474 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs @@ -1,9 +1,8 @@ -using Tensorflow.Keras.ArgsDefinition.Rnn; - -namespace Tensorflow.Keras.ArgsDefinition.Lstm +namespace Tensorflow.Keras.ArgsDefinition.Rnn { public class LSTMArgs : RNNArgs { + // TODO: maybe change the `RNNArgs` and implement this class. public bool UnitForgetBias { get; set; } public float Dropout { get; set; } public float RecurrentDropout { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs new file mode 100644 index 000000000..594c99bb0 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs @@ -0,0 +1,7 @@ +namespace Tensorflow.Keras.ArgsDefinition.Rnn +{ + // TODO: complete the implementation + public class LSTMCellArgs : LayerArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs index da5279257..2585592c1 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs @@ -1,21 +1,30 @@ -using System.Collections.Generic; +using Newtonsoft.Json; +using System.Collections.Generic; namespace Tensorflow.Keras.ArgsDefinition.Rnn { - public class RNNArgs : LayerArgs + public class RNNArgs : AutoSerializeLayerArgs { public interface IRnnArgCell : ILayer { object state_size { get; } } - + [JsonProperty("cell")] + // TODO: the cell should be serialized with `serialize_keras_object`. public IRnnArgCell Cell { get; set; } = null; + [JsonProperty("return_sequences")] public bool ReturnSequences { get; set; } = false; + [JsonProperty("return_state")] public bool ReturnState { get; set; } = false; + [JsonProperty("go_backwards")] public bool GoBackwards { get; set; } = false; + [JsonProperty("stateful")] public bool Stateful { get; set; } = false; + [JsonProperty("unroll")] public bool Unroll { get; set; } = false; + [JsonProperty("time_major")] public bool TimeMajor { get; set; } = false; + // TODO: Add `num_constants` and `zero_output_for_mask`. public Dictionary Kwargs { get; set; } = null; public int Units { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs new file mode 100644 index 000000000..1bc13caf3 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs @@ -0,0 +1,50 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Converters; +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Common +{ + public class CustomizedActivationJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(Activation); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if (value is null) + { + var token = JToken.FromObject(""); + token.WriteTo(writer); + } + else if (value is not Activation) + { + throw new TypeError($"Unable to use `CustomizedActivationJsonConverter` to serialize the type {value.GetType()}."); + } + else + { + var token = JToken.FromObject((value as Activation)!.GetType().Name); + token.WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + throw new NotImplementedException(); + //var dims = serializer.Deserialize(reader, typeof(string)); + //if (dims is null) + //{ + // throw new ValueError("Cannot deserialize 'null' to `Activation`."); + //} + //return new Shape((long[])(dims!)); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs new file mode 100644 index 000000000..4e190605c --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs @@ -0,0 +1,48 @@ +using Newtonsoft.Json.Linq; +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Common +{ + public class CustomizedAxisJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(Axis); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if (value is null) + { + var token = JToken.FromObject(new int[] { }); + token.WriteTo(writer); + } + else if (value is not Axis) + { + throw new TypeError($"Unable to use `CustomizedAxisJsonConverter` to serialize the type {value.GetType()}."); + } + else + { + var token = JToken.FromObject((value as Axis)!.axis); + token.WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + var axis = serializer.Deserialize(reader, typeof(long[])); + if (axis is null) + { + throw new ValueError("Cannot deserialize 'null' to `Axis`."); + } + return new Axis((int[])(axis!)); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs new file mode 100644 index 000000000..1ad19fc89 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs @@ -0,0 +1,73 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Converters; +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.Common +{ + public class CustomizedNodeConfigJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(NodeConfig); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if (value is null) + { + var token = JToken.FromObject(null); + token.WriteTo(writer); + } + else if (value is not NodeConfig) + { + throw new TypeError($"Unable to use `CustomizedNodeConfigJsonConverter` to serialize the type {value.GetType()}."); + } + else + { + var config = value as NodeConfig; + var token = JToken.FromObject(new object[] { config!.Name, config.NodeIndex, config.TensorIndex }); + token.WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + var values = serializer.Deserialize(reader, typeof(object[])) as object[]; + if (values is null) + { + throw new ValueError("Cannot deserialize 'null' to `Shape`."); + } + if(values.Length != 3) + { + throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); + } + if (values[0] is not string) + { + throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`"); + } + if (values[1] is not int) + { + throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`"); + } + if (values[2] is not int) + { + throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`"); + } + return new NodeConfig() + { + Name = values[0] as string, + NodeIndex = (int)values[1], + TensorIndex = (int)values[2] + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs new file mode 100644 index 000000000..300cb2f28 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs @@ -0,0 +1,67 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Converters; +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Common +{ + public class CustomizedShapeJsonConverter: JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(Shape); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if(value is null) + { + var token = JToken.FromObject(null); + token.WriteTo(writer); + } + else if(value is not Shape) + { + throw new TypeError($"Unable to use `CustomizedShapeJsonConverter` to serialize the type {value.GetType()}."); + } + else + { + var shape = (value as Shape)!; + long?[] dims = new long?[shape.ndim]; + for(int i = 0; i < dims.Length; i++) + { + if (shape.dims[i] == -1) + { + dims[i] = null; + } + else + { + dims[i] = shape.dims[i]; + } + } + var token = JToken.FromObject(dims); + token.WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + var dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; + if(dims is null) + { + throw new ValueError("Cannot deserialize 'null' to `Shape`."); + } + long[] convertedDims = new long[dims.Length]; + for(int i = 0; i < dims.Length; i++) + { + convertedDims[i] = dims[i] ?? (-1); + } + return new Shape(convertedDims); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs index 7280594b7..6743935c8 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs @@ -16,23 +16,27 @@ limitations under the License. using System.Collections.Generic; using System.Linq; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Engine { /// /// Specifies the ndim, dtype and shape of every input to a layer. /// - public class InputSpec + public class InputSpec: IKerasConfigable { public int? ndim; + public int? max_ndim; public int? min_ndim; Dictionary axes; Shape shape; + TF_DataType dtype; public int[] AllAxisDim; public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, int? ndim = null, int? min_ndim = null, + int? max_ndim = null, Dictionary axes = null, Shape shape = null) { @@ -41,7 +45,9 @@ public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, axes = new Dictionary(); this.axes = axes; this.min_ndim = min_ndim; + this.max_ndim = max_ndim; this.shape = shape; + this.dtype = dtype; if (ndim == null && shape != null) this.ndim = shape.ndim; @@ -49,7 +55,30 @@ public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, AllAxisDim = axes.Select(x => x.Value).ToArray(); } + public IKerasConfig get_config() + { + return new Config() + { + DType = dtype == TF_DataType.DtInvalid ? null : dtype, + Shape = shape, + Ndim = ndim, + MinNdim = min_ndim, + MaxNdim = max_ndim, + Axes = axes.ToDictionary(x => x.Key.ToString(), x => x.Value) + }; + } + public override string ToString() => $"ndim={ndim}, min_ndim={min_ndim}, axes={axes.Count}"; + + public class Config: IKerasConfig + { + public TF_DataType? DType { get; set; } + public Shape Shape { get; set; } + public int? Ndim { get; set; } + public int? MinNdim { get;set; } + public int? MaxNdim { get;set; } + public IDictionary Axes { get; set; } + } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index 1ec4a2c6e..036291076 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -1,10 +1,12 @@ using System.Collections.Generic; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; +using Tensorflow.Training; namespace Tensorflow.Keras { - public interface ILayer + public interface ILayer: IWithTrackable, IKerasConfigable { string Name { get; } bool Trainable { get; } @@ -19,8 +21,8 @@ public interface ILayer List NonTrainableWeights { get; } Shape OutputShape { get; } Shape BatchInputShape { get; } + TensorShapeConfig BuildInputShape { get; } TF_DataType DType { get; } int count_params(); - LayerArgs get_config(); } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs index 602e7a880..3578652ee 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs @@ -1,5 +1,5 @@ using System; -using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.ArgsDefinition.Reshaping; using Tensorflow.NumPy; namespace Tensorflow.Keras.Layers diff --git a/src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs new file mode 100644 index 000000000..1217e1e52 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Saving +{ + public interface IKerasConfig + { + } + + public interface IKerasConfigable + { + IKerasConfig get_config(); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs index b8b8cab40..4ce290c83 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs @@ -1,4 +1,5 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; using Tensorflow.Keras.ArgsDefinition; @@ -6,11 +7,15 @@ namespace Tensorflow.Keras.Saving { - public class LayerConfig + public class LayerConfig: IKerasConfig { + [JsonProperty("name")] public string Name { get; set; } + [JsonProperty("class_name")] public string ClassName { get; set; } + [JsonProperty("config")] public LayerArgs Config { get; set; } + [JsonProperty("inbound_nodes")] public List InboundNodes { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs index abfb235be..cac19180f 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs @@ -1,15 +1,20 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; using Tensorflow.Keras.Engine; namespace Tensorflow.Keras.Saving { - public class ModelConfig + public class ModelConfig : IKerasConfig { + [JsonProperty("name")] public string Name { get; set; } + [JsonProperty("layers")] public List Layers { get; set; } + [JsonProperty("input_layers")] public List InputLayers { get; set; } + [JsonProperty("output_layers")] public List OutputLayers { get; set; } public override string ToString() diff --git a/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs index 3132248ef..20e2fef59 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs @@ -1,10 +1,13 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Keras.Common; namespace Tensorflow.Keras.Saving { - public class NodeConfig + [JsonConverter(typeof(CustomizedNodeConfigJsonConverter))] + public class NodeConfig : IKerasConfig { public string Name { get; set; } public int NodeIndex { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs b/src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs new file mode 100644 index 000000000..ae8a1ab13 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs @@ -0,0 +1,35 @@ +using System; +using System.Collections.Generic; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + public interface ISerializedAttributes + { + IDictionary Functions { get; } + + IDictionary CheckpointableObjects { get; } + + /// + /// Returns functions to attach to the root object during serialization. + /// + IDictionary FunctionsToSerialize { get; } + + /// + /// Returns objects to attach to the root object during serialization. + /// + IDictionary ObjectsToSerialize{get; } + + /// + /// Saves function dictionary, and validates dictionary values. + /// + /// + IDictionary set_and_validate_functions(IDictionary function_dict); + + /// + /// Saves objects to a dictionary, and validates the values. + /// + /// + IDictionary set_and_validate_objects(IDictionary object_dict); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs new file mode 100644 index 000000000..7abcfde26 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs @@ -0,0 +1,21 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Tensorflow.Keras.Saving +{ + public class TensorShapeConfig + { + [JsonProperty("class_name")] + public string ClassName { get; set; } = "TensorShape"; + [JsonProperty("items")] + public long?[] Items { get; set; } + + public static implicit operator Shape(TensorShapeConfig shape) + => shape == null ? null : new Shape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray()); + + public static implicit operator TensorShapeConfig(Shape shape) + => new TensorShapeConfig() { Items = shape.dims.Select(x => x == -1 ? null : x).ToArray() }; + } +} diff --git a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs b/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs index e25537d80..45ebd884f 100644 --- a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs +++ b/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs @@ -9,10 +9,52 @@ namespace Tensorflow.ModelSaving /// public class SaveOptions { - bool save_debug_info; + public bool save_debug_info = false; + public IList? namespace_white_list { get; set; } = null; + public IDictionary? function_aliases { get; set; } = null; + public string? experimental_io_device { get; set; } = null; + // TODO: experimental + public VariablePolicy experimental_variable_policy { get; set; } = VariablePolicy.None; + public bool experimental_custom_gradients { get; set; } = true; public SaveOptions(bool save_debug_info = false) { this.save_debug_info = save_debug_info; } } + + public class VariablePolicy + { + public string Policy { get; } + private VariablePolicy(string policy) + { + Policy = policy; + } + public static VariablePolicy None = new(null); + public static VariablePolicy SAVE_VARIABLE_DEVICES = new("save_variable_devices"); + public static VariablePolicy EXPAND_DISTRIBUTED_VARIABLES = new("expand_distributed_variables"); + + public bool save_variable_devices() + { + return this != VariablePolicy.None; + } + + /// + /// Tries to convert `obj` to a VariablePolicy instance. + /// + /// + /// + public static VariablePolicy from_obj(object obj) + { + if (obj is null) return VariablePolicy.None; + if (obj is VariablePolicy) return (VariablePolicy)obj; + var key = obj.ToString().ToLower(); + return key switch + { + null => VariablePolicy.None, + "save_variable_devices" => VariablePolicy.SAVE_VARIABLE_DEVICES, + "expand_distributed_variables" => VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, + _ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.") + }; + } + } } diff --git a/src/TensorFlowNET.Core/NumPy/Axis.cs b/src/TensorFlowNET.Core/NumPy/Axis.cs index 6c7189df1..709ca9b27 100644 --- a/src/TensorFlowNET.Core/NumPy/Axis.cs +++ b/src/TensorFlowNET.Core/NumPy/Axis.cs @@ -14,20 +14,29 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using Newtonsoft.Json; using System; using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Keras.Common; namespace Tensorflow { - public record Axis(params int[] axis) + [JsonConverter(typeof(CustomizedAxisJsonConverter))] + public class Axis { + public int[] axis { get; set; } public int size => axis == null ? -1 : axis.Length; public bool IsScalar { get; init; } public int this[int index] => axis[index]; + public Axis(params int[] axis) + { + this.axis = axis; + } + public static implicit operator int[]?(Axis axis) => axis?.axis; diff --git a/src/TensorFlowNET.Core/Numpy/Shape.cs b/src/TensorFlowNET.Core/Numpy/Shape.cs index bc79fefca..ecf735869 100644 --- a/src/TensorFlowNET.Core/Numpy/Shape.cs +++ b/src/TensorFlowNET.Core/Numpy/Shape.cs @@ -14,14 +14,17 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using Newtonsoft.Json; using System; using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Keras.Common; using Tensorflow.NumPy; namespace Tensorflow { + [JsonConverter(typeof(CustomizedShapeJsonConverter))] public class Shape { public int ndim => _dims == null ? -1 : _dims.Length; diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs index fdcb5aff0..e7e9955c0 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs @@ -14,6 +14,8 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class Constant : IInitializer @@ -22,11 +24,19 @@ public class Constant : IInitializer T value; bool _verify_shape; + private readonly Dictionary _config; + + public string ClassName => "Constant"; + public IDictionary Config => _config; + public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) { this.value = value; this.dtype = dtype; _verify_shape = verify_shape; + + _config = new Dictionary(); + _config["value"] = this.value; } public Tensor Apply(InitializerArgs args) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs index d97d88308..def1cb7a0 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs @@ -14,10 +14,17 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class GlorotUniform : VarianceScaling { + private readonly Dictionary _config; + + public override string ClassName => "GlorotUniform"; + public override IDictionary Config => _config; + public GlorotUniform(float scale = 1.0f, string mode = "FAN_AVG", bool uniform = true, @@ -28,7 +35,8 @@ public GlorotUniform(float scale = 1.0f, seed: seed, dtype: dtype) { - + _config = new Dictionary(); + _config["seed"] = _seed; } } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs index 50d4d5037..9748b1004 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs @@ -14,10 +14,17 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using Newtonsoft.Json; +using System.Collections.Generic; + namespace Tensorflow { public interface IInitializer { + [JsonProperty("class_name")] + string ClassName { get; } + [JsonProperty("config")] + IDictionary Config { get; } Tensor Apply(InitializerArgs args); } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs index 02d3c93b2..3077a1e0e 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs @@ -14,12 +14,19 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class Ones : IInitializer { private TF_DataType dtype; + private readonly Dictionary _config; + + public string ClassName => "Ones"; + public IDictionary Config => new Dictionary(); + public Ones(TF_DataType dtype = TF_DataType.TF_FLOAT) { this.dtype = dtype; diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs index 045b02c5a..492047c9f 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs @@ -1,4 +1,4 @@ -/***************************************************************************** +/***************************************************************************** Copyright 2023 Haiping Chen. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,6 +19,7 @@ limitations under the License. using static Tensorflow.Binding; namespace Tensorflow.Operations.Initializers; +using System.Collections.Generic; public class Orthogonal : IInitializer { @@ -31,6 +32,10 @@ public Orthogonal(float gain = 1.0f, int? seed = null) _seed = seed; } + private readonly Dictionary _config; + + public string ClassName => "Orthogonal"; + public IDictionary Config => throw new NotImplementedException(); public Tensor Apply(InitializerArgs args) { return _generate_init_val(args.Shape, args.DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : args.DType); diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs index 029b311bb..21fa7e2b2 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs @@ -14,6 +14,8 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class RandomNormal : IInitializer @@ -23,6 +25,11 @@ public class RandomNormal : IInitializer private int? seed; private TF_DataType dtype; + private readonly Dictionary _config; + + public string ClassName => "RandomNormal"; + public IDictionary Config => _config; + public RandomNormal(float mean = 0.0f, float stddev = 0.05f, int? seed = null, @@ -32,6 +39,11 @@ public RandomNormal(float mean = 0.0f, this.stddev = stddev; this.seed = seed; this.dtype = dtype; + + _config = new Dictionary(); + _config["mean"] = this.mean; + _config["stddev"] = this.stddev; + _config["seed"] = this.seed; } public Tensor Apply(InitializerArgs args) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs index a49d59212..87404708c 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs @@ -14,6 +14,8 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class RandomUniform : IInitializer @@ -23,12 +25,22 @@ public class RandomUniform : IInitializer private float maxval; private TF_DataType dtype; + private readonly Dictionary _config; + + public string ClassName => "RandomUniform"; + public IDictionary Config => _config; + public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT, float minval = -0.05f, float maxval = 0.05f, int? seed = null) { this.dtype = dtype; this.minval = minval; this.maxval = maxval; this.seed = seed; + + _config = new Dictionary(); + _config["minval"] = this.minval; + _config["maxval"] = this.maxval; + _config["seed"] = this.seed; } public Tensor Apply(InitializerArgs args) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs index 048c11e7a..c1c3e9996 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs @@ -14,6 +14,8 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class TruncatedNormal : IInitializer @@ -23,6 +25,11 @@ public class TruncatedNormal : IInitializer private int? seed; private TF_DataType dtype; + private readonly Dictionary _config; + + public string ClassName => "TruncatedNormal"; + public IDictionary Config => _config; + public TruncatedNormal(float mean = 0.0f, float stddev = 1.0f, int? seed = null, @@ -32,6 +39,10 @@ public TruncatedNormal(float mean = 0.0f, this.stddev = stddev; this.seed = seed; this.dtype = dtype; + _config = new Dictionary(); + _config["mean"] = this.mean; + _config["stddev"] = this.stddev; + _config["seed"] = this.seed; } public Tensor Apply(InitializerArgs args) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs index d313f4c9a..f104e8e83 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs @@ -15,7 +15,9 @@ limitations under the License. ******************************************************************************/ using System; +using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; namespace Tensorflow.Operations.Initializers { @@ -30,6 +32,11 @@ public class VarianceScaling : IInitializer protected int? _seed; protected TF_DataType _dtype; protected bool _uniform; + private readonly Dictionary _config; + + public virtual string ClassName => "VarianceScaling"; + + public virtual IDictionary Config => _config; public VarianceScaling(float factor = 2.0f, string mode = "FAN_IN", @@ -50,6 +57,12 @@ public VarianceScaling(float factor = 2.0f, _seed = seed; _dtype = dtype; _uniform = uniform; + + _config = new(); + _config["scale"] = _scale; + _config["mode"] = _mode; + _config["distribution"] = _distribution; + _config["seed"] = _seed; } public Tensor Apply(InitializerArgs args) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs index 5d045292f..c4ed25a17 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs @@ -14,6 +14,8 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class Zeros : IInitializer @@ -21,6 +23,9 @@ public class Zeros : IInitializer Shape shape; TF_DataType dtype; + public string ClassName => "Zeros"; + public IDictionary Config => new Dictionary(); + public Zeros(Shape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT) { this.shape = shape; diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index d63d0311b..2b83dd1d1 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -20,7 +20,9 @@ limitations under the License. using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; using Tensorflow.Operations; +using Tensorflow.Train; using Tensorflow.Util; using static Tensorflow.Binding; @@ -75,6 +77,8 @@ public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell public Shape BatchInputShape => throw new NotImplementedException(); + public TensorShapeConfig BuildInputShape => throw new NotImplementedException(); + public TF_DataType DType => throw new NotImplementedException(); protected bool built = false; public bool Built => built; @@ -143,7 +147,7 @@ public int count_params() throw new NotImplementedException(); } - public LayerArgs get_config() + public IKerasConfig get_config() { throw new NotImplementedException(); } @@ -152,5 +156,7 @@ public void build(Shape input_shape) { throw new NotImplementedException(); } + + public Trackable GetTrackable() { throw new NotImplementedException(); } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_ops.cs b/src/TensorFlowNET.Core/Operations/gen_ops.cs index 11cb6de8e..956be96b5 100644 --- a/src/TensorFlowNET.Core/Operations/gen_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_ops.cs @@ -1,6 +1,9 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Xml.Linq; +using Tensorflow.Contexts; +using Tensorflow.Eager; using static Tensorflow.Binding; namespace Tensorflow.Operations @@ -17182,17 +17185,47 @@ public static Tensor merge_summary(Tensor[] inputs, string name = "MergeSummary" /// path in the input checkpoint_prefixes. This is useful when those paths are non /// user-facing temporary locations. /// - public static Operation merge_v2checkpoints(Tensor checkpoint_prefixes, Tensor destination_prefix, bool? delete_old_dirs = null, string name = "MergeV2Checkpoints") - { + public static Operation merge_v2_checkpoints(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs = true, bool allow_missing_files = false, string name = "MergeV2Checkpoints") + { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name, + checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files)); + result = null; + return null; + //try + //{ + // var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name, + // new object[] { checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files })); + // result = null; + // return null; + //} + //catch (System.Exception) + //{ + // return merge_v2_checkpoints_eager_fallback(checkpoint_prefixes, destination_prefix, delete_old_dirs: delete_old_dirs, + // allow_missing_files: allow_missing_files, name: name, ctx: ctx); + //} + } var dict = new Dictionary(); dict["checkpoint_prefixes"] = checkpoint_prefixes; dict["destination_prefix"] = destination_prefix; - if (delete_old_dirs.HasValue) - dict["delete_old_dirs"] = delete_old_dirs.Value; + dict["delete_old_dirs"] = delete_old_dirs; var op = tf.OpDefLib._apply_op_helper("MergeV2Checkpoints", name: name, keywords: dict); return op; } + //public static Operation merge_v2_checkpoints_eager_fallback(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs, bool allow_missing_files, string name, Context ctx) + //{ + // checkpoint_prefixes = ops.convert_to_tensor(checkpoint_prefixes, TF_DataType.TF_STRING); + // destination_prefix = ops.convert_to_tensor(destination_prefix, TF_DataType.TF_STRING); + // var inputs_flat = new Tensor[] { checkpoint_prefixes, destination_prefix }; + // var attrs = new object[] { "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files }; + // var result = execute.quick_execute("MergeV2Checkpoints", 0, inputs_flat, attrs, ctx, name); + // result = null; + // return null; + //} + /// /// Transforms a spectrogram into a form that's useful for speech recognition. /// @@ -24259,6 +24292,12 @@ public static (Tensor output_false, Tensor output_true) ref_switch(Tensor data, /// public static Tensor regex_full_match(Tensor input, Tensor pattern, string name = "RegexFullMatch") { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("RegexFullMatch", name, input, pattern)); + return result[0]; + } var dict = new Dictionary(); dict["input"] = input; dict["pattern"] = pattern; @@ -29744,6 +29783,12 @@ public static Tensor[] shape_n(Tensor[] input, TF_DataType? out_type = null, str /// public static Tensor sharded_filename(Tensor basename, Tensor shard, Tensor num_shards, string name = "ShardedFilename") { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("ShardedFilename", name, basename, shard, num_shards)); + return result[0]; + } var dict = new Dictionary(); dict["basename"] = basename; dict["shard"] = shard; @@ -34668,6 +34713,12 @@ public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, /// public static Tensor string_join(Tensor[] inputs, string separator = null, string name = "StringJoin") { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("StringJoin", name, inputs, "separator", separator)); + return result[0]; + } var dict = new Dictionary(); dict["inputs"] = inputs; if (separator != null) diff --git a/src/TensorFlowNET.Core/Operations/io_ops.cs b/src/TensorFlowNET.Core/Operations/io_ops.cs index 4f276e36c..35c5877f3 100644 --- a/src/TensorFlowNET.Core/Operations/io_ops.cs +++ b/src/TensorFlowNET.Core/Operations/io_ops.cs @@ -14,7 +14,9 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System.Linq; using Tensorflow.Contexts; +using Tensorflow.Eager; using static Tensorflow.Binding; namespace Tensorflow @@ -23,11 +25,41 @@ public class io_ops { public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null) { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + try + { + var result = tf.Runner.TFE_FastPathExecute( + new FastPathOpExecInfo("SaveV2", name, new object[] { prefix, tensor_names, shape_and_slices, tensors })); + result = null; + return null; + } + catch (System.Exception) + { + return save_v2_eager_fallback(prefix, tensor_names, shape_and_slices, tensors, name, ctx); + } + } var _op = tf.OpDefLib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); return _op; } + public Operation save_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name, Context ctx) + { + DataType[] attr_dtypes; + (attr_dtypes, tensors) = execute.onvert_to_mixed_eager_tensors(tensors, ctx); + prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING); + var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING); + var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING); + var inputs_flat = tensors.Concat(new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor }).ToArray(); + var attrs = new object[] { "dtypes", attr_dtypes }; + + var result = execute.quick_execute("SaveV2", 0, inputs_flat, attrs, ctx, name); + result = null; + return null; + } + public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) { var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs index ee751acf4..1b1fa0037 100644 --- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -17,6 +17,9 @@ limitations under the License. using System; using System.Linq; using Tensorflow.Framework; +using Tensorflow.ModelSaving; +using Tensorflow.Train; +using Tensorflow.Variables; using static Tensorflow.CppShapeInferenceResult.Types; namespace Tensorflow @@ -38,6 +41,11 @@ public static bool is_resource_variable(IVariableV1 var) { return var is ResourceVariable; } + + public static bool is_resource_variable(Trackable var) + { + return var is BaseResourceVariable; + } /// /// Creates a variable handle with information to do shape inference. @@ -171,5 +179,57 @@ private static HandleData get_eager_safe_handle_data(Tensor handle) return HandleData.Parser.ParseFrom(handle.BufferToArray()); } } + + /// + /// Copies an existing variable to a new graph, with no initializer. + /// + /// + public static UninitializedVariable copy_to_graph_uninitialized(ResourceVariable variable) + { + var new_variable = new UninitializedVariable( + trainable: variable.Trainable, + shape: variable.shape, + dtype: variable.dtype, + name: variable.SharedName, + aggregation: variable.Aggregation, + extra_handle_data: null); + new_variable._maybe_initialize_trackable(); + return new_variable; + } + + /// + /// Writes additional information of the variable into the SavedObject proto. + /// + /// + /// + /// + /// + public static void write_object_proto_for_resource_variable(BaseResourceVariable resource_variable, SavedObject proto, SaveOptions options, bool enforcing_naming = true) + { + // lack of API: `proto.Variable.SetInParent()`. + if(enforcing_naming && !resource_variable.Name.EndsWith(":0")) + { + throw new ValueError($"Cowardly refusing to save variable {resource_variable.Name} because of " + + $"unexpected suffix in the name (expected ':0') which won't be restored."); + } + if(proto.Variable is null) + { + proto.Variable = new SavedVariable(); + } + proto.Variable.Name = meta_graph.op_name(resource_variable.Name); + proto.Variable.Trainable = resource_variable.Trainable; + proto.Variable.Dtype = resource_variable.dtype.as_datatype_enum(); + // TODO: lack of API `proto.Variable.Synchronization = resource_variable.synchronization.value`. + proto.Variable.Aggregation = resource_variable.Aggregation; + proto.Variable.Shape = resource_variable.shape.as_proto(); + + if (options.experimental_variable_policy.save_variable_devices()) + { + if (!string.IsNullOrEmpty(resource_variable.Device)) + { + proto.Variable.Device = resource_variable.Device; + } + } + } } } diff --git a/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs b/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs index 9d3e854ac..f2597574b 100644 --- a/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs +++ b/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs @@ -156,7 +156,7 @@ public SavedObjectGraph Clone() { /// Nodes[0] is considered the root node. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField Nodes { + public pbc::RepeatedField Nodes { get { return nodes_; } } @@ -286,6 +286,7 @@ public SavedObject() { [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public SavedObject(SavedObject other) : this() { children_ = other.children_.Clone(); + dependencies_ = other.dependencies_.Clone(); slotVariables_ = other.slotVariables_.Clone(); saveableObjects_ = other.saveableObjects_.Clone(); switch (other.KindCase) { @@ -328,6 +329,7 @@ public SavedObject Clone() { private static readonly pb::FieldCodec _repeated_children_codec = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); private readonly pbc::RepeatedField children_ = new pbc::RepeatedField(); + private readonly pbc::RepeatedField dependencies_ = new pbc::RepeatedField(); /// /// Objects which this object depends on: named edges in the dependency /// graph. @@ -338,6 +340,11 @@ public SavedObject Clone() { public pbc::RepeatedField Children { get { return children_; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Dependencies { + get { return dependencies_; } + } /// Field number for the "slot_variables" field. public const int SlotVariablesFieldNumber = 3; @@ -617,6 +624,7 @@ public void MergeFrom(SavedObject other) { return; } children_.Add(other.children_); + dependencies_.Add(other.dependencies_); slotVariables_.Add(other.slotVariables_); saveableObjects_.Add(other.saveableObjects_); switch (other.KindCase) { diff --git a/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs b/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs index 3aa747c20..fb197eca2 100644 --- a/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs +++ b/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs @@ -198,6 +198,22 @@ public sealed partial class TrackableObject : pb::IMessage { public TrackableObject() { OnConstruction(); } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrackableObject(pbc::RepeatedField slot) { + OnConstruction(); + slotVariables_ = slot; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrackableObject(pbc::RepeatedField slot, + pbc::RepeatedField children + ) + { + OnConstruction(); + slotVariables_ = slot; + children_ = children; + } partial void OnConstruction(); diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index a7db6eee1..ede72a6ae 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -108,6 +108,7 @@ https://tensorflownet.readthedocs.io + diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index 372ac6762..deeb9e4b5 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -202,6 +202,24 @@ public static string as_numpy_name(this TF_DataType type) _ => type.ToString() }; + public static string as_python_name(this TF_DataType type) + => type switch + { + TF_DataType.TF_STRING => "str", + TF_DataType.TF_UINT8 => "uint8", + TF_DataType.TF_INT8 => "int8", + TF_DataType.TF_UINT32 => "uint32", + TF_DataType.TF_INT32 => "int32", + TF_DataType.TF_UINT64 => "uint64", + TF_DataType.TF_INT64 => "int64", + TF_DataType.TF_FLOAT => "float32", + TF_DataType.TF_DOUBLE => "float64", + TF_DataType.TF_BOOL => "bool", + TF_DataType.TF_RESOURCE => "resource", + TF_DataType.TF_VARIANT => "variant", + _ => type.ToString() + }; + public static int get_datatype_size(this TF_DataType type) => type.as_base_dtype() switch { diff --git a/src/TensorFlowNET.Core/Training/AutoTrackable.cs b/src/TensorFlowNET.Core/Training/AutoTrackable.cs index d2198e37e..4ba3e4074 100644 --- a/src/TensorFlowNET.Core/Training/AutoTrackable.cs +++ b/src/TensorFlowNET.Core/Training/AutoTrackable.cs @@ -1,6 +1,71 @@ -namespace Tensorflow.Train +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Functions; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Operations.Activation; +using static Tensorflow.Binding; + +namespace Tensorflow.Train { - public abstract class AutoTrackable : Trackable + public class AutoTrackable : Trackable { + public void _delete_tracking(string name) + { + _maybe_initialize_trackable(); + if (_unconditional_dependency_names.ContainsKey(name)) + { + _unconditional_dependency_names.Remove(name); + for (int i = _unconditional_checkpoint_dependencies.Count - 1; i >= 0; i--) + { + if (_unconditional_checkpoint_dependencies[i].Name == name) + { + _unconditional_checkpoint_dependencies.RemoveAt(i); + } + } + } + } + + public override IDictionary _trackable_children(SaveType save_type, IDictionary>? cache = null) + { + if(save_type != SaveType.SAVEDMODEL) + { + return base._trackable_children(save_type, cache); + } + + Dictionary functions = new(); + // TODO: process of logs. + var properties = this.GetType().GetProperties(); + foreach ( var property in properties ) + { + if(property.PropertyType == typeof(Function) || property.PropertyType == typeof(ConcreteFunction)) + { + string name = property.Name; + object value = property.GetValue(this, null); + functions[name] = (Trackable)value; + } + } + + // TODO: process the type `core_types.GenericFunction`. + + Dictionary children = new(); + foreach(var pair in CheckpointDependencies) + { + var name = pair.Name; + var child = pair.Refer; + if(child is ConcreteFunction) // or Generic function + { + continue; + } + if(functions.ContainsKey(name) && functions[name] != child) + { + throw new ValueError($"Can't save object because it has multiple children with the same " + + $"name. Object: {this}, attribute name: {name}, child 1: " + + $"{child}, child 2: {functions[name]}"); + } + children[name] = child; + } + + return children.Concat(functions).ToDictionary(x => x.Key, x => x.Value); + } } } diff --git a/src/TensorFlowNET.Core/Training/IWithTrackable.cs b/src/TensorFlowNET.Core/Training/IWithTrackable.cs new file mode 100644 index 000000000..87eda8795 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/IWithTrackable.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; + +namespace Tensorflow.Training +{ + public interface IWithTrackable + { + Trackable GetTrackable(); + } +} diff --git a/src/TensorFlowNET.Core/Training/LayerUtils.cs b/src/TensorFlowNET.Core/Training/LayerUtils.cs new file mode 100644 index 000000000..211419651 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/LayerUtils.cs @@ -0,0 +1,9 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; + +namespace Tensorflow.Training +{ + +} diff --git a/src/TensorFlowNET.Core/Training/Optimizer.cs b/src/TensorFlowNET.Core/Training/Optimizer.cs index f985c6566..e656fe96d 100644 --- a/src/TensorFlowNET.Core/Training/Optimizer.cs +++ b/src/TensorFlowNET.Core/Training/Optimizer.cs @@ -351,7 +351,7 @@ public virtual void _prepare() /// /// /// - protected IVariableV1 get_slot(IVariableV1 var, string name) + internal IVariableV1 get_slot(IVariableV1 var, string name) { var named_slots = _slots.ContainsKey(name) ? _slots[name] : null; if (named_slots == null) @@ -360,6 +360,11 @@ protected IVariableV1 get_slot(IVariableV1 var, string name) return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null; } + internal IEnumerable get_slot_names() + { + return _slots.Keys; + } + private string _var_key(IVariableV1 var) { return $"{var.Op.graph.graph_key}.{var.Op.name}"; diff --git a/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs b/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs index 167c635a8..2d23a325f 100644 --- a/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs +++ b/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs @@ -14,6 +14,8 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using static Tensorflow.Binding; + namespace Tensorflow { public class ResourceVariableSaveable : MySaveableObject @@ -35,6 +37,32 @@ public ResourceVariableSaveable(Tensor var, string slice_spec, string name) this.name = name; } + public ResourceVariableSaveable(BaseResourceVariable var, string slice_spec, string name) + { + _var_device = var.Device; + _var_shape = var.shape; + + Tensor _read_variable_closure(BaseResourceVariable v) + { + tf.device(v.Device); + if(tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) + { + return null; + } + var x = v.read_value_no_copy(); + tf.device("/device:CPU:0"); + return array_ops.identity(x); + } + + this.handle_op = var.Handle; + var tensor = _read_variable_closure(var); + + var spec = new SaveSpec(tensor, slice_spec, name, dtype: var.dtype); + _op = var; + specs = new SaveSpec[] { spec }; + this.name = name; + } + public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) { var restored_tensor = restored_tensors[0]; diff --git a/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs b/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs index 1ae912ce6..393a6a981 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs @@ -28,7 +28,7 @@ public class SaveSpec public string slice_spec => _slice_spec; private string _name; - public string name => _name; + public string name { get => _name; set => _name = value; } private TF_DataType _dtype; public TF_DataType dtype => _dtype; diff --git a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs index c86075f86..1309a6174 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs @@ -14,11 +14,31 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using Tensorflow.Checkpoint; + namespace Tensorflow { public class MySaveableObject { - public Tensor op; + protected Maybe _op; + public Tensor op + { + get + { + if(_op.TryGet(out var tensor)) + { + return tensor; + } + else + { + throw new TypeError("The _op is not a tensor."); + } + } + set + { + _op = value; + } + } public SaveSpec[] specs; public string name; public string device; @@ -35,7 +55,7 @@ public MySaveableObject(Tensor var, string slice_spec, string name) public MySaveableObject(Tensor op, SaveSpec[] specs, string name) { - this.op = op; + this._op = op; this.specs = specs; this.name = name; } @@ -48,4 +68,18 @@ public virtual Operation restore(Tensor[] restored_tensors, Shape[] restored_sha validate_shape: restored_shapes == null && op.shape.IsFullyDefined); } } + + public class NoRestoreSaveable: MySaveableObject + { + public NoRestoreSaveable(Tensor tensor, string name, TF_DataType dtype = TF_DataType.DtInvalid, string? device = null) : base(tensor, + new SaveSpec[] { new SaveSpec(tensor, "", name, dtype) }, name) + { + + } + + public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) + { + return control_flow_ops.no_op(); + } + } } diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs new file mode 100644 index 000000000..d10257822 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs @@ -0,0 +1,11 @@ +using System.Collections.Generic; + +namespace Tensorflow; + +public record class AssetInfo +( + List asset_defs, + Dictionary asset_initializers_by_resource, + Dictionary asset_filename_map, + Dictionary asset_index +); diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs new file mode 100644 index 000000000..a91933357 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs @@ -0,0 +1,133 @@ +using System; +using Tensorflow.Checkpoint; +using Tensorflow.Train; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Functions; +using Tensorflow.Keras.Saving.SavedModel; + +namespace Tensorflow; + +public class AugmentedGraphView: ObjectGraphView +{ + private Dictionary> _children_cache; + private Dictionary> _serialization_cache; + private List _untraces_functions; + private Dictionary _wrapped_functions; + public AugmentedGraphView(Trackable root): base(root) + { + _children_cache= new Dictionary>(); + _serialization_cache = new Dictionary>(); + _untraces_functions = new List(); + _wrapped_functions = new Dictionary(); + } + + public void set_signature(SignatureMap signature_map, IDictionary wrapped_functions) + { + list_children(Root); + var name = SignatureSerializationUtils.SIGNATURE_ATTRIBUTE_NAME; + if (!_children_cache.ContainsKey(Root)) + { + _children_cache[Root] = new Dictionary(); + } + _children_cache[Root][name] = signature_map; + _wrapped_functions = _wrapped_functions.Concat(wrapped_functions).ToDictionary(x => x.Key, x => x.Value); + } + + public override List list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL, IDictionary>? serialization_cache = null) + { + if(serialization_cache is not null) + { + throw new ValueError("Serialization cache should not be passed to `AugmentedGraphView.list_children`, please either remove the parameter or use `ObjectGraphView.list_children`."); + } + + if (!_children_cache.ContainsKey(obj)) + { + Dictionary children = new Dictionary(); + _children_cache[obj] = children; + foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL, _serialization_cache)) + { + var name = pair.Name; + var child = pair.Refer; + if(child is ConcreteFunction) + { + child = maybe_uncache_variable_captures((ConcreteFunction)child); + } + children[name] = child; + } + + if (obj is Function && children.Count == 0) + { + _untraces_functions.Add(((Function)obj).Name); + } + } + + List res = new(); + foreach(var pair in _children_cache[obj]) + { + res.Add(new TrackableReference(pair.Key, pair.Value)); + } + + return res; + } + + private ConcreteFunction maybe_uncache_variable_captures(ConcreteFunction concrete_function) + { + if (_wrapped_functions.ContainsKey(concrete_function)) + { + return _wrapped_functions[concrete_function]; + } + // skip the process here because of lack of feature. + // In the future, we may add an attribute which could specify if the variable is supposed to be cached. + //foreach(var capture in concrete_function.CapturedInputs) + //{ + + //} + return concrete_function; + } + + public override (IList, IDictionary>) breadth_first_traversal() + { + Trackable get_merged_trackable(Trackable x) + { + // TODO: complete it with new definitions `Asset` and `TrackableConstant`. + return x; + } + var trackable_objects = base.breadth_first_traversal(); + + foreach(var obj in _children_cache.Keys) + { + // skip the deletion of cache (maybe do it later). + foreach(var pair in _children_cache[obj]) + { + _children_cache[obj][pair.Key] = get_merged_trackable(pair.Value); + } + } + + return base.breadth_first_traversal(); + } + + public List<(string, Trackable)> list_dependencies(Trackable obj) + { + IDictionary children; + if (!_children_cache.ContainsKey(obj)) + { + children= new Dictionary(); + } + else + { + children= _children_cache[obj]; + } + List<(string, Trackable)> res = new(); + foreach(var pair in obj.deserialization_dependencies(children)) + { + res.Add((pair.Key, pair.Value)); + } + return res; + } + + public Trackable get_child(Trackable obj, string name) + { + return _children_cache[obj][name]; + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs new file mode 100644 index 000000000..726f6cfd4 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs @@ -0,0 +1,33 @@ +namespace Tensorflow; + +public static class Constants +{ + public static readonly string ASSETS_DIRECTORY = "assets"; + public static readonly string ASSETS_KEY = "saved_model_assets"; + + public static readonly string DEBUG_DIRECTORY = "debug"; + + public static readonly string DEBUG_INFO_FILENAME_PB = "saved_model_debug_info.pb"; + + public static readonly string EXTRA_ASSETS_DIRECTORY = "assets.extra"; + + public static readonly string FINGERPRINT_FILENAME = "fingerprint.pb"; + + public static readonly string INIT_OP_SIGNATURE_KEY = "__saved_model_init_op"; + + public static readonly string LEGACY_INIT_OP_KEY = "legacy_init_op"; + + public static readonly string MAIN_OP_KEY = "saved_model_main_op"; + + public static readonly string SAVED_MODEL_FILENAME_PB = "saved_model.pb"; + public static readonly string SAVED_MODEL_FILENAME_PBTXT = "saved_model.pbtxt"; + + public static readonly int SAVED_MODEL_SCHEMA_VERSION = 1; + + public static readonly string TRAIN_OP_KEY = "saved_model_train_op"; + + public static readonly string TRAIN_OP_SIGNATURE_KEY = "__saved_model_train_op"; + + public static readonly string VARIABLES_DIRECTORY = "variables"; + public static readonly string VARIABLES_FILENAME = "variables"; +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs new file mode 100644 index 000000000..fe0403c30 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs @@ -0,0 +1,17 @@ +using Tensorflow.Train; + +namespace Tensorflow; + +public class RevivedTypes +{ + /// + /// Create a SavedUserObject from a trackable object. + /// + /// + /// + public static SavedUserObject? serialize(Trackable obj) + { + // TODO: complete the implementation. + return null; + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs new file mode 100644 index 000000000..8dd4f008f --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs @@ -0,0 +1,9 @@ +using System; + +namespace Tensorflow; + +public enum SaveType +{ + SAVEDMODEL, + CHECKPOINT +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs new file mode 100644 index 000000000..1be54287e --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs @@ -0,0 +1,299 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Checkpoint; +using Tensorflow.Contexts; +using Tensorflow.Functions; +using Tensorflow.ModelSaving; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; +using static Tensorflow.Binding; +using Tensorflow.Training.Saving.SavedModel; + +namespace Tensorflow; + +public class SaveableView +{ + private AugmentedGraphView _augmented_graph_view; + private SaveOptions _options; + private IList _trackable_objects; + private List _nodes; + private IDictionary> _node_paths; + private IDictionary _node_ids; + private IDictionary> + _slot_variables; + private IDictionary _object_names; + private List _gradient_functions; // to be completed + private List _gradient_defs; // to be completed + private List _concrete_functions; + private Dictionary _captured_tensor_node_ids; + private Dictionary> _saveable_objects_map; + private Dictionary _obj_to_registered_saver; + + public AugmentedGraphView AugmentedGraphView + { + get => _augmented_graph_view; + } + + public Trackable Root + { + get => _nodes[0]; + } + public List Nodes + { + get => _nodes; + } + public IDictionary NodeIds + { + get => _node_ids; + } + public List GradientDefs + { + get => _gradient_defs; + } + public IDictionary> NodePaths + { + get => _node_paths; + } + public SaveableView(AugmentedGraphView augmented_graph_view, SaveOptions options) + { + _augmented_graph_view = augmented_graph_view; + _options = options; + + (_trackable_objects, _node_paths, _node_ids, _slot_variables, _object_names) = + CheckPointUtils.objects_ids_and_slot_variables_and_paths(_augmented_graph_view); + + // TODO: deal with untraced functions. + + initialize_save_and_restore_functions(); + initialize_nodes_and_concrete_functions(); + + _captured_tensor_node_ids = new(); + } + + private void initialize_save_and_restore_functions() + { + // TODO: deal with the return value of `get_checkpoint_factories_and_keys`. + var (checkpoint_factory_map, registered_savers) = SaveUtilV1.get_checkpoint_factories_and_keys(_object_names); + // skip the process of registered savers and the generation of saveable_objects_map and _obj_to_registered_saver. + _obj_to_registered_saver = new(); + _saveable_objects_map = new(); + } + + private void initialize_nodes_and_concrete_functions() + { + _nodes = _trackable_objects.ToList().ConvertAll(x => x); // deep copy + _gradient_functions = new(); + _gradient_defs = new(); + + // TODO: deal with the condition that obj in `_saveable_objects_map`. + // foreach (var obj in _nodes) + // { + // + // } + + foreach (var obj in _nodes) + { + if (obj is ConcreteFunction) + { + _concrete_functions.Add((ConcreteFunction)obj); + } + } + } + + public List get_concrete_resource_initializers() + { + // TODO: complete the implementation. + return new List(); + } + + public (Dictionary, Dictionary, AssetInfo) map_resources() + { + Debug.Assert(!tf.Context.executing_eagerly()); + + Dictionary object_map = new(); + Dictionary tensor_map = new(); + + AssetInfo assetInfo = new(new List(), new Dictionary(), + new Dictionary(), new Dictionary()); + + foreach (var node_id in dependency_sorted_node_ids()) + { + var obj = _nodes[node_id]; + var tensors = obj.export_to_saved_model_graph(object_map, tensor_map, _options); + // TODO: deal with Asset (if obj is Asset) + foreach (var tensor in tensors) + { + _captured_tensor_node_ids[tensor] = node_id; + } + } + + return (object_map, tensor_map, assetInfo); + } + + /// + /// Returns topologically sorted nodes, sorted by dependencies. + /// + public List dependency_sorted_node_ids() + { + Dictionary> dependency_map = new(); + foreach (var node in _nodes) + { + var node_id = _node_ids[node]; + List deps = new List(); + dependency_map.Add(node_id, deps); + + // TODO: deal with captured tensor. + + foreach (var (_, dep) in _augmented_graph_view.list_dependencies(node)) + { + if (!_node_ids.ContainsKey(dep)) + { + var node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]); + throw new ValueError( + $"Found an untracked dependency. Object {node_path} depends on {dep}, " + + $"but this dependency isn't listed as a child. Please track this child by " + + $"overriding `_trackable_children` or use `._track_trackable`."); + } + deps.Add(_node_ids[dep]); + } + } + + try + { + return TrackableUtils.order_by_dependency(dependency_map); + } + catch (TrackableUtils.CyclicDependencyError err) + { + List pretty_printed_nodes = new(); + List pretty_printed_dependencies = new(); + + foreach (var pair in err.LeftOverDependencyMap) + { + var x = pair.Key; + var deps = pair.Value; + var node_path = TrackableUtils.pretty_print_node_path(_node_paths[_nodes[x]]); + pretty_printed_nodes.Add($"\tNode {x.ToString()} = {node_path} (type {_nodes[x]})"); + pretty_printed_dependencies.Add( + $"\tNode {x.ToString()} depends on nodes [{string.Join(", ", deps.Select(x => x.ToString()))}]"); + } + + throw new ValueError($"There is one or more dependency cycle in the saved Trackable object. " + + $"Saving cannot continue until this cycle is resolved." + + $"\n>> Unresolved nodes:\n{string.Join("\n", pretty_printed_nodes)}" + + $"\n>> Unresolved cyclic dependencies:\n{string.Join("\n", pretty_printed_dependencies)}"); + } + } + + /// + /// Corresponding to tensorflow/python/saved_model/save.py/_serialize_object_graph + /// + /// + /// + public SavedObjectGraph serialize_object_graph(IDictionary asset_file_def_index) + { + SavedObjectGraph proto = new(); + fill_object_graph_proto(proto); + + // TODO: complete the process of concrete functions. + + int cnt = Math.Min(_nodes.Count, proto.Nodes.Count); + for (int i = 0; i < cnt; i++) + { + var obj = _nodes[i]; + var obj_proto = proto.Nodes[i]; + write_object_proto(obj, obj_proto, asset_file_def_index, x => _augmented_graph_view.list_children(x)); + } + + return proto; + } + + private static void write_object_proto(Trackable obj, SavedObject proto, + IDictionary asset_file_def_index, Func> list_children_fn) + { + // skip the process of type Asset + if (resource_variable_ops.is_resource_variable(obj)) + { + var options = SaveContext.get_save_options(); + (obj as BaseResourceVariable).write_object_proto(proto, options); + } + else if (obj is Function) + { + // TODO: complete it. + throw new NotImplementedException(); + } + else if (obj is ConcreteFunction) + { + // TODO: complete it. + throw new NotImplementedException(); + } + // skip the process of type `_CapturedTensor` and `CapturableResource`. + else + { + var registered_type_proto = RevivedTypes.serialize(obj); + if (registered_type_proto is null) + { + registered_type_proto = new SavedUserObject() + { + Identifier = obj.ObjectIdentifier, + Version = new VersionDef() + { + Producer = 1, + MinConsumer = 1, + BadConsumers = { } + } + }; + } + + proto.UserObject = new SavedUserObject(registered_type_proto); + } + + // TODO: try get the registered_name from `registration`. + } + + public void fill_object_graph_proto(SavedObjectGraph proto) + { + for (int node_id = 0; node_id < _nodes.Count; node_id++) + { + var node = _nodes[node_id]; + Debug.Assert(_node_ids[node] == node_id); + SavedObject object_proto = new(); + if (_slot_variables.TryGetValue(node, out var value)) + { + object_proto.SlotVariables.AddRange(value); + } + // skip the check of type `_CapturedTensor` + foreach (var child in _augmented_graph_view.list_children(node)) + { + var child_proto = new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference(); + child_proto.NodeId = _node_ids[child.Refer]; + child_proto.LocalName = child.Name; + object_proto.Children.Add(child_proto); + } + + foreach (var pair in _augmented_graph_view.list_dependencies(node)) + { + var child_proto = new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference(); + child_proto.NodeId = _node_ids[pair.Item2]; + child_proto.LocalName = pair.Item1; + object_proto.Dependencies.Add(child_proto); + } + + if (_saveable_objects_map.ContainsKey(node)) + { + // TODO: complete it. + throw new NotImplementedException(); + } + else if(_obj_to_registered_saver.ContainsKey(node)) + { + // TODO: complete it. + // We now skip it for the lack of `SavedObject.registered_saver` API. + throw new NotImplementedException(); + } + + proto.Nodes.Add(object_proto); + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs new file mode 100644 index 000000000..6aa1fbde1 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs @@ -0,0 +1,10 @@ +namespace Tensorflow; + +public static class TagConstants +{ + public static readonly string SERVING = "serve"; + public static readonly string TRAINING = "train"; + public static readonly string EVAL = "eval"; + public static readonly string GPU = "gpu"; + public static readonly string TPU = "tpu"; +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs new file mode 100644 index 000000000..dbbab91d8 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public class BuilderUtils +{ + public static void copy_assets_to_destination_dir(IDictionary asset_filename_map, + string destination_dir, HashSet? saved_files = null) + { + if (saved_files is null) saved_files = new HashSet(); + + var asset_destination_dir = SavedModelUtils.get_or_create_assets_dir(destination_dir); + + // TODO: complete the implementation of this function. + if (asset_filename_map is not null && asset_filename_map.Count > 0) + { + throw new NotImplementedException(); + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs new file mode 100644 index 000000000..94760e3df --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs @@ -0,0 +1,269 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using Google.Protobuf; +using Tensorflow.Checkpoint; +using Tensorflow.Functions; +using Tensorflow.ModelSaving; +using Tensorflow.Train; +using Tensorflow.Exceptions; +using static Tensorflow.Binding; +using Tensorflow.Training.Saving.SavedModel; + +namespace Tensorflow; + +public static partial class SavedModelUtils +{ + private static readonly IEnumerable byte_swappable = new List() + { + dtypes.float16, dtypes.float32, dtypes.float64, TF_DataType.TF_BFLOAT16, + dtypes.complex64, dtypes.complex128, TF_DataType.TF_UINT16, dtypes.uint32, + dtypes.uint64, TF_DataType.TF_INT16, dtypes.int32, dtypes.int64, TF_DataType.TF_QINT16, + TF_DataType.TF_QUINT16, TF_DataType.TF_QINT32 + }.Select(x => (int)x); + + public static (IList, IDictionary>) save_and_return_nodes(Trackable obj, + string export_dir, ConcreteFunction? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false) + { + if (options is null) + { + options = new SaveOptions(); + } + + var saved_model = new Tensorflow.SavedModel(); + var meta_graph_def = new MetaGraphDef(); + saved_model.MetaGraphs.Add(meta_graph_def); + + var (_, exported_graph, object_saver, asset_info, saved_nodes, node_paths) = + _build_meta_graph(obj, signatures, options, meta_graph_def); + saved_model.SavedModelSchemaVersion = Tensorflow.Constants.SAVED_MODEL_SCHEMA_VERSION; + + if (!experimental_skip_checkpoint) + { + SavedModelUtils.get_or_create_variables_dir(export_dir); + CheckpointOptions ckpt_options = new(options.experimental_io_device); + object_saver.save(SavedModelUtils.get_variables_path(export_dir), options:ckpt_options); + } + BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir); + + if (tf.Context.executing_eagerly()) + { + // tensorflow python has a check of `context.async_wait()` here. + } + + // TODO: deal with `pywrap_saved_model.Save(export_dir)`. + + var saved_model_serialized = saved_model.ToString(); + + // This is a state depending on some py-c APIs. Here we temporarily set it as `true`. + if (true) + { + var fingerprint_path = Path.Combine(tf.compat.as_str(export_dir), + tf.compat.as_str(Constants.FINGERPRINT_FILENAME)); + // TODO: add c api and complete the fingerprint def. + var fingerprint_proto = ""; + File.WriteAllText(fingerprint_path, fingerprint_proto); + } + + var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB)); + File.WriteAllBytes(path, saved_model.ToByteArray()); + //File.WriteAllText(path, saved_model.ToString()); + + if (options.save_debug_info) + { + throw new NotImplementedException(); + } + + ops.dismantle_graph(exported_graph); + + return (saved_nodes, node_paths); + } + + private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, IList, + IDictionary>) _build_meta_graph(Trackable obj, + ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) + { + using (SaveContext.save_context(options)) + { + if (ops.inside_function()) + { + throw new AssertionError("`tf.saved_model.save` is not supported inside a traced @tf.function. " + + "Move the call to the outer eagerly-executed context."); + } + + if (meta_graph_def is null) + { + meta_graph_def = new MetaGraphDef(); + } + + AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); + if (signatures is null) + { + signatures = SignatureSerializationUtils.find_function_to_export(augmented_graph_view); + } + + // TODO: process of aignatures and wrapped_functions + + SaveableView saveable_view = new SaveableView(augmented_graph_view, options); + TrackableSaver object_saver = new TrackableSaver(augmented_graph_view); + var (asset_info, exported_graph) = _fill_meta_graph_def(meta_graph_def, saveable_view, signatures, + options.namespace_white_list, options.experimental_custom_gradients); + if (options.function_aliases is not null) + { + var function_aliases = meta_graph_def.MetaInfoDef.FunctionAliases; + foreach (var pair in options.function_aliases) + { + var alias = pair.Key; + var func = pair.Value; + // TODO: complete it. + throw new NotImplementedException(); + } + } + + var object_graph_proto = saveable_view.serialize_object_graph(asset_info.asset_index); + meta_graph_def.ObjectGraphDef = new SavedObjectGraph(object_graph_proto); + + return (meta_graph_def, exported_graph, object_saver, asset_info, saveable_view.Nodes, saveable_view.NodePaths); + } + } + + private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view, + ConcreteFunction signatures, IEnumerable namespace_whitelist, + bool save_custom_gradients) + { + var resource_initializers = saveable_view.get_concrete_resource_initializers(); + var exported_graph = new Graph(); + + Dictionary object_map; + Dictionary tensor_map; + AssetInfo asset_info; + var g = exported_graph.as_default(); + (object_map, tensor_map, asset_info) = saveable_view.map_resources(); + // TODO: deal with signatures. + if (save_custom_gradients) + { + // TODO: trace gradient functions. + } + + foreach (var resource_initializer_function in resource_initializers) + { + // List asset_dependencies = new(); + // TODO: deal with initializers + } + + // using(ops.control_dependencies(...)) + var init_op = control_flow_ops.no_op(); + if (meta_graph_def.CollectionDef.ContainsKey(Tensorflow.Constants.MAIN_OP_KEY)) + { + meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY].NodeList.Value.Append(init_op.name); + } + else + { + meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY] = new CollectionDef(); + } + // Lack `CopyFrom` API + // meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY] + + g.Exit(); + + foreach (var obj in object_map.Values) + { + obj._maybe_initialize_trackable(); + } + + // TODO: add the implementation of `call_with_mapped_functions`. + var (named_saveable_objects, registered_savers) = + SaveUtilV1.frozen_saveables_and_savers(saveable_view.AugmentedGraphView, object_map, exported_graph, false); + var saver = MultiDeviceSaver.from_saveables(named_saveable_objects, registered_savers, false); + + var eg = exported_graph.as_default(); + var saver_def = saver.to_proto(); + meta_graph_def.SaverDef = saver_def; + eg.Exit(); + + + saveable_view.dependency_sorted_node_ids(); + + var graph_def = exported_graph.as_graph_def(true); + graph_def.Library.RegisteredGradients.AddRange(saveable_view.GradientDefs); + verify_ops(graph_def, namespace_whitelist); + + meta_graph_def.GraphDef = new GraphDef(graph_def); + meta_graph_def.MetaInfoDef = new(); + meta_graph_def.MetaInfoDef.Tags.Add(TagConstants.SERVING); + meta_graph_def.MetaInfoDef.TensorflowVersion = tf.VERSION; + // TODO: add git version. + meta_graph_def.MetaInfoDef.TensorflowGitVersion = ""; + meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; + meta_graph_def.MetaInfoDef.StrippedOpList = new(); + meta_graph_def.MetaInfoDef.StrippedOpList.MergeFrom(meta_graph.stripped_op_list_for_graph(meta_graph_def.GraphDef)); + meta_graph_def.AssetFileDef.AddRange(asset_info.asset_defs); + + // TODO: deal with signatures here. + + meta_graph.strip_graph_default_valued_attrs(meta_graph_def); + + if (!BitConverter.IsLittleEndian) + { + swap_function_tensor_content(meta_graph_def); + } + + return (asset_info, exported_graph); + } + + private static void verify_ops(GraphDef graph_def, IEnumerable? namespace_whitelist) + { + return; + // if (namespace_whitelist is null || !namespace_whitelist.Any()) + // { + // return; + // } + + // skip the check for the lack of `meta_graph.ops_used_by_graph_def`. + } + + public static void swap_function_tensor_content(MetaGraphDef meta_graph_def) + { + var functions = meta_graph_def.GraphDef.Library.Function; + foreach (var function in functions) + { + var node_def = function.NodeDef; + foreach (var node in node_def) + { + if (node.Op == "Const") + { + var tensor = node.Attr["value"].Tensor; + byte_swap_tensor_content(tensor); + } + } + } + } + + public static void byte_swap_tensor_content(TensorProto tensor) + { + if (byte_swappable.Contains((int)tensor.Dtype)) + { + var tshape = tensor.TensorShape.Dim; + var tensor_bytes = tensor.TensorContent; + if (tensor_bytes is not null && !tensor_bytes.IsEmpty) + { + long tensor_size = 1; + foreach (var sz in tshape) + { + tensor_size *= sz.Size; + } + + var chunksize = tensor_bytes.Length / tensor_size; + List reversed_bytes = new(); + for (int i = 0; i < tensor_bytes.Length; i += (int)chunksize) + { + var current = tensor_bytes.Skip(i).Take((int)chunksize).Reverse(); + reversed_bytes.AddRange(current); + } + tensor.TensorContent = ByteString.CopyFrom(reversed_bytes.ToArray()); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs new file mode 100644 index 000000000..4cfe0b69b --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.ModelSaving; + +namespace Tensorflow.Training.Saving.SavedModel +{ + /// + /// A context for building a graph of SavedModel. + /// + public static class SaveContext + { + // TODO: make it thead safe. + private static bool _in_save_context = false; + private static SaveOptions _save_options = null; + + public static bool in_save_context() => _in_save_context; + public static SaveOptions get_save_options() + { + if (!in_save_context()) + { + throw new ValueError("Not in a SaveContext."); + } + return _save_options; + } + public static SaveContextHandler save_context(SaveOptions options) + { + return new SaveContextHandler(options); + } + + public class SaveContextHandler: IDisposable + { + private bool _old_in_save_context; + private SaveOptions _old_save_options; + public SaveContextHandler(SaveOptions options) + { + if (SaveContext.in_save_context()) + { + throw new ValueError("Already in a SaveContext."); + } + _old_in_save_context = SaveContext._in_save_context; + SaveContext._in_save_context = true; + _old_save_options = SaveContext._save_options; + SaveContext._save_options = options; + } + public void Dispose() + { + SaveContext._in_save_context = _old_in_save_context; + SaveContext._save_options = _old_save_options; + } + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs new file mode 100644 index 000000000..4a0d3b002 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs @@ -0,0 +1,107 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Functions; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Train; + +namespace Tensorflow; + +public static class SignatureSerializationUtils +{ + internal static readonly string DEFAULT_SIGNATURE_ATTR = "_default_save_signature"; + internal static readonly string SIGNATURE_ATTRIBUTE_NAME = "signatures"; + internal static readonly int _NUM_DISPLAY_NORMALIZED_SIGNATURES = 5; + public static SignatureMap create_signature_map(IDictionary signatures) + { + var signature_map = new SignatureMap(); + foreach (var pair in signatures) + { + var name = pair.Key; + var func = pair.Value; + Debug.Assert(func is ConcreteFunction); + // TODO: assert the `func.structured_outputs` and arg_keywords. + signature_map._add_signature(name, (ConcreteFunction)func); + } + + return signature_map; + } + + public static ConcreteFunction find_function_to_export(AugmentedGraphView graph_view) + { + var children = graph_view.list_children(graph_view.Root); + List possible_signatures = new(); + foreach (var item in children) + { + var name = item.Name; + var child = item.Refer; + if(child is not (Function or ConcreteFunction)) + { + continue; + } + if(name == DEFAULT_SIGNATURE_ATTR) + { + Debug.Assert(child is ConcreteFunction); + return (ConcreteFunction)child; + } + ConcreteFunction concrete = get_signature(child); + if(concrete is not null && valid_signature(concrete)) + { + possible_signatures.Add(concrete); + } + } + + if(possible_signatures.Count == 1) + { + var signature = get_signature(possible_signatures[0]); + if(signature is not null && valid_signature(signature)) + { + return signature; + } + } + return null; + } + + private static ConcreteFunction get_signature(Trackable function) + { + // TODO: implement it. + return null; + } + + private static bool valid_signature(ConcreteFunction concreate_function) + { + // TODO: implement it. + return false; + } +} + +public class SignatureMap: Trackable +{ + private Dictionary _signatures; + + public SignatureMap() + { + _signatures = new(); + } + + public void _add_signature(string name, ConcreteFunction concrete_function) + { + _signatures[name] = concrete_function; + } + + public void _add_signature(string name, Function concrete_function) + { + _signatures[name] = concrete_function; + } + + public override IDictionary _trackable_children(SaveType save_type, IDictionary>? cache = null) + { + if (save_type != SaveType.SAVEDMODEL) + { + return new Dictionary(); + } + + return _signatures.TakeWhile(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value); + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs new file mode 100644 index 000000000..b0e6411c9 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs @@ -0,0 +1,57 @@ +using System.IO; +using System.Security.Cryptography.X509Certificates; +using Tensorflow.Train; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public static partial class SavedModelUtils +{ + /// + /// Return variables sub-directory, or create one if it doesn't exist. + /// + /// + public static string get_or_create_variables_dir(string export_dir) + { + var variables_dir = get_variables_dir(export_dir); + Directory.CreateDirectory(variables_dir); + return variables_dir; + } + + /// + /// Return variables sub-directory in the SavedModel. + /// + /// + /// + public static string get_variables_dir(string export_dir) + { + return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.VARIABLES_DIRECTORY)); + } + + public static string get_variables_path(string export_dir) + { + return Path.Combine(tf.compat.as_text(get_variables_dir(export_dir)), tf.compat.as_text(Constants.VARIABLES_FILENAME)); + } + + /// + /// Return assets sub-directory, or create one if it doesn't exist. + /// + /// + /// + public static string get_or_create_assets_dir(string export_dir) + { + var assets_destination_dir = get_assets_dir(export_dir); + Directory.CreateDirectory(assets_destination_dir); + return assets_destination_dir; + } + + /// + /// Return path to asset directory in the SavedModel. + /// + /// + /// + public static string get_assets_dir(string export_dir) + { + return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.ASSETS_DIRECTORY)); + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index 3a6647880..a6e21e3e5 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -16,12 +16,38 @@ limitations under the License. using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using Tensorflow.Checkpoint; +using Tensorflow.Operations.Activation; +using Tensorflow.Train; +using Tensorflow.Training; using static Tensorflow.Binding; namespace Tensorflow { - public class saveable_object_util + /// + /// A SaveableObject that defines `Trackable` checkpointing steps. + /// + public class TrackableSaveable : MySaveableObject + { + private string _prefix; + private IEnumerable _local_names; + private Trackable _trackable; + private bool _call_with_mapped_captures; + // TODO: revise the implementation. Currently the parameter of constructor of this class and its base class has conflict. + public TrackableSaveable(Trackable obj, IEnumerable specs, string name, IEnumerable local_names, + string prefix, bool call_with_mapped_captures = false) : base((object)obj as Tensor, specs.ToArray(), name) + { + _prefix = prefix; + _trackable = obj; + _local_names = local_names; + _call_with_mapped_captures = call_with_mapped_captures; + } + + // TODO: complete this class. + } + public static class saveable_object_util { /// /// Returns the variables and names that will be used for a Saver. @@ -52,7 +78,7 @@ private static void _add_saveable(List saveables, List seen_ops, T } /// - /// Create `SaveableObject`s from an operation. + /// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`. /// /// /// @@ -74,6 +100,72 @@ public static IEnumerable saveable_objects_for_op(Tensor op, s } } + /// + /// Create `SaveableObject`s from an operation. + /// + /// + /// + /// + public static IEnumerable saveable_objects_for_op(Trackable obj, string name) + { + // The `op` maybe `Variable` or `Trackable`. + if (obj is BaseResourceVariable) + { + var variable = obj as BaseResourceVariable; + if (variable.InGraphMode) + { + yield return new ResourceVariableSaveable(variable.GraphElement, "", name); + } + else + { + yield return new ResourceVariableSaveable(variable, "", name); + } + } + else + { + foreach(var pair in saveable_objects_from_trackable(obj)) + { + var attr = pair.Key; + var factory = pair.Value; + string full_name; + if(attr == Trackable.Constants.VARIABLE_VALUE_KEY) + { + full_name = name; + } + else + { + full_name = name + "_" + attr; + } + if(factory.TryGet(out var variable)) + { + foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name)) + { + yield return op; + } + } + else + { + var saveable = factory.GetValue(); + foreach (var op in saveable_objects_for_op(saveable, saveable.name)) + { + yield return op; + } + } + } + } + } + + /// + /// Create `SaveableObject`s from an operation. + /// + /// + /// + /// + public static IEnumerable saveable_objects_for_op(MySaveableObject obj, string name) + { + yield return obj; + } + public static Dictionary op_list_to_dict(IVariableV1[] op_list, bool convert_variable_to_tensor = true) { op_list = op_list.OrderBy(x => x.Name).ToArray(); @@ -121,5 +213,164 @@ public static Dictionary op_list_to_dict(IVariableV1[] op_list, return names_to_saveables; } + + public static IDictionary> saveable_objects_from_trackable(Trackable obj) + { + // skip the process of type `PythonState` + + if (trackable_has_serialize_to_tensor(obj)) + { + var name = TrackableUtils.SERIALIZE_TO_TENSORS_NAME; + // skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. + var tensor_dict = obj.serialize_to_tensors(); + + List specs = new(); + List local_names = new(); + string prefix = SaveableCompat.get_saveable_name(obj) ?? ""; + foreach(var pair in tensor_dict) + { + var tensor_name = pair.Key; + var maybe_tensor = pair.Value; + local_names.Add(tensor_name); + string spec_name = name + TrackableUtils.escape_local_name(tensor_name); + + IDictionary internal_dict; + if(maybe_tensor.TryGet(out var tensor)) + { + internal_dict= new Dictionary(); + internal_dict[""] = tensor; + } + else + { + internal_dict = maybe_tensor.GetValue>(); + } + + foreach(var item in internal_dict) + { + specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); + } + } + Dictionary> res = new(); + res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix); + return res; + } + else + { + return obj.gather_saveables_for_checkpoint(); + } + } + + public static bool trackable_has_serialize_to_tensor(Trackable obj) + { + return obj.GetType().GetMethod("serialize_to_tensors").DeclaringType != typeof(Trackable); + } + + internal static string convert_to_string(string x) + { + return tf.compat.as_str(x); + } + + /// + /// Converts a list of SaveableObjects to a tensor dictionary. + /// + /// + public static Dictionary>> saveable_object_to_tensor_dict(IList saveables) + { + Dictionary>> tensor_dict = new(); + foreach (var saveable in saveables) + { + foreach (var spec in saveable.specs) + { + // skip the check that if `spec` is callable. + var name = convert_to_string(spec.name); + var slice_spec = convert_to_string(spec.slice_spec); + if (!string.IsNullOrEmpty(slice_spec)) + { + tensor_dict.SetDefault(name, new Dictionary()).GetValue>()[slice_spec] = spec.tensor; + } + else + { + tensor_dict[name] = spec.tensor; + } + } + } + return tensor_dict; + } + + /// + /// Generates `Trackable._restore_from_tensors` from SaveableObjects. + /// + /// + public static Func>>, IDictionary> saveable_object_to_restore_fn(IList saveables) + { + return (restored_tensors) => + { + Dictionary restored_ops = new(); + + foreach(var saveable in saveables) + { + List saveable_restored_tensors = new(); + foreach(var spec in saveable.specs) + { + var name = TrackableUtils.extract_local_name(saveable_object_util.convert_to_string(spec.name)); + var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); + + var maybe_tensor = restored_tensors[name]; + IDictionary dict; + if(maybe_tensor.TryGet(out var tensor)) + { + dict = new Dictionary(); + dict[""] = tensor; + } + else + { + dict = maybe_tensor.GetValue>(); + } + saveable_restored_tensors.Add(dict[slice_spec]); + } + restored_ops[saveable.name] = saveable.restore(saveable_restored_tensors.ToArray(), null); + } + return restored_ops; + }; + } + } + + public class SaveableCompatibilityConverter: Trackable + { + private object _obj; + private IList _saveables; + public SaveableCompatibilityConverter(object obj, IList saveables) + { + _obj= obj; + _saveables= saveables; + } + + public object Obj => _obj; + public IList mySaveables=> _saveables; + + public override IDictionary>> serialize_to_tensors() + { + return saveable_object_util.saveable_object_to_tensor_dict(_saveables); + } + + /// + /// Returns the restore ops defined in the Saveables. + /// + /// + /// + public override IDictionary _restore_from_tensors(IDictionary>> restored_tensors) + { + List expected_keys = new(); + foreach(var saveable in _saveables) + { + expected_keys.AddRange(saveable.specs.Select(x => TrackableUtils.extract_local_name(saveable_object_util.convert_to_string(x.name)))); + } + if (!expected_keys.Distinct().SequenceEqual(restored_tensors.Keys)) + { + throw new ValueError($"Could not restore object {_obj} because not all expected tensors were in the checkpoint." + + $"\n\tExpected: {expected_keys} \n\tGot: {list(restored_tensors.Keys)}"); + } + return saveable_object_util.saveable_object_to_restore_fn(_saveables).Invoke(restored_tensors); + } } } diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index 79d6dca92..132571f2a 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -14,13 +14,63 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Checkpoint; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.ModelSaving; +using Tensorflow.Training; using static Tensorflow.Binding; namespace Tensorflow.Train { - public abstract class Trackable + public abstract class Trackable: IWithTrackable { + /// + /// Corresponding to tensorflow/python/trackable/constants.py + /// + public static class Constants + { + public static readonly string OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"; + public static readonly string VARIABLE_VALUE_KEY = "VARIABLE_VALUE"; + public static readonly string OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"; + } protected int _self_update_uid; + protected IDictionary _unconditional_dependency_names; + + protected IList _unconditional_checkpoint_dependencies; + + protected IDictionary> _self_saveable_object_factories = + new Dictionary>(); + private bool _manual_tracking = true; + + private static Trackable _none = new AutoTrackable(); + /// + /// This is a trick for that CSharp does not allow the key of `Dictionary` to be null. + /// The `None` can be any object that inherits `Trackable`. + /// This Property is supposed to be used only internal. + /// + public static Trackable None + { + get + { + return _none; + } + } + public Trackable GetTrackable() + { + return this; + } + public virtual string ObjectIdentifier + { + get => "_generic_user_object"; + } + public int UpdateUid { get => _self_update_uid; set => _self_update_uid = value; } + public IList UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } + public IDictionary UnconditionalDependencyNames { get => _unconditional_dependency_names; } + public IList CheckpointDependencies { get => UnconditionalCheckpointDependencies; } /// /// Restore-on-create for a variable be saved with this `Checkpointable`. @@ -47,9 +97,13 @@ protected virtual IVariableV1 _add_variable_with_custom_getter(VariableArgs args // assign again. It will add this variable to our dependencies, and if there // is a non-trivial restoration queued, it will handle that. This also // handles slot variables. - if (!args.Overwrite || new_variable is RefVariable) - return _track_checkpointable(new_variable, name: args.Name, - overwrite: args.Overwrite); + if (!args.Overwrite || new_variable is RefVariable || new_variable is Trackable) + { + var temp = new_variable as Trackable; + var res = _track_trackable(temp, args.Name, args.Overwrite); + Debug.Assert(res is IVariableV1); + return res as IVariableV1; + } else return new_variable; } @@ -73,10 +127,136 @@ protected IVariableV1 _track_checkpointable(IVariableV1 checkpointable, string n /// /// Initialize dependency management. /// - protected void _maybe_initialize_trackable() + public void _maybe_initialize_trackable() { - // _self_unconditional_checkpoint_dependencies = [] + if(_unconditional_checkpoint_dependencies is not null) + { + return; + } _self_update_uid = -1; + _unconditional_checkpoint_dependencies = new List(); + _unconditional_dependency_names = new Dictionary(); + } + + public virtual IDictionary _trackable_children(SaveType save_type, IDictionary>? cache) + { + _maybe_initialize_trackable(); + return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); + } + + public virtual Trackable _track_trackable(Trackable trackable, string name, bool overwrite = false) + { + _maybe_initialize_trackable(); + if (!_manual_tracking) return trackable; + var new_reference = new TrackableReference(name, trackable); + var current_object = _lookup_dependency(name); + + if(current_object is null) + { + _unconditional_checkpoint_dependencies.Add(new_reference); + _handle_deferred_dependencies(name, trackable); + } + _unconditional_dependency_names[name] = trackable; + return trackable; + } + + /// + /// Pop and load any deferred checkpoint restores into `trackable`. + /// This method does not add a new dependency on `trackable`, but it does check if any outstanding/deferred dependencies have been queued waiting for + /// this dependency to be added (matched based on `name`). If so, `trackable` and its dependencies are restored. The restorations are + /// considered fulfilled and so are deleted. + /// `_track_trackable` is more appropriate for adding a normal/unconditional dependency, and includes handling for deferred restorations. + /// This method allows objects such as `Optimizer` to use the same restoration logic while managing conditional dependencies themselves, + /// by overriding `_checkpoint_dependencies` and `_lookup_dependency` to change the object's dependencies based on the context + /// it is saved/restored in (a single optimizer instance can have state associated with multiple graphs). + /// + /// + /// + public virtual void _handle_deferred_dependencies(string name, Trackable trackable) + { + //_maybe_initialize_trackable(); + //trackable._maybe_initialize_trackable(); + + // TODO: complete the implementation. + } + + public virtual Trackable? _lookup_dependency(string name) + { + if (_unconditional_dependency_names.TryGetValue(name, out var dependency)) return dependency; + else return null; + } + + public static Trackable convert_to_trackable(object obj, object? parent = null) + { + if (obj is Trackable) + { + return (Trackable)obj; + } + else + { + throw new NotImplementedException(); + } + } + + public virtual IDictionary deserialization_dependencies(IDictionary children) + { + return new Dictionary(); + } + + public virtual (IDictionary, IDictionary) map_resources( + SaveOptions? save_options) + { + return (new Dictionary(), new Dictionary()); + } + + public virtual List export_to_saved_model_graph(IDictionary object_map, + IDictionary tensor_map, SaveOptions? options = null) + { + var (self_object_map, self_tensor_map) = map_resources(options); + foreach (var pair in self_object_map) + { + object_map.Add(pair); + } + foreach (var pair in self_tensor_map) + { + tensor_map.Add(pair); + } + + return self_tensor_map.Keys.ToList(); + } + + public virtual IDictionary> gather_saveables_for_checkpoint() + { + if (saveable_object_util.trackable_has_serialize_to_tensor(this)) + { + // TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). + throw new NotImplementedException(); + } + else + { + return _self_saveable_object_factories; + } + } + + /// + /// Gathers tensors to save to the checkpoint. You should only override `serialize_to_tensors` and `restore_from_tensors` + /// if you are defining a custom resource or variable with custom ops. + /// Otherwise, please store the state of your trackable in `tf.Variable` objects + /// and add them to Trackable object hierarchy using `setattr` (for subclasses + /// of `AutoTrackable`) or overriding the `_trackable_children` method. + /// + /// + /// + public virtual IDictionary>> serialize_to_tensors() + { + throw new NotImplementedException(); + } + + public virtual IDictionary _restore_from_tensors(IDictionary>> restored_tensors) + { + throw new NotImplementedException(); } } + + public record class TrackableReference(string Name, Trackable Refer); } diff --git a/src/TensorFlowNET.Core/Training/TrackableUtils.cs b/src/TensorFlowNET.Core/Training/TrackableUtils.cs new file mode 100644 index 000000000..390d95c75 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/TrackableUtils.cs @@ -0,0 +1,172 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Exceptions; +using Tensorflow.Train; + +namespace Tensorflow.Training; + +public static class TrackableUtils +{ + public class CyclicDependencyError: System.Exception + { + public IDictionary> LeftOverDependencyMap { get; } + public CyclicDependencyError(IDictionary> leftover_dependency_map): base() + { + LeftOverDependencyMap = leftover_dependency_map; + } + public CyclicDependencyError(IDictionary> leftover_dependency_map): base() + { + LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable()); + } + } + private static string _ESCAPE_CHAR = "."; + private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; + private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; + internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; + public static string object_path_to_string(IEnumerable node_path_arr) + { + return string.Join("/", node_path_arr.Select(x => escape_local_name(x.Name))); + } + + public static string escape_local_name(string name) + { + return name.Replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR).Replace("/", _ESCAPE_CHAR + "S"); + } + + public static string checkpoint_key(string object_path, string local_name) + { + var key_suffix = escape_local_name(local_name); + if (local_name == SERIALIZE_TO_TENSORS_NAME) + { + key_suffix = ""; + } + + return $"{object_path}/{OBJECT_ATTRIBUTES_NAME}/{key_suffix}"; + } + + /// + /// Topologically sorts the keys of a map so that dependencies appear first. + /// Uses Kahn's algorithm: https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm + /// + /// + /// + public static List order_by_dependency(IDictionary> dependency_map) + { + Dictionary> reverse_dependency_map = new(); + foreach (var pair in dependency_map) + { + foreach (var dep in pair.Value) + { + if (reverse_dependency_map.ContainsKey(dep)) + { + reverse_dependency_map[dep].Add(pair.Key); + } + else + { + reverse_dependency_map[dep] = new HashSet(); + reverse_dependency_map[dep].Add(pair.Key); + } + } + } + + // Validate that all values in the dependency map are also keys. + var unknown_keys = reverse_dependency_map.Keys.Except(dependency_map.Keys); + if (unknown_keys.Count() > 0) + { + throw new ValueError( + $"Found values in the dependency map which are not keys: {string.Join(", ", unknown_keys.Select(x => x.ToString()))}"); + } + + // Generate the list sorted by objects without dependencies -> dependencies. + // The returned list will reverse this. + List reversed_dependency_arr = new(); + + Queue to_visit = new(); + foreach (var x in dependency_map.Keys) + { + if (!reverse_dependency_map.ContainsKey(x)) + { + to_visit.Enqueue(x); + } + } + + while (to_visit.Count > 0) + { + var x = to_visit.Dequeue(); + reversed_dependency_arr.Add(x); + foreach (var dep in dependency_map[x].Distinct()) + { + var edges = reverse_dependency_map[dep]; + edges.Remove(x); + if (edges.Count == 0) + { + to_visit.Enqueue(dep); + if (!reverse_dependency_map.Remove(dep)) + { + throw new KeyError($"Cannot find the key {dep} in reverse_dependency_map"); + } + } + } + } + + if (reverse_dependency_map.Count > 0) + { + Dictionary> leftover_dependency_map = new(); + foreach (var pair in reverse_dependency_map) + { + foreach (var x in pair.Value) + { + if (leftover_dependency_map.ContainsKey(x)) + { + leftover_dependency_map[x].Add(pair.Key); + } + else + { + leftover_dependency_map[x] = new List() { pair.Key }; + } + } + } + + throw new CyclicDependencyError(leftover_dependency_map); + } + + reversed_dependency_arr.Reverse(); + return reversed_dependency_arr; + } + + public static string pretty_print_node_path(IEnumerable paths) + { + if (paths.Count() == 0) + { + return "root object"; + } + else + { + return $"root.{string.Join(".", paths.Select(x => x.Name))}"; + } + } + + /// + /// Returns the substring after the "/.ATTIBUTES/" in the checkpoint key. + /// + /// + /// + /// + public static string extract_local_name(string key, string? prefix = null) + { + if(prefix is null) + { + prefix = ""; + } + var search_key = OBJECT_ATTRIBUTES_NAME + "/" + prefix; + try + { + return key.Substring(key.IndexOf(search_key) + search_key.Length); + } + catch(ArgumentOutOfRangeException) + { + return key; + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/data_structures.cs b/src/TensorFlowNET.Core/Training/data_structures.cs new file mode 100644 index 000000000..6e3336c90 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/data_structures.cs @@ -0,0 +1,370 @@ +using Google.Protobuf; +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO.Compression; +using System.Linq; +using System.Linq.Expressions; +using System.Runtime.InteropServices; +using System.Text; +using Tensorflow.Functions; +using Tensorflow.Keras; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Operations.Activation; +using Tensorflow.Train; +using static Tensorflow.ApiDef.Types; + +namespace Tensorflow.Training +{ + public class NoDependency + { + public Trackable Value { get; set; } + public NoDependency(Trackable value) + { + Value = value; + } + } + + public abstract class TrackableDataStructure : Trackable + { + private bool _self_trainable; + private List _self_extra_variables; + + public TrackableDataStructure() + { + _self_trainable = true; + _self_extra_variables = new List(); + } + + public abstract IEnumerable Values { get; } + public bool Trainable { get => _self_trainable; set => _self_trainable = value; } + public IEnumerable Layers + { + get + { + List collected = new(); + foreach(var obj in Values) + { + if(obj is ILayer) + { + collected.Add((ILayer)obj); + } + else if(obj is TrackableDataStructure) + { + collected.AddRange((obj as TrackableDataStructure).Layers); + } + } + return collected; + } + } + public IEnumerable TrainableWeights + { + get + { + if (!_self_trainable) + { + return new List(); + } + List trainable_variables = new(); + foreach (var obj in Values) + { + // skip the process of `module.Module`. + if (obj is TrackableDataStructure) + { + trainable_variables.AddRange((obj as TrackableDataStructure).TrainableVariables); + } + } + foreach(var v in _self_extra_variables) + { + if (v.Trainable) + { + trainable_variables.Add(v); + } + } + return trainable_variables; + } + } + public IEnumerable NonTrainableWeights + { + get + { + var trainable_extra_variables = _self_extra_variables.TakeWhile(x => x.Trainable).ToList(); + var non_trainable_extra_variables = _self_extra_variables.TakeWhile(x => !x.Trainable).ToList(); + List non_trainable_variables = new(); + foreach(var obj in Values) + { + // skip the process of `module.Module`. + if (obj is TrackableDataStructure) + { + non_trainable_variables.AddRange((obj as TrackableDataStructure).NonTrainableVariables); + } + } + + if (!_self_trainable) + { + // Return order is all trainable vars, then all non-trainable vars. + List trainable_variables = new(); + foreach(var obj in Values) + { + // skip the process of `module.Module`. + if (obj is TrackableDataStructure) + { + trainable_variables.AddRange((obj as TrackableDataStructure).TrainableVariables); + } + } + return trainable_variables.concat(trainable_extra_variables).concat(non_trainable_variables).concat(non_trainable_extra_variables); + } + else + { + return non_trainable_variables.concat(non_trainable_extra_variables); + } + } + } + public IEnumerable Weights => TrainableWeights.Concat(NonTrainableWeights); + public IEnumerable TrainableVariables => TrainableWeights; + public IEnumerable NonTrainableVariables => NonTrainableWeights; + public IEnumerable Variables => Weights; + + // TODO: `losses` property. + + /// + /// Add a dependency on `value`. + /// + /// + /// + protected virtual Trackable _track_value(Trackable value, string name) + { + value = sticky_attribute_assignment(this, name, value); + if(value is IVariableV1) + { + _self_extra_variables.Add(value as IVariableV1); + } + // skip the left process (need to be done in the future). + return value; + } + + public static Trackable wrap_or_unwrap(NoDependency value) + { + return value.Value; + } + + public static Trackable wrap_or_unwrap(Trackable value) + { + return value; + } + + public static Trackable wrap_or_unwrap(IList value) + { + return new ListWrapper(value); + } + + public static Trackable wrap_or_unwrap(IEnumerable value) + { + return new ListWrapper(value.ToList()); + } + + protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, Trackable value) + { + value = wrap_or_unwrap(value); + trackable._track_trackable(value, name, true); + return value; + } + + protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, NoDependency value) + { + var wrapped_value = wrap_or_unwrap(value); + trackable._track_trackable(wrapped_value, name, true); + return wrapped_value; + } + + protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, IList value) + { + var wrapped_value = wrap_or_unwrap(value); + trackable._track_trackable(wrapped_value, name, true); + return wrapped_value; + } + } + + public class ListWrapper : TrackableDataStructure, IList, ICloneable + { + private IList _storage; + private bool _non_append_mutation_value; + private bool _external_modification_value; + private IList _last_wrapped_list_snapshot; + /// + /// + /// + /// The initial value of the data structure. A shallow copy may be maintained for error checking. `wrapped_list` itself should not be + /// modified directly after constructing the `ListWrapper`, and if changes are detected the `ListWrapper` will throw an exception on save. + public ListWrapper(IList wrapped_list) + { + _storage = wrapped_list; + _non_append_mutation_value = _external_modification_value = false; + _last_wrapped_list_snapshot = new List(_storage); + } + + protected bool NonAppendMuation { + get => _non_append_mutation_value; + set + { + // TODO: deal with `attribute_sentinel`. + _non_append_mutation_value = value; + } + } + + protected bool ExternalModification + { + get => _external_modification_value; + set + { + // TODO: deal with `attribute_sentinel`. + _external_modification_value = value; + } + } + + public override IEnumerable Values => this; + public bool IsReadOnly { get => _storage.IsReadOnly; } + + /// + /// Checks for any changes to the wrapped list not through the wrapper. + /// + private void check_external_modification() + { + if (_external_modification_value || _non_append_mutation_value) return; + if (!_storage.SequenceEqual(_last_wrapped_list_snapshot)) + { + _external_modification_value = true; + } + } + + private void update_snapshot() + { + // TODO: deal with `attribute_sentinel`. + if (_external_modification_value || _non_append_mutation_value) return; + _last_wrapped_list_snapshot = new List(_storage); + } + + public override IDictionary _trackable_children(SaveType save_type, IDictionary>? cache = null) + { + check_external_modification(); + if (_non_append_mutation_value) + { + throw new ValueError($"Unable to save the object {this} (a list wrapper constructed to track trackable TensorFlow objects). A list element was replaced" + + $", deleted or moved (sort). In order to support restoration on object creation, tracking is exclusively for append-only data structures." + + $"\n\nIf you don't need this list checkpointed, wrap it in a non-trackable object; it will be subsequently ignored."); + } + if (_external_modification_value) + { + throw new ValueError($"Unable to save the object {this} (a list wrapper constructed to track trackable TensorFlow objects). The wrapped list was modified " + + $"outside the wrapper (its final value was {_storage}, its value when a checkpoint dependency was added was {_last_wrapped_list_snapshot}), which breaks " + + $"restoration on object creation.\n\nIf you don't need this list checkpointed, wrap it in a NoDependency object; it will be subsequently ignored."); + } + var children = base._trackable_children(save_type, cache); + + if(save_type == SaveType.SAVEDMODEL) + { + children = children.Concat(this.TakeWhile(x => x is Function or ConcreteFunction).Select((x, idx) => new KeyValuePair(idx.ToString(), x))).ToDictionary(x => x.Key, x => x.Value); + } + + return children; + } + + private bool has_mutation_or_trackable() + { + return _non_append_mutation_value; + } + + /// + /// Allows storage of non-trackable objects. + /// + /// + /// + /// + protected override Trackable _track_value(Trackable value, string name) + { + try + { + base._track_value(value, name); + } + catch(ValueError ex) + { + value = sticky_attribute_assignment(this, name, value); + } + return value; + } + + public object Clone() + { + var res = new ListWrapper(_storage.Select(x => x).ToList()); + res.NonAppendMuation= _non_append_mutation_value; + res.ExternalModification = _external_modification_value; + return res; + } + + public Trackable this[int index] { + get => _storage[index]; + set + { + // skip the process of `Slice`, maybe support it in the future. + _non_append_mutation_value = true; + _storage[index] = _track_value(value, _name_element(index)); + + update_snapshot(); + } + } + + public int IndexOf(Trackable item) => _storage.IndexOf(item); + + public void Insert(int index, Trackable item) + { + check_external_modification(); + _non_append_mutation_value = true; + _storage.Insert(index, item); + update_snapshot(); + } + + public void RemoveAt(int index) + { + check_external_modification(); + if (has_mutation_or_trackable()) + { + _non_append_mutation_value = true; + } + _storage.RemoveAt(index); + update_snapshot(); + } + + public int Count { get => _storage.Count; } + + public void Add(Trackable item) + { + check_external_modification(); + _storage.Add(item); + update_snapshot(); + } + + public void Clear() => _storage.Clear(); + + public bool Contains(Trackable item) => _storage.Contains(item); + + public void CopyTo(Trackable[] array, int arrayIndex) => _storage.CopyTo(array, arrayIndex); + + public bool Remove(Trackable item) + { + check_external_modification(); + if (has_mutation_or_trackable()) + { + _non_append_mutation_value = true; + } + var res = _storage.Remove(item); + update_snapshot(); + return res; + } + + public IEnumerator GetEnumerator() => _storage.GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => _storage.GetEnumerator(); + + protected string _name_element(int index) => $"{index}"; + } +} diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index b270ec57d..4005d5640 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -2,14 +2,20 @@ using System; using Tensorflow.Eager; using Tensorflow.Variables; +using Tensorflow.Train; using static Tensorflow.Binding; +using System.Collections.Generic; +using Tensorflow.ModelSaving; +using System.Diagnostics; +using Tensorflow.Checkpoint; namespace Tensorflow { - public class BaseResourceVariable : DisposableObject + public class BaseResourceVariable : DisposableTrackableObject { protected string _name; public virtual string Name => _handle_name; + public virtual string SharedName => _name; protected TF_DataType _dtype; public TF_DataType dtype => _dtype; protected string _handle_name; @@ -19,9 +25,10 @@ public class BaseResourceVariable : DisposableObject public string UniqueId => _unique_id; protected bool _in_graph_mode; + internal bool InGraphMode => _in_graph_mode; protected bool _trainable; - public bool trainable => _trainable; + public bool Trainable => _trainable; protected Tensor _initial_value; @@ -46,6 +53,7 @@ public class BaseResourceVariable : DisposableObject public Graph Graph => handle.graph; public string Device => handle.Device; EagerResourceDeleter eager_resource_deleter; + public VariableAggregation Aggregation { get; protected set; } = VariableAggregation.None; public BaseResourceVariable() { @@ -73,6 +81,11 @@ public void __init__(bool trainable = true, _handle = handle.EagerTensorHandle.DangerousGetHandle(); eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device); } + else if(handle is null) + { + // TODO: fix this dangerous change. + _handle = IntPtr.Zero; + } else { _handle = handle.Handle == null ? IntPtr.Zero : handle.Handle.DangerousGetHandle(); @@ -165,7 +178,7 @@ IVariableV1 _lazy_read(Operation op, Tensor value) /// void variable_accessed(BaseResourceVariable variable) { - if (variable.trainable) + if (variable.Trainable) { foreach (var tape in tf.GetTapeSet()) tape.VariableAccessed(variable as ResourceVariable); @@ -243,5 +256,60 @@ public Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = else return value(); } + + public override (IDictionary, IDictionary) map_resources(SaveOptions save_options) + { + BaseResourceVariable new_variable; + if (save_options.experimental_variable_policy.save_variable_devices()) + { + tf.device(this.Device); + Debug.Assert(this is ResourceVariable); + new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); + } + else + { + new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); + } + Dictionary obj_map = new(); + Dictionary resource_map = new(); + obj_map[this] = new_variable; + resource_map[this.handle] = new_variable.handle; + return (obj_map, resource_map); + } + + /// + /// Writes additional information of the variable into the SavedObject proto. + /// ubclasses of ResourceVariables could choose to override this method to + /// customize extra information to provide when saving a SavedModel. + /// + /// + /// + public virtual void write_object_proto(SavedObject proto, SaveOptions options) + { + resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); + } + + public override IDictionary> gather_saveables_for_checkpoint() + { + var res = new Dictionary>(); + res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; + return res; + } + + public Tensor is_initialized(string name = null) + { + return gen_resource_variable_ops.var_is_initialized_op(this.handle, name); + } + + public Tensor read_value_no_copy() + { + Tensor value = null; + tf_with(ops.name_scope("Read"), _ => + { + // TODO: `no_copy = true`. + value = _read_variable_op(); + }); + return array_ops.identity(value); + } } } diff --git a/src/TensorFlowNET.Core/Variables/IVariableV1.cs b/src/TensorFlowNET.Core/Variables/IVariableV1.cs index f4f716c3c..3eb78153a 100644 --- a/src/TensorFlowNET.Core/Variables/IVariableV1.cs +++ b/src/TensorFlowNET.Core/Variables/IVariableV1.cs @@ -46,6 +46,7 @@ public interface IVariableV1 Graph Graph { get; } TF_DataType dtype { get; } Shape shape { get; } + bool Trainable { get; } Tensor assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true); Tensor assign_sub(T delta, bool use_locking = false, string name = null, bool read_value = true); IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null); diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 67c12c427..7b08f3ea4 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -20,11 +20,12 @@ limitations under the License. using System.Collections.Generic; using System.Linq; using static Tensorflow.Binding; +using Tensorflow.Train; namespace Tensorflow { [Obsolete] - public partial class RefVariable : IVariableV1, IProtoBuf + public partial class RefVariable: Trackable, IVariableV1, IProtoBuf { protected string _name; public string UniqueId => _name; @@ -56,6 +57,7 @@ public partial class RefVariable : IVariableV1, IProtoBuf _variable.name; public Tensor eval() => _variable; + public bool Trainable => _trainable; public RefVariable(object initial_value = null, bool trainable = true, diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index b31960c73..1645d7130 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -17,7 +17,9 @@ limitations under the License. using Google.Protobuf; using System; using System.Collections.Generic; +using Tensorflow.Checkpoint; using Tensorflow.NumPy; +using Tensorflow.Train; using static Tensorflow.Binding; namespace Tensorflow @@ -39,6 +41,7 @@ public ResourceVariable(object initial_value = null, VariableAggregation aggregation = VariableAggregation.None, Shape shape = null) { + Aggregation = aggregation; if (variable_def != null) { if (initial_value != null) diff --git a/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs b/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs new file mode 100644 index 000000000..6c0349950 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs @@ -0,0 +1,70 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Gradients; +using static Tensorflow.Binding; + +namespace Tensorflow.Variables +{ + /// + /// A variable with no initializer. + /// + public sealed class UninitializedVariable: BaseResourceVariable + { + // TODO: complete the arg list. + public UninitializedVariable( + bool trainable = true, + string caching_device = "", + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid, + VariableAggregation aggregation = VariableAggregation.None, + Shape shape = null, + Tensor extra_handle_data = null) + { + string unique_id = ""; + string handle_name = ""; + tf_with(ops.init_scope(), (x) => + { + _in_graph_mode = !tf.Context.executing_eagerly(); + tf_with(ops.name_scope(name, "Variable", skip_on_eager: false), name => + { + handle_name = ops.name_from_scope_name(name); + string? shared_name; + if (_in_graph_mode) + { + shared_name = handle_name; + unique_id = shared_name; + } + else + { + unique_id = $"{handle_name}-{ops.uid()}"; + shared_name = null; + } + var handle = resource_variable_ops.variable_handle_from_shape_and_dtype( + shape, dtype, shared_name, name, _in_graph_mode, extra_handle_data); + // skip the assignment of `handle._parent_trackable` because of lack of API. + // skip the assignment of `handle._name` and `handle._unique_id` because of accessability. + + if (_in_graph_mode) + { + tf_with(ops.name_scope("Read"), _ => + { + tf.device(handle.Device); + var value = gen_resource_variable_ops.read_variable_op(handle, dtype); + // _maybe_set_handle_data(dtype, handle, value) + _graph_element = value; + }); + ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); + } + else + { + _graph_element = null; + } + }); + }); + _shape = shape; + _dtype = dtype; + base.__init__(trainable, handle, unique_id: unique_id, handle_name: handle_name); + } + } +} diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 95e8db577..bf5ae7bee 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -566,5 +566,23 @@ public static bool executing_eagerly_outside_functions() else throw new NotImplementedException(""); } + + public static bool inside_function() + { + return get_default_graph().building_function; + } + + public static void dismantle_graph(Graph graph) + { + + } + + public class NullContextManager: IDisposable + { + public void Dispose() + { + + } + } } } diff --git a/src/TensorFlowNET.Keras/Activations.cs b/src/TensorFlowNET.Keras/Activations.cs new file mode 100644 index 000000000..444c783e0 --- /dev/null +++ b/src/TensorFlowNET.Keras/Activations.cs @@ -0,0 +1,82 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Operations.Activation; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras +{ + public class Activations + { + private static Dictionary _nameActivationMap; + private static Dictionary _activationNameMap; + + private static Activation _linear = (features, name) => features; + private static Activation _relu = (features, name) + => tf.Context.ExecuteOp("Relu", name, new ExecuteOpArgs(features)); + private static Activation _sigmoid = (features, name) + => tf.Context.ExecuteOp("Sigmoid", name, new ExecuteOpArgs(features)); + private static Activation _softmax = (features, name) + => tf.Context.ExecuteOp("Softmax", name, new ExecuteOpArgs(features)); + private static Activation _tanh = (features, name) + => tf.Context.ExecuteOp("Tanh", name, new ExecuteOpArgs(features)); + + /// + /// Register the name-activation mapping in this static class. + /// + /// + /// + private static void RegisterActivation(string name, Activation activation) + { + _nameActivationMap[name] = activation; + _activationNameMap[activation] = name; + } + + static Activations() + { + _nameActivationMap = new Dictionary(); + _activationNameMap= new Dictionary(); + + RegisterActivation("relu", _relu); + RegisterActivation("linear", _linear); + RegisterActivation("sigmoid", _sigmoid); + RegisterActivation("softmax", _softmax); + RegisterActivation("tanh", _tanh); + } + + public Activation Linear => _linear; + + public Activation Relu => _relu; + + public Activation Sigmoid => _sigmoid; + + public Activation Softmax => _softmax; + + public Activation Tanh => _tanh; + + + public static Activation GetActivationByName(string name) + { + if (!_nameActivationMap.TryGetValue(name, out var res)) + { + throw new Exception($"Activation {name} not found"); + } + else + { + return res; + } + } + + public static string GetNameByActivation(Activation activation) + { + if(!_activationNameMap.TryGetValue(activation, out var name)) + { + throw new Exception($"Activation {activation} not found"); + } + else + { + return name; + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Activations/Activations.Linear.cs b/src/TensorFlowNET.Keras/Activations/Activations.Linear.cs deleted file mode 100644 index acd4de6e7..000000000 --- a/src/TensorFlowNET.Keras/Activations/Activations.Linear.cs +++ /dev/null @@ -1,10 +0,0 @@ -namespace Tensorflow.Keras -{ - public partial class Activations - { - /// - /// Linear activation function (pass-through). - /// - public Activation Linear = (features, name) => features; - } -} diff --git a/src/TensorFlowNET.Keras/Activations/Activations.Relu.cs b/src/TensorFlowNET.Keras/Activations/Activations.Relu.cs deleted file mode 100644 index dfebfb297..000000000 --- a/src/TensorFlowNET.Keras/Activations/Activations.Relu.cs +++ /dev/null @@ -1,10 +0,0 @@ -using static Tensorflow.Binding; - -namespace Tensorflow.Keras -{ - public partial class Activations - { - public Activation Relu = (features, name) - => tf.Context.ExecuteOp("Relu", name, new ExecuteOpArgs(features)); - } -} diff --git a/src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs b/src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs deleted file mode 100644 index ad900bdef..000000000 --- a/src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs +++ /dev/null @@ -1,11 +0,0 @@ -using System; -using static Tensorflow.Binding; - -namespace Tensorflow.Keras -{ - public partial class Activations - { - public Activation Sigmoid = (features, name) - => tf.Context.ExecuteOp("Sigmoid", name, new ExecuteOpArgs(features)); - } -} diff --git a/src/TensorFlowNET.Keras/Activations/Activations.Softmax.cs b/src/TensorFlowNET.Keras/Activations/Activations.Softmax.cs deleted file mode 100644 index 02d86acea..000000000 --- a/src/TensorFlowNET.Keras/Activations/Activations.Softmax.cs +++ /dev/null @@ -1,11 +0,0 @@ -using System; -using static Tensorflow.Binding; - -namespace Tensorflow.Keras -{ - public partial class Activations - { - public Activation Softmax = (features, name) - => tf.Context.ExecuteOp("Softmax", name, new ExecuteOpArgs(features)); - } -} diff --git a/src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs b/src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs deleted file mode 100644 index 33dc5ba62..000000000 --- a/src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs +++ /dev/null @@ -1,11 +0,0 @@ -using System; -using static Tensorflow.Binding; - -namespace Tensorflow.Keras -{ - public partial class Activations - { - public Activation Tanh = (features, name) - => tf.Context.ExecuteOp("Tanh", name, new ExecuteOpArgs(features)); - } -} diff --git a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs index 23c40fbff..3aeb3200d 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs @@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Engine { public partial class Functional { - public ModelConfig get_config() + public override IKerasConfig get_config() { return get_network_config(); } @@ -25,7 +25,7 @@ ModelConfig get_network_config() { Name = name }; - + var node_conversion_map = new Dictionary(); foreach (var layer in _self_tracked_trackables) { @@ -42,23 +42,26 @@ ModelConfig get_network_config() } var layer_configs = new List(); - foreach (var layer in _self_tracked_trackables) + using (SharedObjectSavingScope.Enter()) { - var filtered_inbound_nodes = new List(); - foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) + foreach (var layer in _self_tracked_trackables) { - var node_key = _make_node_key(layer.Name, original_node_index); - if (NetworkNodes.Contains(node_key) && !node.is_input) + var filtered_inbound_nodes = new List(); + foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) { - var node_data = node.serialize(_make_node_key, node_conversion_map); - filtered_inbound_nodes.append(node_data); + var node_key = _make_node_key(layer.Name, original_node_index); + if (NetworkNodes.Contains(node_key) && !node.is_input) + { + var node_data = node.serialize(_make_node_key, node_conversion_map); + filtered_inbound_nodes.append(node_data); + } } - } - var layer_config = generic_utils.serialize_keras_object(layer); - layer_config.Name = layer.Name; - layer_config.InboundNodes = filtered_inbound_nodes; - layer_configs.Add(layer_config); + var layer_config = generic_utils.serialize_layer_to_config(layer); + layer_config.Name = layer.Name; + layer_config.InboundNodes = filtered_inbound_nodes; + layer_configs.Add(layer_config); + } } config.Layers = layer_configs; diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index 09a31b948..44eaef534 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -2,7 +2,9 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Keras.Utils; +using Tensorflow.Train; using static Tensorflow.Binding; namespace Tensorflow.Keras.Engine @@ -20,6 +22,30 @@ public partial class Functional : Model Dictionary tensor_usage_count; + /// + /// Dictionary of layer dependencies to be included in the checkpoint. + /// + public IDictionary LayerCheckpointDependencies + { + get + { + int weight_layer_index = 0; + Dictionary dependencies = new(); + for(int i = 0; i < Layers.Count; i++) + { + var layer = Layers[i]; + var weights = layer.TrainableWeights.concat(layer.NonTrainableWeights).ToList(); + if(weights.Count > 0) + { + dependencies[$"layer_with_weights-{weight_layer_index}"] = layer; + weight_layer_index++; + } + dependencies[$"layer-{i}"] = layer; + } + return dependencies; + } + } + public Functional(Tensors inputs, Tensors outputs, string name = null) : base(new ModelArgs { @@ -44,6 +70,7 @@ protected void _init_graph_network(Tensors inputs, Tensors outputs) this.inputs = inputs; this.outputs = outputs; built = true; + _buildInputShape = inputs.shape; if (outputs.Any(x => x.KerasHistory == null)) base_layer_utils.create_keras_history(outputs); @@ -325,5 +352,28 @@ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? train return output_tensors; } + + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) + { + return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache)) + .ToDictionary(x => x.Key, x => x.Value); + } + + protected override void _init_set_name(string name, bool zero_based = true) + { + if (string.IsNullOrEmpty(name)) + { + string class_name = GetType().Name; + if (this.GetType() == typeof(Functional)) + { + class_name = "Model"; + } + this.name = base_layer_utils.unique_layer_name(generic_utils.to_snake_case(class_name), zero_based: zero_based); + } + else + { + this.name = name; + } + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs new file mode 100644 index 000000000..fc405d872 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs @@ -0,0 +1,32 @@ +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Engine; + +public abstract partial class Layer +{ + public virtual SavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); + + public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; + + public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata; + + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) + { + IDictionary children; + if (save_type == SaveType.SAVEDMODEL) + { + Debug.Assert(cache is not null); + children = TrackableSavedModelSaver.trackable_children(cache); + } + else + { + children = new Dictionary(); + } + + return children.Concat(base._trackable_children(save_type, cache)).ToDictionary(x => x.Key, x => x.Value); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index ba40b1a22..31b37d681 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -49,6 +49,8 @@ public abstract partial class Layer : AutoTrackable, ILayer public bool Built => built; public bool Trainable => args.Trainable; public TF_DataType DType => args.DType; + public bool AutoCast => args.Autocast; + public IRegularizer ActivityRegularizer => args.ActivityRegularizer; /// /// A stateful layer is a layer whose updates are run during inference too, @@ -59,6 +61,7 @@ public abstract partial class Layer : AutoTrackable, ILayer /// Provides information about which inputs are compatible with the layer. /// protected InputSpec inputSpec; + public InputSpec InputSpec => inputSpec; bool dynamic = true; public bool SupportsMasking { get; set; } protected List _trainable_weights; @@ -77,6 +80,8 @@ public abstract partial class Layer : AutoTrackable, ILayer protected bool computePreviousMask; protected List updates; public Shape BatchInputShape => args.BatchInputShape; + protected TensorShapeConfig _buildInputShape = null; + public TensorShapeConfig BuildInputShape => _buildInputShape; List inboundNodes; public List InboundNodes => inboundNodes; @@ -86,9 +91,29 @@ public abstract partial class Layer : AutoTrackable, ILayer ThreadLocal callContext = new ThreadLocal(); public CallContext CallContext => callContext.Value; - public Tensor[] input => inboundNodes[0].input_tensors; + public Tensor[] input + { + get + { + if(inboundNodes is not null && inboundNodes.Count > 0) + { + return inboundNodes[0].input_tensors; + } + return null; + } + } public Dictionary> NodesByDepth { get; set; } - public Shape OutputShape => inboundNodes[0].Outputs.shape; + public Shape OutputShape + { + get + { + if(inboundNodes is not null && inboundNodes.Count > 0) + { + return inboundNodes[0].Outputs.shape; + } + return null; + } + } protected List _self_tracked_trackables; public Layer(LayerArgs args) @@ -162,7 +187,7 @@ private Tensor compute_mask(Tensor inputs, Tensor mask = null) /// /// /// - /// + /// /// protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { @@ -201,6 +226,7 @@ protected void MaybeBuild(Tensors inputs) public virtual void build(Shape input_shape) { + _buildInputShape = input_shape; built = true; } @@ -286,7 +312,9 @@ public List weights } } - public virtual LayerArgs get_config() + public List Variables => weights; + + public virtual IKerasConfig get_config() => args; } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index bc2c2cea6..966853809 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -33,6 +33,11 @@ public History fit(NDArray x, NDArray y, int workers = 1, bool use_multiprocessing = false) { + if (x.dims[0] != y.dims[0]) + { + throw new InvalidArgumentError( + $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); + } int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); var train_x = x[new Slice(0, train_count)]; var train_y = y[new Slice(0, train_count)]; diff --git a/src/TensorFlowNET.Keras/Engine/Model.Save.cs b/src/TensorFlowNET.Keras/Engine/Model.Save.cs index c287309d4..a1e891f98 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Save.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Save.cs @@ -1,5 +1,8 @@ using System.Collections.Generic; +using Tensorflow.Functions; using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.ModelSaving; namespace Tensorflow.Keras.Engine @@ -18,9 +21,21 @@ public void save(string filepath, bool overwrite = true, bool include_optimizer = true, string save_format = "tf", - SaveOptions options = null) + SaveOptions? options = null, + ConcreteFunction? signatures = null, + bool save_traces = true) { - saver.save(this, filepath); + if (save_format != "tf") + { + saver.save(this, filepath); + } + else + { + using (SharedObjectSavingScope.Enter()) + { + KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); + } + } } } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs index 9bab9bd2f..dfe5b05f3 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -4,6 +4,8 @@ using Tensorflow.Keras.Engine.DataAdapters; using Tensorflow.Keras.Losses; using Tensorflow.Keras.Optimizers; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Train; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -34,6 +36,13 @@ public partial class Model : Layer, IModel IVariableV1 _predict_counter; bool _base_model_initialized; bool stop_training; + DataHandler data_handler; + + public OptimizerV2 Optimizer + { + get => optimizer; + set => optimizer = value; + } public Model(ModelArgs args) : base(args) @@ -101,5 +110,15 @@ public override List TrainableVariables return variables; } } + + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) + { + if(save_type == SaveType.SAVEDMODEL) + { + //TODO: deal with `train_function`, `test_function`, `predict_function`, `train_tf_function`. + } + var children = base._trackable_children(save_type, cache); + return children; + } } } diff --git a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs index 6e790a26f..45f64720f 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs @@ -25,6 +25,7 @@ public override void build(Shape input_shape) { throw new ValueError("Alpha must be a number greater than 0."); } + _buildInputShape = input_shape; built = true; } diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs index aba175de9..2fd2caee1 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs @@ -14,6 +14,7 @@ public Exponential(LayerArgs args) : base(args) } public override void build(Shape input_shape) { + _buildInputShape = input_shape; built = true; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) diff --git a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs index b12d7deec..1ef8d0e58 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs @@ -16,10 +16,11 @@ public SELU ( LayerArgs args ) : base(args) { // SELU has no arguments } public override void build(Shape input_shape) { - if ( alpha < 0f ) { - throw new ValueError("Alpha must be a number greater than 0."); - } - built = true; + if ( alpha < 0f ) { + throw new ValueError("Alpha must be a number greater than 0."); + } + _buildInputShape = input_shape; + built = true; } protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { Tensor output = inputs; diff --git a/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs b/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs index 6f6dd7e85..c51316308 100644 --- a/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs +++ b/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Layers { @@ -146,7 +147,7 @@ public override Tensor _calculate_scores(Tensor query, Tensor key) return scores; } - public override LayerArgs get_config() => this.args; + public override IKerasConfig get_config() => this.args; //var config = new Dictionary { // { // "use_scale", diff --git a/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs b/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs index 3f618b5db..1348e19cf 100644 --- a/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs +++ b/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Keras.Saving; /// /// Base class for attention layers that can be used in sequence DNN/CNN models. @@ -252,6 +253,6 @@ public static Tensor _merge_masks(Tensor x, Tensor y) return tf.logical_and(x, y); } - public override LayerArgs get_config() => this.args; + public override IKerasConfig get_config() => this.args; } } diff --git a/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs b/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs index 1b82e0a96..701724d5b 100644 --- a/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs +++ b/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs @@ -1,4 +1,5 @@ using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.ArgsDefinition.Core; using Tensorflow.Keras.Engine; using Tensorflow.NumPy; using static Tensorflow.Binding; diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs index e0a337caa..b8286be67 100644 --- a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs @@ -49,6 +49,7 @@ public override void build(Shape input_shape) initializer: bias_initializer, trainable: true); built = true; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs index 912a429b7..933aa9cf1 100644 --- a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs @@ -98,6 +98,7 @@ public override void build(Shape input_shape) name: tf_op_name); built = true; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = false) diff --git a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs index e4c227456..ca8007d09 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs @@ -43,6 +43,7 @@ public Dense(DenseArgs args) : public override void build(Shape input_shape) { + _buildInputShape = input_shape; var last_dim = input_shape.dims.Last(); var axes = new Dictionary(); axes[-1] = (int)last_dim; diff --git a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs index 0f387570b..af71ddf9f 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs @@ -4,8 +4,8 @@ using System.Collections.Generic; using System.Linq; using System.Text.RegularExpressions; -using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.ArgsDefinition.Core; namespace Tensorflow.Keras.Layers { diff --git a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs index 79f4e5ce9..606f387bb 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs @@ -62,6 +62,7 @@ public override void build(Shape input_shape) name: "embeddings"); tf.Context.graph_mode(); built = true; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) diff --git a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs index 6b064716f..03b4b742a 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs @@ -18,6 +18,7 @@ limitations under the License. using Tensorflow.Framework.Models; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving.SavedModel; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -105,5 +106,7 @@ public static InputLayer from_config(LayerArgs args) { return new InputLayer(args as InputLayerArgs); } + + public override SavedModelSaver TrackableSavedModelSaver => new InputLayerSavedModelSaver(this); } } diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs deleted file mode 100644 index 6cb03e1e0..000000000 --- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs +++ /dev/null @@ -1,113 +0,0 @@ -using Tensorflow.Keras.ArgsDefinition; -using Tensorflow.Keras.Engine; - -namespace Tensorflow.Keras.Layers { - /// - /// Crop the input along axis 1 and 2. - /// For example: - /// shape (1, 5, 5, 5) -- crop2D ((1, 2), (1, 3)) --> shape (1, 2, 1, 5) - /// - public class Cropping2D : Layer { - Cropping2DArgs args; - public Cropping2D ( Cropping2DArgs args ) : base(args) { - this.args = args; - } - public override void build(Shape input_shape) { - built = true; - } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { - Tensor output = inputs; - if ( output.rank != 4 ) { - // throw an ValueError exception - throw new ValueError("Expected dim=4, found dim=" + output.rank); - } - if ( args.cropping.shape == new Shape(1) ) { - int crop = args.cropping[0]; - if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) { - output = output[new Slice(), - new Slice(crop, ( int ) output.shape[1] - crop), - new Slice(crop, ( int ) output.shape[2] - crop), - new Slice()]; - } - else { - output = output[new Slice(), - new Slice(), - new Slice(crop, ( int ) output.shape[2] - crop), - new Slice(crop, ( int ) output.shape[3] - crop)]; - } - } - // a tuple of 2 integers - else if ( args.cropping.shape == new Shape(2) ) { - int crop_1 = args.cropping[0]; - int crop_2 = args.cropping[1]; - if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) { - output = output[new Slice(), - new Slice(crop_1, ( int ) output.shape[1] - crop_1), - new Slice(crop_2, ( int ) output.shape[2] - crop_2), - new Slice()]; - } - else { - output = output[new Slice(), - new Slice(), - new Slice(crop_1, ( int ) output.shape[2] - crop_1), - new Slice(crop_2, ( int ) output.shape[3] - crop_2)]; - } - } - else if ( args.cropping.shape[0] == 2 && args.cropping.shape[1] == 2 ) { - int x_start = args.cropping[0, 0], x_end = args.cropping[0, 1]; - int y_start = args.cropping[1, 0], y_end = args.cropping[1, 1]; - if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) { - output = output[new Slice(), - new Slice(x_start, ( int ) output.shape[1] - x_end), - new Slice(y_start, ( int ) output.shape[2] - y_end), - new Slice()]; - } - else { - output = output[new Slice(), - new Slice(), - new Slice(x_start, ( int ) output.shape[2] - x_end), - new Slice(y_start, ( int ) output.shape[3] - y_end) - ]; - } - } - return output; - } - - public override Shape ComputeOutputShape ( Shape input_shape ) { - if ( args.cropping.shape == new Shape(1) ) { - int crop = args.cropping[0]; - if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1] - crop * 2, ( int ) input_shape[2] - crop * 2, ( int ) input_shape[3]); - } - else { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1], ( int ) input_shape[2] - crop * 2, ( int ) input_shape[3] - crop * 2); - } - } - // a tuple of 2 integers - else if ( args.cropping.shape == new Shape(2) ) { - int crop_1 = args.cropping[0], crop_2 = args.cropping[1]; - if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1] - crop_1 * 2, ( int ) input_shape[2] - crop_2 * 2, ( int ) input_shape[3]); - } - else { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1], ( int ) input_shape[2] - crop_1 * 2, ( int ) input_shape[3] - crop_2 * 2); - } - } - else if ( args.cropping.shape == new Shape(2, 2) ) { - int crop_1_start = args.cropping[0, 0], crop_1_end = args.cropping[0, 1]; - int crop_2_start = args.cropping[1, 0], crop_2_end = args.cropping[1, 1]; - if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1] - crop_1_start - crop_1_end, - ( int ) input_shape[2] - crop_2_start - crop_2_end, ( int ) input_shape[3]); - } - else { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1], - ( int ) input_shape[2] - crop_1_start - crop_1_end, ( int ) input_shape[3] - crop_2_start - crop_2_end); - } - } - else { - throw new ValueError(); - } - } - } -} diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs deleted file mode 100644 index 2d6751bf9..000000000 --- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs +++ /dev/null @@ -1,123 +0,0 @@ -using Tensorflow.Keras.ArgsDefinition; -using Tensorflow.Keras.Engine; - -namespace Tensorflow.Keras.Layers { - /// - /// Similar to copping 2D - /// - public class Cropping3D : Layer { - Cropping3DArgs args; - public Cropping3D ( Cropping3DArgs args ) : base(args) { - this.args = args; - } - - public override void build(Shape input_shape) { - built = true; - } - - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { - Tensor output = inputs; - if ( output.rank != 5 ) { - // throw an ValueError exception - throw new ValueError("Expected dim=5, found dim=" + output.rank); - } - - if ( args.cropping.shape == new Shape(1) ) { - int crop = args.cropping[0]; - if ( args.data_format == Cropping3DArgs.DataFormat.channels_last ) { - output = output[new Slice(), - new Slice(crop, ( int ) output.shape[1] - crop), - new Slice(crop, ( int ) output.shape[2] - crop), - new Slice(crop, ( int ) output.shape[3] - crop), - new Slice()]; - } - else { - output = output[new Slice(), - new Slice(), - new Slice(crop, ( int ) output.shape[2] - crop), - new Slice(crop, ( int ) output.shape[3] - crop), - new Slice(crop, ( int ) output.shape[4] - crop)]; - } - - } - // int[1][3] equivalent to a tuple of 3 integers - else if ( args.cropping.shape == new Shape(3) ) { - var crop_1 = args.cropping[0]; - var crop_2 = args.cropping[1]; - var crop_3 = args.cropping[2]; - if ( args.data_format == Cropping3DArgs.DataFormat.channels_last ) { - output = output[new Slice(), - new Slice(crop_1, ( int ) output.shape[1] - crop_1), - new Slice(crop_2, ( int ) output.shape[2] - crop_2), - new Slice(crop_3, ( int ) output.shape[3] - crop_3), - new Slice()]; - } - else { - output = output[new Slice(), - new Slice(), - new Slice(crop_1, ( int ) output.shape[2] - crop_1), - new Slice(crop_2, ( int ) output.shape[3] - crop_2), - new Slice(crop_3, ( int ) output.shape[4] - crop_3)]; - } - } - else if ( args.cropping.shape[0] == 3 && args.cropping.shape[1] == 2 ) { - int x = args.cropping[0, 0], x_end = args.cropping[0, 1]; - int y = args.cropping[1, 0], y_end = args.cropping[1, 1]; - int z = args.cropping[2, 0], z_end = args.cropping[2, 1]; - if ( args.data_format == Cropping3DArgs.DataFormat.channels_last ) { - output = output[new Slice(), - new Slice(x, ( int ) output.shape[1] - x_end), - new Slice(y, ( int ) output.shape[2] - y_end), - new Slice(z, ( int ) output.shape[3] - z_end), - new Slice()]; - } - else { - output = output[new Slice(), - new Slice(), - new Slice(x, ( int ) output.shape[2] - x_end), - new Slice(y, ( int ) output.shape[3] - y_end), - new Slice(z, ( int ) output.shape[4] - z_end) - ]; - } - } - return output; - } - public override Shape ComputeOutputShape ( Shape input_shape ) { - if ( args.cropping.shape == new Shape(1) ) { - int crop = args.cropping[0]; - if ( args.data_format == Cropping3DArgs.DataFormat.channels_last ) { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1] - crop * 2, ( int ) input_shape[2] - crop * 2, ( int ) input_shape[3] - crop * 2, ( int ) input_shape[4]); - } - else { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1], ( int ) input_shape[2] - crop * 2, ( int ) input_shape[3] - crop * 2, ( int ) input_shape[4] - crop * 2); - } - } - // int[1][3] equivalent to a tuple of 3 integers - else if ( args.cropping.shape == new Shape(3) ) { - var crop_start_1 = args.cropping[0]; - var crop_start_2 = args.cropping[1]; - var crop_start_3 = args.cropping[2]; - if ( args.data_format == Cropping3DArgs.DataFormat.channels_last ) { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1] - crop_start_1 * 2, ( int ) input_shape[2] - crop_start_2 * 2, ( int ) input_shape[3] - crop_start_3 * 2, ( int ) input_shape[4]); - } - else { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1], ( int ) input_shape[2] - crop_start_1 * 2, ( int ) input_shape[3] - crop_start_2 * 2, ( int ) input_shape[4] - crop_start_3 * 2); - } - } - else if ( args.cropping.shape == new Shape(3, 2) ) { - int x = args.cropping[0, 0], x_end = args.cropping[0, 1]; - int y = args.cropping[1, 0], y_end = args.cropping[1, 1]; - int z = args.cropping[2, 0], z_end = args.cropping[2, 1]; - if ( args.data_format == Cropping3DArgs.DataFormat.channels_last ) { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1] - x - x_end, ( int ) input_shape[2] - y - y_end, ( int ) input_shape[3] - z - z_end, ( int ) input_shape[4]); - } - else { - return new Shape(( int ) input_shape[0], ( int ) input_shape[1], ( int ) input_shape[2] - x - x_end, ( int ) input_shape[3] - y - y_end, ( int ) input_shape[4] - z - z_end); - } - } - else { - throw new ValueError(); - } - } - } -} diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.Cropping.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.Cropping.cs index 339ddb85b..3e3442f25 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.Cropping.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.Cropping.cs @@ -2,16 +2,18 @@ using System; using System.Collections.Generic; using System.Text; -using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Layers.Reshaping; +using Tensorflow.Keras.ArgsDefinition.Reshaping; -namespace Tensorflow.Keras.Layers { - public partial class LayersApi { +namespace Tensorflow.Keras.Layers +{ + public partial class LayersApi { /// /// Cropping layer for 1D input /// /// cropping size public ILayer Cropping1D ( NDArray cropping ) - => new Cropping1D(new CroppingArgs { + => new Cropping1D(new Cropping1DArgs { cropping = cropping }); diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 5c1c8995d..76634918d 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -1,9 +1,8 @@ using System; using Tensorflow.Keras.ArgsDefinition; -using Tensorflow.Keras.ArgsDefinition.Lstm; +using Tensorflow.Keras.ArgsDefinition.Core; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; -using Tensorflow.Keras.Layers.Lstm; using Tensorflow.Keras.Layers.Rnn; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -108,7 +107,7 @@ public ILayer Conv1D(int filters, DilationRate = dilation_rate, Groups = groups, UseBias = use_bias, - Activation = GetActivationByName(activation), + Activation = Activations.GetActivationByName(activation), KernelInitializer = GetInitializerByName(kernel_initializer), BiasInitializer = GetInitializerByName(bias_initializer) }); @@ -163,7 +162,7 @@ public ILayer Conv2D(int filters, BiasInitializer = bias_initializer == null ? tf.zeros_initializer : bias_initializer, BiasRegularizer = bias_regularizer, ActivityRegularizer = activity_regularizer, - Activation = activation ?? keras.activations.Linear + Activation = activation ?? keras.activations.Linear, }); /// @@ -210,7 +209,8 @@ public ILayer Conv2D(int filters, UseBias = use_bias, KernelInitializer = GetInitializerByName(kernel_initializer), BiasInitializer = GetInitializerByName(bias_initializer), - Activation = GetActivationByName(activation) + Activation = Activations.GetActivationByName(activation), + ActivationName = activation }); /// @@ -255,7 +255,7 @@ public ILayer Conv2DTranspose(int filters, UseBias = use_bias, KernelInitializer = GetInitializerByName(kernel_initializer), BiasInitializer = GetInitializerByName(bias_initializer), - Activation = GetActivationByName(activation) + Activation = Activations.GetActivationByName(activation) }); /// @@ -300,7 +300,8 @@ public ILayer Dense(int units) => new Dense(new DenseArgs { Units = units, - Activation = GetActivationByName("linear") + Activation = Activations.GetActivationByName("linear"), + ActivationName = "linear" }); /// @@ -320,7 +321,8 @@ public ILayer Dense(int units, => new Dense(new DenseArgs { Units = units, - Activation = GetActivationByName(activation), + Activation = Activations.GetActivationByName(activation), + ActivationName = activation, InputShape = input_shape }); @@ -664,7 +666,7 @@ public ILayer SimpleRNN(int units, => new SimpleRNN(new SimpleRNNArgs { Units = units, - Activation = GetActivationByName(activation), + Activation = Activations.GetActivationByName(activation), KernelInitializer = GetInitializerByName(kernel_initializer), RecurrentInitializer = GetInitializerByName(recurrent_initializer), BiasInitializer = GetInitializerByName(bias_initializer), @@ -812,24 +814,7 @@ public ILayer GlobalMaxPooling1D(string data_format = "channels_last") public ILayer GlobalMaxPooling2D(string data_format = "channels_last") => new GlobalMaxPooling2D(new Pooling2DArgs { DataFormat = data_format }); - - /// - /// Get an activation function layer from its name. - /// - /// The name of the activation function. One of linear, relu, sigmoid, and tanh. - /// - - Activation GetActivationByName(string name) - => name switch - { - "linear" => keras.activations.Linear, - "relu" => keras.activations.Relu, - "sigmoid" => keras.activations.Sigmoid, - "tanh" => keras.activations.Tanh, - "softmax" => keras.activations.Softmax, - _ => throw new Exception($"Activation {name} not found") - }; - + Activation GetActivationByName(string name) => Activations.GetActivationByName(name); /// /// Get an weights initializer from its name. /// diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs index 5f8217604..da7e857a2 100644 --- a/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs +++ b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs @@ -37,6 +37,7 @@ public override void build(Shape input_shape) }).ToArray(); shape_set.Add(shape); }*/ + _buildInputShape = input_shape; } protected override Tensors _merge_function(Tensors inputs) diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs index 0363d58f4..3cd43af92 100644 --- a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs +++ b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs @@ -17,6 +17,7 @@ public Merge(MergeArgs args) : base(args) public override void build(Shape input_shape) { // output_shape = input_shape.dims[1^]; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs index dac92f812..3b8e1ee8d 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs @@ -58,7 +58,7 @@ public override void build(Shape input_shape) var ndims = input_shape.ndim; foreach (var (idx, x) in enumerate(axis)) if (x < 0) - axis[idx] = ndims + x; + args.Axis.dims[idx] = axis[idx] = ndims + x; fused = ndims == 4; @@ -118,6 +118,7 @@ public override void build(Shape input_shape) throw new NotImplementedException("build when renorm is true"); built = true; + _buildInputShape = input_shape; } public override Shape ComputeOutputShape(Shape input_shape) diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs index 5eebd7350..e19b9c30e 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs @@ -81,6 +81,7 @@ public override void build(Shape input_shape) _fused = _fused_can_be_used(ndims); built = true; + _buildInputShape = input_shape; } bool _fused_can_be_used(int ndims) diff --git a/src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs similarity index 100% rename from src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs rename to src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs similarity index 77% rename from src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs rename to src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs index 45f5bf0f6..10c15b698 100644 --- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs @@ -1,11 +1,12 @@ -using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.ArgsDefinition.Reshaping; using Tensorflow.Keras.Engine; -namespace Tensorflow.Keras.Layers { +namespace Tensorflow.Keras.Layers.Reshaping +{ public class Cropping1D : Layer { - CroppingArgs args; - public Cropping1D(CroppingArgs args) : base(args) + Cropping1DArgs args; + public Cropping1D(Cropping1DArgs args) : base(args) { this.args = args; } @@ -22,6 +23,7 @@ public override void build(Shape input_shape) throw new ValueError("The `cropping` argument must be a tuple of 2 integers."); } built = true; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) @@ -40,7 +42,7 @@ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? train else { int crop_start = args.cropping[0], crop_end = args.cropping[1]; - output = output[new Slice(), new Slice(crop_start, (int)(output.shape[1]) - crop_end), new Slice()]; + output = output[new Slice(), new Slice(crop_start, (int)output.shape[1] - crop_end), new Slice()]; } return output; } @@ -50,12 +52,12 @@ public override Shape ComputeOutputShape(Shape input_shape) if (args.cropping.shape[0] == 1) { int crop = args.cropping[0]; - return new Shape((int)(input_shape[0]), (int)(input_shape[1] - crop * 2), (int)(input_shape[2])); + return new Shape((int)input_shape[0], (int)(input_shape[1] - crop * 2), (int)input_shape[2]); } else { int crop_start = args.cropping[0], crop_end = args.cropping[1]; - return new Shape((int)(input_shape[0]), (int)(input_shape[1] - crop_start - crop_end), (int)(input_shape[2])); + return new Shape((int)input_shape[0], (int)(input_shape[1] - crop_start - crop_end), (int)input_shape[2]); } } } diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs new file mode 100644 index 000000000..a8d7043ed --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs @@ -0,0 +1,140 @@ +using Tensorflow.Keras.ArgsDefinition.Reshaping; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Layers.Reshaping +{ + /// + /// Crop the input along axis 1 and 2. + /// For example: + /// shape (1, 5, 5, 5) -- crop2D ((1, 2), (1, 3)) --> shape (1, 2, 1, 5) + /// + public class Cropping2D : Layer + { + Cropping2DArgs args; + public Cropping2D(Cropping2DArgs args) : base(args) + { + this.args = args; + } + public override void build(Shape input_shape) + { + built = true; + _buildInputShape = input_shape; + } + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + { + Tensor output = inputs; + if (output.rank != 4) + { + // throw an ValueError exception + throw new ValueError("Expected dim=4, found dim=" + output.rank); + } + if (args.cropping.shape == new Shape(1)) + { + int crop = args.cropping[0]; + if (args.data_format == Cropping2DArgs.DataFormat.channels_last) + { + output = output[new Slice(), + new Slice(crop, (int)output.shape[1] - crop), + new Slice(crop, (int)output.shape[2] - crop), + new Slice()]; + } + else + { + output = output[new Slice(), + new Slice(), + new Slice(crop, (int)output.shape[2] - crop), + new Slice(crop, (int)output.shape[3] - crop)]; + } + } + // a tuple of 2 integers + else if (args.cropping.shape == new Shape(2)) + { + int crop_1 = args.cropping[0]; + int crop_2 = args.cropping[1]; + if (args.data_format == Cropping2DArgs.DataFormat.channels_last) + { + output = output[new Slice(), + new Slice(crop_1, (int)output.shape[1] - crop_1), + new Slice(crop_2, (int)output.shape[2] - crop_2), + new Slice()]; + } + else + { + output = output[new Slice(), + new Slice(), + new Slice(crop_1, (int)output.shape[2] - crop_1), + new Slice(crop_2, (int)output.shape[3] - crop_2)]; + } + } + else if (args.cropping.shape[0] == 2 && args.cropping.shape[1] == 2) + { + int x_start = args.cropping[0, 0], x_end = args.cropping[0, 1]; + int y_start = args.cropping[1, 0], y_end = args.cropping[1, 1]; + if (args.data_format == Cropping2DArgs.DataFormat.channels_last) + { + output = output[new Slice(), + new Slice(x_start, (int)output.shape[1] - x_end), + new Slice(y_start, (int)output.shape[2] - y_end), + new Slice()]; + } + else + { + output = output[new Slice(), + new Slice(), + new Slice(x_start, (int)output.shape[2] - x_end), + new Slice(y_start, (int)output.shape[3] - y_end) + ]; + } + } + return output; + } + + public override Shape ComputeOutputShape(Shape input_shape) + { + if (args.cropping.shape == new Shape(1)) + { + int crop = args.cropping[0]; + if (args.data_format == Cropping2DArgs.DataFormat.channels_last) + { + return new Shape((int)input_shape[0], (int)input_shape[1] - crop * 2, (int)input_shape[2] - crop * 2, (int)input_shape[3]); + } + else + { + return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2); + } + } + // a tuple of 2 integers + else if (args.cropping.shape == new Shape(2)) + { + int crop_1 = args.cropping[0], crop_2 = args.cropping[1]; + if (args.data_format == Cropping2DArgs.DataFormat.channels_last) + { + return new Shape((int)input_shape[0], (int)input_shape[1] - crop_1 * 2, (int)input_shape[2] - crop_2 * 2, (int)input_shape[3]); + } + else + { + return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop_1 * 2, (int)input_shape[3] - crop_2 * 2); + } + } + else if (args.cropping.shape == new Shape(2, 2)) + { + int crop_1_start = args.cropping[0, 0], crop_1_end = args.cropping[0, 1]; + int crop_2_start = args.cropping[1, 0], crop_2_end = args.cropping[1, 1]; + if (args.data_format == Cropping2DArgs.DataFormat.channels_last) + { + return new Shape((int)input_shape[0], (int)input_shape[1] - crop_1_start - crop_1_end, + (int)input_shape[2] - crop_2_start - crop_2_end, (int)input_shape[3]); + } + else + { + return new Shape((int)input_shape[0], (int)input_shape[1], + (int)input_shape[2] - crop_1_start - crop_1_end, (int)input_shape[3] - crop_2_start - crop_2_end); + } + } + else + { + throw new ValueError(); + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs new file mode 100644 index 000000000..796c2dd33 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs @@ -0,0 +1,150 @@ +using Tensorflow.Keras.ArgsDefinition.Reshaping; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Layers.Reshaping +{ + /// + /// Similar to copping 2D + /// + public class Cropping3D : Layer + { + Cropping3DArgs args; + public Cropping3D(Cropping3DArgs args) : base(args) + { + this.args = args; + } + + public override void build(Shape input_shape) + { + built = true; + _buildInputShape = input_shape; + } + + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + { + Tensor output = inputs; + if (output.rank != 5) + { + // throw an ValueError exception + throw new ValueError("Expected dim=5, found dim=" + output.rank); + } + + if (args.cropping.shape == new Shape(1)) + { + int crop = args.cropping[0]; + if (args.data_format == Cropping3DArgs.DataFormat.channels_last) + { + output = output[new Slice(), + new Slice(crop, (int)output.shape[1] - crop), + new Slice(crop, (int)output.shape[2] - crop), + new Slice(crop, (int)output.shape[3] - crop), + new Slice()]; + } + else + { + output = output[new Slice(), + new Slice(), + new Slice(crop, (int)output.shape[2] - crop), + new Slice(crop, (int)output.shape[3] - crop), + new Slice(crop, (int)output.shape[4] - crop)]; + } + + } + // int[1][3] equivalent to a tuple of 3 integers + else if (args.cropping.shape == new Shape(3)) + { + var crop_1 = args.cropping[0]; + var crop_2 = args.cropping[1]; + var crop_3 = args.cropping[2]; + if (args.data_format == Cropping3DArgs.DataFormat.channels_last) + { + output = output[new Slice(), + new Slice(crop_1, (int)output.shape[1] - crop_1), + new Slice(crop_2, (int)output.shape[2] - crop_2), + new Slice(crop_3, (int)output.shape[3] - crop_3), + new Slice()]; + } + else + { + output = output[new Slice(), + new Slice(), + new Slice(crop_1, (int)output.shape[2] - crop_1), + new Slice(crop_2, (int)output.shape[3] - crop_2), + new Slice(crop_3, (int)output.shape[4] - crop_3)]; + } + } + else if (args.cropping.shape[0] == 3 && args.cropping.shape[1] == 2) + { + int x = args.cropping[0, 0], x_end = args.cropping[0, 1]; + int y = args.cropping[1, 0], y_end = args.cropping[1, 1]; + int z = args.cropping[2, 0], z_end = args.cropping[2, 1]; + if (args.data_format == Cropping3DArgs.DataFormat.channels_last) + { + output = output[new Slice(), + new Slice(x, (int)output.shape[1] - x_end), + new Slice(y, (int)output.shape[2] - y_end), + new Slice(z, (int)output.shape[3] - z_end), + new Slice()]; + } + else + { + output = output[new Slice(), + new Slice(), + new Slice(x, (int)output.shape[2] - x_end), + new Slice(y, (int)output.shape[3] - y_end), + new Slice(z, (int)output.shape[4] - z_end) + ]; + } + } + return output; + } + public override Shape ComputeOutputShape(Shape input_shape) + { + if (args.cropping.shape == new Shape(1)) + { + int crop = args.cropping[0]; + if (args.data_format == Cropping3DArgs.DataFormat.channels_last) + { + return new Shape((int)input_shape[0], (int)input_shape[1] - crop * 2, (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2, (int)input_shape[4]); + } + else + { + return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2, (int)input_shape[4] - crop * 2); + } + } + // int[1][3] equivalent to a tuple of 3 integers + else if (args.cropping.shape == new Shape(3)) + { + var crop_start_1 = args.cropping[0]; + var crop_start_2 = args.cropping[1]; + var crop_start_3 = args.cropping[2]; + if (args.data_format == Cropping3DArgs.DataFormat.channels_last) + { + return new Shape((int)input_shape[0], (int)input_shape[1] - crop_start_1 * 2, (int)input_shape[2] - crop_start_2 * 2, (int)input_shape[3] - crop_start_3 * 2, (int)input_shape[4]); + } + else + { + return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop_start_1 * 2, (int)input_shape[3] - crop_start_2 * 2, (int)input_shape[4] - crop_start_3 * 2); + } + } + else if (args.cropping.shape == new Shape(3, 2)) + { + int x = args.cropping[0, 0], x_end = args.cropping[0, 1]; + int y = args.cropping[1, 0], y_end = args.cropping[1, 1]; + int z = args.cropping[2, 0], z_end = args.cropping[2, 1]; + if (args.data_format == Cropping3DArgs.DataFormat.channels_last) + { + return new Shape((int)input_shape[0], (int)input_shape[1] - x - x_end, (int)input_shape[2] - y - y_end, (int)input_shape[3] - z - z_end, (int)input_shape[4]); + } + else + { + return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - x - x_end, (int)input_shape[3] - y - y_end, (int)input_shape[4] - z - z_end); + } + } + else + { + throw new ValueError(); + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs index 868506b6b..8e7a19a9a 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs @@ -24,6 +24,7 @@ public override void build(Shape input_shape) permute = new int[input_shape.rank]; dims.CopyTo(permute, 1); built = true; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { diff --git a/src/TensorFlowNET.Keras/Layers/Lstm/LSTM.cs b/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs similarity index 87% rename from src/TensorFlowNET.Keras/Layers/Lstm/LSTM.cs rename to src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs index b7d973847..59555e62b 100644 --- a/src/TensorFlowNET.Keras/Layers/Lstm/LSTM.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs @@ -1,9 +1,8 @@ using System.Linq; -using Tensorflow.Keras.ArgsDefinition.Lstm; +using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; -using Tensorflow.Keras.Layers.Rnn; -namespace Tensorflow.Keras.Layers.Lstm +namespace Tensorflow.Keras.Layers.Rnn { /// /// Long Short-Term Memory layer - Hochreiter 1997. diff --git a/src/TensorFlowNET.Keras/Layers/Lstm/LSTMCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs similarity index 72% rename from src/TensorFlowNET.Keras/Layers/Lstm/LSTMCell.cs rename to src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs index 3cd35a091..a622c91a9 100644 --- a/src/TensorFlowNET.Keras/Layers/Lstm/LSTMCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs @@ -1,7 +1,7 @@ -using Tensorflow.Keras.ArgsDefinition.Lstm; +using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; -namespace Tensorflow.Keras.Layers.Lstm +namespace Tensorflow.Keras.Layers.Rnn { public class LSTMCell : Layer { diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs index 877c35994..6b755ecee 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs @@ -3,7 +3,6 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; -using Tensorflow.Keras.Layers.Lstm; // from tensorflow.python.distribute import distribution_strategy_context as ds_context; namespace Tensorflow.Keras.Layers.Rnn diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs index a3cd002d9..19669b4b9 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs @@ -12,7 +12,19 @@ public class SimpleRNN : RNN public SimpleRNN(SimpleRNNArgs args) : base(args) { this.args = args; - cell = new SimpleRNNCell(args); + } + + public override void build(Shape input_shape) + { + var input_dim = input_shape[-1]; + _buildInputShape = input_shape; + + kernel = add_weight("kernel", (input_shape[-1], args.Units), + initializer: args.KernelInitializer + //regularizer = self.kernel_regularizer, + //constraint = self.kernel_constraint, + //caching_device = default_caching_device, + ); } } } \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs index eead274a1..20962df1f 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs @@ -4,6 +4,7 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Layers.Rnn { @@ -136,7 +137,7 @@ public void build() // self.built = True } - public override LayerArgs get_config() + public override IKerasConfig get_config() { throw new NotImplementedException(); //def get_config(self): diff --git a/src/TensorFlowNET.Keras/Protobuf/SavedMetadata.cs b/src/TensorFlowNET.Keras/Protobuf/SavedMetadata.cs index 61cec6468..f29f2dec3 100644 --- a/src/TensorFlowNET.Keras/Protobuf/SavedMetadata.cs +++ b/src/TensorFlowNET.Keras/Protobuf/SavedMetadata.cs @@ -194,6 +194,18 @@ public SavedObject() { OnConstruction(); } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SavedObject(int nodeId, string nodePath, + global::ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef version, string identifier, string metadata) + { + OnConstruction(); + nodeId_ = nodeId; + nodePath_ = nodePath; + identifier_ = identifier; + metadata_ = metadata; + version_ = version; + } + partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] diff --git a/src/TensorFlowNET.Keras/Protobuf/Versions.cs b/src/TensorFlowNET.Keras/Protobuf/Versions.cs index 40405a5a6..ff9a23c62 100644 --- a/src/TensorFlowNET.Keras/Protobuf/Versions.cs +++ b/src/TensorFlowNET.Keras/Protobuf/Versions.cs @@ -74,6 +74,13 @@ public sealed partial class VersionDef : pb::IMessage { public VersionDef() { OnConstruction(); } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public VersionDef(int producer, int minConsumer) { + OnConstruction(); + producer_ = producer; + minConsumer_ = minConsumer; + } partial void OnConstruction(); diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs new file mode 100644 index 000000000..3ea4f067e --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs @@ -0,0 +1,41 @@ +using System.Collections.Generic; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public static class Constants +{ + /// + /// Namespace used to store all attributes added during serialization. + /// e.g. the list of layers can be accessed using `loaded.keras_api.layers`, in an + /// object loaded from `tf.saved_model.load()`. + /// + public static readonly string KERAS_ATTR = "keras_api"; + /// + /// Keys for the serialization cache. + /// Maps to the keras serialization dict {Layer --> SerializedAttributes object} + /// + public static readonly string KERAS_CACHE_KEY = "keras_serialized_attributes"; + /// + /// Name of Keras metadata file stored in the SavedModel. + /// + public static readonly string SAVED_METADATA_PATH = "keras_metadata.pb"; + + public static readonly string INPUT_LAYER_IDENTIFIER = "_tf_keras_input_layer"; + public static readonly string LAYER_IDENTIFIER = "_tf_keras_layer"; + public static readonly string METRIC_IDENTIFIER = "_tf_keras_metric"; + public static readonly string MODEL_IDENTIFIER = "_tf_keras_model"; + public static readonly string NETWORK_IDENTIFIER = "_tf_keras_network"; + public static readonly string RNN_LAYER_IDENTIFIER = "_tf_keras_rnn_layer"; + public static readonly string SEQUENTIAL_IDENTIFIER = "_tf_keras_sequential"; + + public static readonly IList KERAS_OBJECT_IDENTIFIERS = new List() + { + INPUT_LAYER_IDENTIFIER, + LAYER_IDENTIFIER, + METRIC_IDENTIFIER, + MODEL_IDENTIFIER, + NETWORK_IDENTIFIER, + RNN_LAYER_IDENTIFIER, + SEQUENTIAL_IDENTIFIER + }; +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs new file mode 100644 index 000000000..c7b7e52f4 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -0,0 +1,162 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Google.Protobuf; +using Tensorflow.Functions; +using Tensorflow.Keras.Engine; +using Tensorflow.ModelSaving; +using Tensorflow.Train; +using Tensorflow.Keras.Optimizers; +using ThirdParty.Tensorflow.Python.Keras.Protobuf; +using static Tensorflow.Binding; +using Tensorflow.Training; + + +namespace Tensorflow.Keras.Saving.SavedModel; + +public partial class KerasSavedModelUtils +{ + public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures, + SaveOptions? options, bool save_traces = true) + { + if (!overwrite && File.Exists(filepath)) + { + throw new Exception("The file already exists but is not allowed to overwrite it."); + } + + if (save_traces) + { + if(should_skip_serialization(model)) + { + throw new NotImplementedException(); + } + } + + OptimizerV2? orig_optimizer = null; + if (!include_optimizer) + { + orig_optimizer = model.Optimizer; + model.Optimizer = null; + model._delete_tracking("optimizer"); + } + + IList saved_nodes; + IDictionary> node_paths; + // skip two scopes of python + using (KerasSavedModelUtils.keras_option_scope(save_traces)) + { + (saved_nodes, node_paths) = Tensorflow.SavedModelUtils.save_and_return_nodes(model, filepath, signatures, options); + } + + var metadata = generate_keras_metadata(saved_nodes, node_paths); + File.WriteAllBytes(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToByteArray()); + //File.WriteAllText(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToString()); + + if (!include_optimizer) + { + model.Optimizer = orig_optimizer!; + } + } + + public static SavedMetadata generate_keras_metadata(IList saved_nodes, + IDictionary> node_paths) + { + var metadata = new SavedMetadata(); + for (int i = 0; i < saved_nodes.Count; i++) + { + var node = saved_nodes[i]; + if (node is not Layer) + { + continue; + } + + Layer layer = (Layer)node; + + var path = node_paths[node]; + string node_path; + if (path is null || path.Count() == 0) + { + node_path = "root"; + } + else + { + node_path = $"root.{string.Join(".", path.Select(x => x.Name))}"; + } + + ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedObject saved_object = new() + { + NodeId = i, + NodePath = node_path, + Version = new ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef() + { + Producer = 2, + MinConsumer = 1, + BadConsumers = { } + }, + Identifier = layer.ObjectIdentifier, + Metadata = layer.TrackingMetadata + }; + + metadata.Nodes.Add(saved_object); + } + + return metadata; + } + + public static bool should_skip_serialization(object layer) + { + return false; + } + + /// + /// Returns extra trackable objects to attach to the serialized layer. + /// + /// + /// + /// + public static IDictionary wrap_layer_objects(Layer layer, IDictionary> serialization_cache) + { + // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. + + // TODO: change the inherits of `Variable` and revise the implmentation. + var variables = TrackableDataStructure.wrap_or_unwrap(layer.Variables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + })); + var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + })); + var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + })); + + Dictionary res = new(); + res["variables"] = variables; + res["trainable_variables"] = trainable_variables; + res["non_trainable_variables"] = non_trainable_variables; + res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); + + return res; + } + + /// + /// Returns dict of wrapped layer call function and losses in tf.functions. + /// + /// + /// + /// + public static IDictionary wrap_layer_functions(Layer layer, IDictionary> serialization_cache) + { + // TODO: deal with type `RevivedLayer` and `Sequential`. + + // skip the process because of lack of APIs of `Layer`. + + return new Dictionary(); + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs new file mode 100644 index 000000000..eb88c8953 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs @@ -0,0 +1,37 @@ +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.Engine; +using Newtonsoft.Json; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public abstract class SavedModelSaver +{ + protected Trackable _obj; + public SavedModelSaver(Trackable obj) + { + _obj = obj; + } + + public abstract string ObjectIdentifier { get; } + public abstract string TrackingMetadata { get; } + + public abstract IDictionary objects_to_serialize( + IDictionary> serialization_cache); + + public abstract IDictionary functions_to_serialize( + IDictionary> serialization_cache); + + public IDictionary trackable_children(IDictionary> serialization_cache) + { + if (!KerasSavedModelUtils.ShouldHaveTraces) + { + return new Dictionary(); + } + + var children = objects_to_serialize(serialization_cache); + return children.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) + .ToDictionary(x => x.Key, x => x.Value); + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs new file mode 100644 index 000000000..03693cb57 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs @@ -0,0 +1,165 @@ +using System.Collections.Generic; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Utils; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public class LayerSavedModelSaver: SavedModelSaver +{ + private Layer _layer; + public LayerSavedModelSaver(Layer obj): base(obj) + { + _obj = obj; + _layer = obj; + } + public override string ObjectIdentifier + { + get => Constants.LAYER_IDENTIFIER; + } + + public override IDictionary objects_to_serialize(IDictionary> serialization_cache) + { + return get_serialized_attributes(serialization_cache).ObjectsToSerialize; + } + + public override IDictionary functions_to_serialize(IDictionary> serialization_cache) + { + return get_serialized_attributes(serialization_cache).FunctionsToSerialize; + } + + /// + /// Generates or retrieves serialized attributes from cache. + /// + /// + protected ISerializedAttributes get_serialized_attributes(IDictionary> serialization_cache) + { + // TODO: deal with cache. + IDictionary keras_cache; + if(serialization_cache is not null && serialization_cache.ContainsKey(Constants.KERAS_CACHE_KEY)) + { + keras_cache = serialization_cache[Constants.KERAS_CACHE_KEY]; + } + else + { + serialization_cache![Constants.KERAS_CACHE_KEY] = keras_cache = new Dictionary(); + } + if (keras_cache.ContainsKey(_obj)) return keras_cache[_obj]; + + var serialized_attr = keras_cache[_obj] = SerializedAttributes.Create(_obj); + + // TODO: complete the statement. Currently the `Layer` lacks member `_must_restore_from_config`. + if (KerasSavedModelUtils.should_skip_serialization(_obj)) + { + return serialized_attr; + } + + var (object_dict, function_dict) = get_serialized_attributes_internal(serialization_cache); + + serialized_attr.set_and_validate_objects(object_dict); + serialized_attr.set_and_validate_functions(function_dict); + return serialized_attr; + } + + /// + /// Returns dictionary of serialized attributes. + /// + /// + private (IDictionary, IDictionary) get_serialized_attributes_internal(IDictionary> serialization_cache) + { + var objects = KerasSavedModelUtils.wrap_layer_objects(_layer, serialization_cache); + var functions = KerasSavedModelUtils.wrap_layer_functions(_layer, serialization_cache); + + functions["_default_save_signature"] = null; + + return (objects, functions); + } + + public override string TrackingMetadata + { + get + { + JObject metadata = new JObject(); + metadata["name"] = _layer.Name; + metadata["trainable"] = _layer.Trainable; + // TODO: implement `expects_training_arg`. + metadata["expects_training_arg"] = false; + metadata["dtype"] = _layer.DType.as_python_name(); + metadata["batch_input_shape"] = _layer.BatchInputShape is null ? null : JToken.FromObject(_layer.BatchInputShape); + // metadata["stateful"] = _obj.stateful; + // metadata["must_restore_from_config"] = _obj.must_restore_from_config; + // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; + metadata["autocast"] = _layer.AutoCast; + + if(_layer.InputSpec is not null) + { + metadata["input_spec"] = generic_utils.serialize_keras_object(_layer.InputSpec); + } + + metadata.Merge(get_serialized(_layer), new JsonMergeSettings + { + // Handle conflicts by using values from obj2 + MergeArrayHandling = MergeArrayHandling.Merge + }); + // skip the check of `input_spec` and `build_input_shape` for the lack of members. + // skip the check of `activity_regularizer` for the type problem. + if(_layer.BuildInputShape is not null) + { + metadata["build_input_shape"] = JToken.FromObject(_layer.BuildInputShape); + } + return metadata.ToString(); + } + } + + public static JObject get_serialized(Layer obj) + { + return generic_utils.serialize_keras_object(obj); + } +} + +public class InputLayerSavedModelSaver: SavedModelSaver +{ + public InputLayerSavedModelSaver(Layer obj) : base(obj) + { + + } + public override string ObjectIdentifier => Constants.INPUT_LAYER_IDENTIFIER; + + public override IDictionary functions_to_serialize(IDictionary> serialization_cache) + { + return new Dictionary(); + } + + public override IDictionary objects_to_serialize(IDictionary> serialization_cache) + { + return new Dictionary(); + } + + public override string TrackingMetadata + { + get + { + if(_obj is not InputLayer) + { + throw new TypeError($"The type {_obj.GetType()} cannot be recognized as an input layer."); + } + var layer = (InputLayer)_obj; + var config = (layer.get_config() as InputLayerArgs)!; + var info = new + { + class_name = layer.GetType().Name, + name = layer.Name, + dtype = layer.DType, + sparse = config.Sparse, + ragged = config.Ragged, + batch_input_shape = layer.BatchInputShape, + config = layer.get_config() + }; + return JsonConvert.SerializeObject(info); + } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs new file mode 100644 index 000000000..ac194c00f --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs @@ -0,0 +1,282 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers.Rnn; +using Tensorflow.Keras.Metrics; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + // TODO: revise the name of these "Attributes". Since "Attribute" is a significant feature of C#, + // Using the name "Attributes" may be quite confusing. + /// + /// Class that tracks and validates all serialization attributes. + /// + public abstract class SerializedAttributes: ISerializedAttributes + { + protected IDictionary _object_dict; + protected IDictionary _function_dict; + protected AutoTrackable _keras_trackable; + protected HashSet _all_functions; + protected HashSet _all_checkpointable_objects; + + private SerializedAttributes() + { + _object_dict= new Dictionary(); + _function_dict= new Dictionary(); + _keras_trackable= new AutoTrackable(); + _all_functions= new HashSet(); + _all_checkpointable_objects= new HashSet(); + } + + protected SerializedAttributes(IEnumerable checkpointable_objects, IEnumerable functions) + { + _object_dict = new Dictionary(); + _function_dict = new Dictionary(); + _keras_trackable = new AutoTrackable(); + + _all_checkpointable_objects = new HashSet(checkpointable_objects); + _all_functions = new HashSet(functions); + } + + protected SerializedAttributes((IEnumerable, IEnumerable) objects_and_functions) + { + _object_dict = new Dictionary(); + _function_dict = new Dictionary(); + _keras_trackable = new AutoTrackable(); + + _all_checkpointable_objects = new HashSet(objects_and_functions.Item1); + _all_functions = new HashSet(objects_and_functions.Item2); + } + + public IDictionary Functions => _function_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); + + public IDictionary CheckpointableObjects => _object_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); + + /// + /// Returns functions to attach to the root object during serialization. + /// + public IDictionary FunctionsToSerialize + { + get + { + Dictionary functions = new(); + foreach(var pair in Functions) + { + if (_all_functions.Contains(pair.Key)) + { + // TODO: deal with `LayerCall`. + functions[pair.Key] = pair.Value; + } + } + return functions; + } + } + + /// + /// Returns objects to attach to the root object during serialization. + /// + public IDictionary ObjectsToSerialize + { + get + { + var objects = CheckpointableObjects.TakeWhile( x=> _all_checkpointable_objects.Contains(x.Key)).ToDictionary(x => x.Key, x => x.Value); + objects[Constants.KERAS_ATTR] = _keras_trackable; + return objects; + } + } + + /// + /// Saves function dictionary, and validates dictionary values. + /// + /// + public IDictionary set_and_validate_functions(IDictionary function_dict) + { + foreach(var key in _all_functions) + { + if (function_dict.ContainsKey(key)) + { + // TODO: deal with type `LayerCall`. + var fn = function_dict[key]; + if (fn is not null && (fn is not Function)) + { + throw new ValueError($"Function dictionary contained a non-function object: {function_dict[key]} (for key {key})."); + } + _function_dict[key] = fn; + + var tf_fn = fn; // TODO: deal with type `LayerCall`. + + // Warning: this implmentation should be considered again. + var properties = _keras_trackable.GetType().GetProperties(); + foreach (var property in properties) + { + if(property.Name == key) + { + property.SetValue(_keras_trackable, tf_fn); + break; + } + } + } + else + { + throw new ValueError($"Function {key} missing from serialized function dict."); + } + } + return Functions; + } + + /// + /// Saves objects to a dictionary, and validates the values. + /// + /// + public IDictionary set_and_validate_objects(IDictionary object_dict) + { + foreach(var key in _all_checkpointable_objects) + { + if (object_dict.ContainsKey(key)) + { + _object_dict[key] = object_dict[key]; + // Warning: this implmentation should be considered again. + var properties = _keras_trackable.GetType().GetProperties(); + foreach (var property in properties) + { + if (property.Name == key) + { + property.SetValue(_keras_trackable, object_dict[key]); + break; + } + } + } + else + { + throw new ValueError($"Object {key} missing from serialized object dict."); + } + } + return CheckpointableObjects; + } + + /// + /// Returns a new SerializedAttribute object (corresponding to `new` of tensorflow python). + /// + /// + public static SerializedAttributes Create(Trackable obj) + { + if(obj is Model) + { + return new ModelAttributes(); + } + else if(obj is Metric) + { + return new MetricAttributes(); + } + else if(obj is RNN) + { + return new RNNAttributes(); + } + else if(obj is Layer) + { + return new LayerAttributes(); + } + else + { + throw new TypeError($"Internal error during serialization: Expected Keras Layer object, got {obj} of type {obj.GetType()}"); + } + } + + protected virtual (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + { + return (checkpointable_objects ?? (new List()), functions ?? (new List())); + } + } + + // Note that the current implementation still has some potential risks. + // The tensorflow python says that this class is "Common endpoints shared by all models loadable by Keras". + // However, currently it's just a normal class. + public class CommonEndPoints: SerializedAttributes + { + public CommonEndPoints(IEnumerable checkpointable_objects, IEnumerable functions) : + //base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), + // functions.Concat(new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" })) + base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables"}), + functions.Concat(new string[] { })) + { + + } + + public CommonEndPoints() : + //base(new string[] { "variables", "trainable_variables", "regularization_losses" }, + // new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) + base(new string[] { "variables", "trainable_variables"}, + new string[] {}) + { + + } + } + + public class LayerAttributes: CommonEndPoints + { + public LayerAttributes(IEnumerable checkpointable_objects, IEnumerable functions) : + //base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), + // functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) + base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers"}), + functions.Concat(new string[] { })) + { + + } + + public LayerAttributes() : + //base(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }, + // new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) + base(new string[] { "non_trainable_variables", "layers" }, + new string[] { }) + { + + } + } + + public class ModelAttributes: LayerAttributes + { + public ModelAttributes(IEnumerable checkpointable_objects, IEnumerable functions): + base(checkpointable_objects, functions) + { + + } + + public ModelAttributes(): base() + { + + } + } + + public class MetricAttributes : SerializedAttributes + { + public MetricAttributes(IEnumerable checkpointable_objects, IEnumerable functions) : + base(checkpointable_objects.Concat(new string[] { "variables" }), functions) + { + + } + + public MetricAttributes() : + base(new string[] { "variables" }, new string[] {}) + { + + } + } + + public class RNNAttributes: LayerAttributes + { + public RNNAttributes(IEnumerable checkpointable_objects, IEnumerable functions) : + base(checkpointable_objects, functions.Concat(new string[] {"states"})) + { + + } + + public RNNAttributes() : + base(new string[] { }, new string[] { "states" }) + { + + } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs new file mode 100644 index 000000000..51f8d2c91 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs @@ -0,0 +1,47 @@ +using System; +using System.Collections.Generic; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public partial class KerasSavedModelUtils +{ + public static bool ShouldHaveTraces { get; internal set; } = true; + + public static SaveOptionsContext keras_option_scope(bool save_traces) + { + var res = new SaveOptionsContext(ShouldHaveTraces); + ShouldHaveTraces = save_traces; + return res; + } + + public static IEnumerable list_all_layers(Layer layer) + { + if(layer is Model) + { + return (layer as Model).Layers; + } + else + { + return new List(layer._flatten_layers(false, false)); + } + } +} + +/// +/// Implementation of this class is different with that of python. +/// But it could be used with `using` the same as `with` of python. +/// +public class SaveOptionsContext: IDisposable +{ + public bool _old_value; + public SaveOptionsContext(bool old_value) + { + _old_value = old_value; + } + + public void Dispose() + { + KerasSavedModelUtils.ShouldHaveTraces = _old_value; + } +} diff --git a/src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs b/src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs deleted file mode 100644 index 4c2ecc0d8..000000000 --- a/src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs +++ /dev/null @@ -1,15 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; - -namespace Tensorflow.Keras.Saving -{ - public class TensorShapeConfig - { - public string ClassName { get; set; } - public int?[] Items { get; set; } - - public static implicit operator Shape(TensorShapeConfig shape) - => shape == null ? null : new Shape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray()); - } -} diff --git a/src/TensorFlowNET.Keras/Saving/serialization.cs b/src/TensorFlowNET.Keras/Saving/serialization.cs new file mode 100644 index 000000000..d5e46d11c --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/serialization.cs @@ -0,0 +1,125 @@ +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using Tensorflow.Keras.Saving.SavedModel; + +namespace Tensorflow.Keras.Saving +{ + // TODO: make it thread safe. + public class SharedObjectSavingScope: IDisposable + { + private class WeakReferenceEqualityComparer: IEqualityComparer> + { + public bool Equals(WeakReference x, WeakReference y) + { + if(!x.TryGetTarget(out var tx)) + { + return false; + } + if(!y.TryGetTarget(out var ty)) + { + return false; + } + return tx.Equals(ty); + } + public int GetHashCode(WeakReference obj) + { + if (!obj.TryGetTarget(out var w)) + { + return 0; + } + return w.GetHashCode(); + } + } + private static SharedObjectSavingScope? _instance = null; + private readonly Dictionary, int> _shared_object_ids= new Dictionary, int>(); + private int _currentId = 0; + /// + /// record how many times the scope is nested. + /// + private int _nestedDepth = 0; + private SharedObjectSavingScope() + { + + } + + public static SharedObjectSavingScope Enter() + { + if(_instance is not null) + { + _instance._nestedDepth++; + return _instance; + } + else + { + _instance = new SharedObjectSavingScope(); + _instance._nestedDepth++; + return _instance; + } + } + + public static SharedObjectSavingScope GetScope() + { + return _instance; + } + + public int GetId(object? obj) + { + if(obj is null) + { + return _currentId++; + } + var maybe_key = _shared_object_ids.Keys.SingleOrDefault(x => new WeakReferenceEqualityComparer().Equals(x, new WeakReference(obj))); + if (maybe_key is not null) + { + return _shared_object_ids[maybe_key]; + } + _shared_object_ids[new WeakReference(obj)] = _currentId++; + return _currentId; + } + + public void Dispose() + { + _nestedDepth--; + if(_nestedDepth== 0) + { + _instance = null; + } + } + } + + public static class serialize_utils + { + public static readonly string SHARED_OBJECT_KEY = "shared_object_id"; + /// + /// Returns the serialization of the class with the given config. + /// + /// + /// + /// + /// + /// + public static JObject serialize_keras_class_and_config(string class_name, JToken config, object? obj = null, int? shared_object_id = null) + { + JObject res = new JObject(); + res["class_name"] = class_name; + res["config"] = config; + + if(shared_object_id is not null) + { + res[SHARED_OBJECT_KEY] = shared_object_id!; + } + + var scope = SharedObjectSavingScope.GetScope(); + if(scope is not null && obj is not null) + { + res[SHARED_OBJECT_KEY] = scope.GetId(obj); + } + + return res; + } + } +} diff --git a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs index 1e6ce4091..d845f3ca9 100644 --- a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs @@ -53,7 +53,7 @@ public static IVariableV1 make_variable(VariableArgs args) } /// - /// Makes a layer name (or arbitrary string) unique within a TensorFlow graph. + /// Makes a layer name (or arbitrary string) unique within a TensorFlow graph. (correponding to `backend.unique_object_name` of python.) /// /// /// diff --git a/src/TensorFlowNET.Keras/Utils/generic_utils.cs b/src/TensorFlowNET.Keras/Utils/generic_utils.cs index c2839cdc7..730a33e3e 100644 --- a/src/TensorFlowNET.Keras/Utils/generic_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/generic_utils.cs @@ -14,24 +14,43 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Utils { public class generic_utils { - public static LayerConfig serialize_keras_object(ILayer instance) + /// + /// This method does not have corresponding method in python. It's close to `serialize_keras_object`. + /// + /// + /// + public static LayerConfig serialize_layer_to_config(ILayer instance) { var config = instance.get_config(); + Debug.Assert(config is LayerArgs); return new LayerConfig { - Config = config, + Config = config as LayerArgs, ClassName = instance.GetType().Name }; } + public static JObject serialize_keras_object(IKerasConfigable instance) + { + var config = JToken.FromObject(instance.get_config()); + // TODO: change the class_name to registered name, instead of system class name. + return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance); + } + public static string to_snake_case(string name) { return string.Concat(name.Select((x, i) => diff --git a/test/TensorFlowNET.Keras.UnitTest/InitializerTest.cs b/test/TensorFlowNET.Keras.UnitTest/InitializerTest.cs index c811b5643..6950e65fc 100644 --- a/test/TensorFlowNET.Keras.UnitTest/InitializerTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/InitializerTest.cs @@ -6,7 +6,7 @@ using TensorFlowNET.Keras.UnitTest; using static Tensorflow.Binding; -namespace Tensorflow.Keras.UnitTest; +namespace TensorFlowNET.Keras.UnitTest; [TestClass] public class InitializerTest : EagerModeTestBase @@ -15,6 +15,6 @@ public class InitializerTest : EagerModeTestBase public void Orthogonal() { var initializer = tf.keras.initializers.Orthogonal(); - var values = initializer.Apply(new InitializerArgs((2, 2))); + var values = initializer.Apply(new Tensorflow.InitializerArgs((2, 2))); } } diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs index 0a1098af7..67e8ff797 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs @@ -1,6 +1,8 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using Tensorflow.Keras.Engine; +using System.Diagnostics; using static Tensorflow.KerasApi; +using Tensorflow.Keras.Saving; namespace TensorFlowNET.Keras.UnitTest { @@ -15,7 +17,8 @@ public void GetAndFromConfig() { var model = GetFunctionalModel(); var config = model.get_config(); - var new_model = keras.models.from_config(config); + Debug.Assert(config is ModelConfig); + var new_model = keras.models.from_config(config as ModelConfig); Assert.AreEqual(model.Layers.Count, new_model.Layers.Count); } diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs new file mode 100644 index 000000000..269b9c058 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs @@ -0,0 +1,202 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Tensorflow; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using Tensorflow.Keras; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Optimizers; +using Tensorflow.Operations; +using System.Diagnostics; + +namespace TensorFlowNET.Keras.UnitTest.SaveModel; + +[TestClass] +public class SequentialModelTest +{ + [TestMethod] + public void SimpleModelFromAutoCompile() + { + var inputs = new KerasInterface().Input((28, 28, 1)); + var x = new Flatten(new FlattenArgs()).Apply(inputs); + x = new Dense(new DenseArgs() { Units = 100, Activation = tf.nn.relu }).Apply(x); + x = new LayersApi().Dense(units: 10).Apply(x); + var outputs = new LayersApi().Softmax(axis: 1).Apply(x); + var model = new KerasInterface().Model(inputs, outputs); + + model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); + + var data_loader = new MnistModelLoader(); + var num_epochs = 1; + var batch_size = 50; + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 10000, + }).Result; + + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); + + model.save("./pb_simple_compile", save_format: "tf"); + } + + [TestMethod] + public void SimpleModelFromSequential() + { + Model model = KerasApi.keras.Sequential(new List() + { + keras.layers.InputLayer((28, 28, 1)), + keras.layers.Flatten(), + keras.layers.Dense(100, "relu"), + keras.layers.Dense(10), + keras.layers.Softmax(1) + }); + + model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); + + var data_loader = new MnistModelLoader(); + var num_epochs = 1; + var batch_size = 50; + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 50000, + }).Result; + + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); + + model.save("./pb_simple_sequential", save_format: "tf"); + } + + [TestMethod] + public void AlexModelFromSequential() + { + Model model = KerasApi.keras.Sequential(new List() + { + keras.layers.InputLayer((227, 227, 3)), + keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"), + keras.layers.BatchNormalization(), + keras.layers.MaxPooling2D((3, 3), strides:(2, 2)), + + keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: "relu"), + keras.layers.BatchNormalization(), + keras.layers.MaxPooling2D((3, 3), (2, 2)), + + keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"), + keras.layers.BatchNormalization(), + + keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"), + keras.layers.BatchNormalization(), + + keras.layers.Conv2D(256, (3, 3), (1, 1), "same", activation: "relu"), + keras.layers.BatchNormalization(), + keras.layers.MaxPooling2D((3, 3), (2, 2)), + + keras.layers.Flatten(), + keras.layers.Dense(4096, activation: "relu"), + keras.layers.Dropout(0.5f), + + keras.layers.Dense(4096, activation: "relu"), + keras.layers.Dropout(0.5f), + + keras.layers.Dense(1000, activation: "linear"), + keras.layers.Softmax(1) + }); + + model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits:true), new string[] { "accuracy" }); + + var num_epochs = 1; + var batch_size = 8; + + var dataset = new RandomDataSet(new Shape(227, 227, 3), 16); + + model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); + + model.save("./pb_alex_sequential", save_format: "tf"); + + // The saved model can be test with the following python code: + #region alexnet_python_code + //import pathlib + //import tensorflow as tf + + //def func(a): + // return -a + + //if __name__ == '__main__': + // model = tf.keras.models.load_model("./pb_alex_sequential") + // model.summary() + + // num_classes = 5 + // batch_size = 128 + // img_height = 227 + // img_width = 227 + // epochs = 100 + + // dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz" + // data_dir = tf.keras.utils.get_file('flower_photos', origin = dataset_url, untar = True) + // data_dir = pathlib.Path(data_dir) + + // train_ds = tf.keras.preprocessing.image_dataset_from_directory( + // data_dir, + // validation_split = 0.2, + // subset = "training", + // seed = 123, + // image_size = (img_height, img_width), + // batch_size = batch_size) + + // val_ds = tf.keras.preprocessing.image_dataset_from_directory( + // data_dir, + // validation_split = 0.2, + // subset = "validation", + // seed = 123, + // image_size = (img_height, img_width), + // batch_size = batch_size) + + + // model.compile(optimizer = 'adam', + // loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True), + // metrics =['accuracy']) + + // model.build((None, img_height, img_width, 3)) + + // history = model.fit( + // train_ds, + // validation_data = val_ds, + // epochs = epochs + // ) + #endregion + } + + public class RandomDataSet : DataSetBase + { + private Shape _shape; + + public RandomDataSet(Shape shape, int count) + { + _shape = shape; + Debug.Assert(_shape.ndim == 3); + long[] dims = new long[4]; + dims[0] = count; + for (int i = 1; i < 4; i++) + { + dims[i] = _shape[i - 1]; + } + Shape s = new Shape(dims); + Data = np.random.normal(0, 2, s); + Labels = np.random.uniform(0, 1, (count, 1)); + } + } +} \ No newline at end of file