From 9349ec4829a0b4e891f7d516d67ec7358bbac28b Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Mon, 24 Apr 2023 00:04:19 +0800 Subject: [PATCH] Add Tensorflow.NET.Hub and support loading bert. --- TensorFlow.NET.sln | 28 + src/TensorFlowNET.Core/Tensors/Tensors.cs | 21 +- src/TensorFlowNET.Core/Tensors/dtypes.cs | 11 + .../Engine/Layer.AddWeights.cs | 4 +- .../Saving/KerasMetaData.cs | 4 + .../Saving/KerasObjectLoader.cs | 13 +- .../Saving/SavedModel/RevivedInputLayer.cs | 37 +- .../Saving/SavedModel/RevivedLayer.cs | 16 +- .../Saving/SavedModel/RevivedNetwork.cs | 40 ++ .../GcsCompressedFileResolver.cs | 57 ++ .../HttpCompressedFileResolver.cs | 78 +++ .../HttpUncompressedFileResolver.cs | 65 ++ src/TensorflowNET.Hub/KerasLayer.cs | 157 +++++ src/TensorflowNET.Hub/Tensorflow.Hub.csproj | 17 + src/TensorflowNET.Hub/file_utils.cs | 74 +++ src/TensorflowNET.Hub/hub.cs | 17 + src/TensorflowNET.Hub/module_v2.cs | 33 + src/TensorflowNET.Hub/registry.cs | 55 ++ src/TensorflowNET.Hub/resolver.cs | 580 ++++++++++++++++++ src/TensorflowNET.Hub/tf_utils.cs | 80 +++ .../KerasLayerTest.cs | 46 ++ .../Tensorflow.Hub.Unittest.csproj | 23 + test/TensorflowNET.Hub.Unittest/Usings.cs | 1 + 23 files changed, 1433 insertions(+), 24 deletions(-) create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/RevivedNetwork.cs create mode 100644 src/TensorflowNET.Hub/GcsCompressedFileResolver.cs create mode 100644 src/TensorflowNET.Hub/HttpCompressedFileResolver.cs create mode 100644 src/TensorflowNET.Hub/HttpUncompressedFileResolver.cs create mode 100644 src/TensorflowNET.Hub/KerasLayer.cs create mode 100644 src/TensorflowNET.Hub/Tensorflow.Hub.csproj create mode 100644 src/TensorflowNET.Hub/file_utils.cs create mode 100644 src/TensorflowNET.Hub/hub.cs create mode 100644 src/TensorflowNET.Hub/module_v2.cs create mode 100644 src/TensorflowNET.Hub/registry.cs create mode 100644 src/TensorflowNET.Hub/resolver.cs create mode 100644 src/TensorflowNET.Hub/tf_utils.cs create mode 100644 test/TensorflowNET.Hub.Unittest/KerasLayerTest.cs create mode 100644 test/TensorflowNET.Hub.Unittest/Tensorflow.Hub.Unittest.csproj create mode 100644 test/TensorflowNET.Hub.Unittest/Usings.cs diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 6357ec25e..d7b388769 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -23,6 +23,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub", "src\TensorflowNET.Hub\Tensorflow.Hub.csproj", "{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub.Unittest", "test\TensorflowNET.Hub.Unittest\Tensorflow.Hub.Unittest.csproj", "{7DEA8760-E401-4872-81F3-405F185A13A0}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -153,6 +157,30 @@ Global {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|x64 {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|Any CPU.Build.0 = Debug|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x64.ActiveCfg = Debug|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x64.Build.0 = Debug|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x86.ActiveCfg = Debug|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x86.Build.0 = Debug|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|Any CPU.ActiveCfg = Release|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|Any CPU.Build.0 = Release|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x64.ActiveCfg = Release|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x64.Build.0 = Release|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x86.ActiveCfg = Release|Any CPU + {9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x86.Build.0 = Release|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|Any CPU.Build.0 = Debug|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x64.ActiveCfg = Debug|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x64.Build.0 = Debug|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x86.ActiveCfg = Debug|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x86.Build.0 = Debug|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|Any CPU.ActiveCfg = Release|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|Any CPU.Build.0 = Release|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x64.ActiveCfg = Release|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x64.Build.0 = Release|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x86.ActiveCfg = Release|Any CPU + {7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index 3d734cd15..b98495a32 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -207,9 +207,24 @@ private static void EnsureSingleTensor(Tensors tensors, string methodnName) } public override string ToString() - => items.Count() == 1 - ? items.First().ToString() - : items.Count() + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name)); + { + if(items.Count == 1) + { + return items[0].ToString(); + } + else + { + StringBuilder sb = new StringBuilder(); + sb.Append($"Totally {items.Count} tensors, which are {string.Join(", ", items.Select(x => x.name))}\n[\n"); + for(int i = 0; i < items.Count; i++) + { + var tensor = items[i]; + sb.Append($"Tensor {i}({tensor.name}): {tensor.ToString()}\n"); + } + sb.Append("]\n"); + return sb.ToString(); + } + } public void Dispose() { diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index 3563f91a0..5b4db53b9 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -301,6 +301,17 @@ public static bool is_integer(this TF_DataType type) type == TF_DataType.DtInt32Ref || type == TF_DataType.DtInt64Ref; } + public static bool is_unsigned(this TF_DataType type) + { + return type == TF_DataType.TF_UINT8 || type == TF_DataType.TF_UINT16 || type == TF_DataType.TF_UINT32 || + type == TF_DataType.TF_UINT64; + } + + public static bool is_bool(this TF_DataType type) + { + return type == TF_DataType.TF_BOOL; + } + public static bool is_floating(this TF_DataType type) { return type == TF_DataType.TF_HALF || type == TF_DataType.TF_FLOAT || type == TF_DataType.TF_DOUBLE; diff --git a/src/TensorFlowNET.Keras/Engine/Layer.AddWeights.cs b/src/TensorFlowNET.Keras/Engine/Layer.AddWeights.cs index 703e7f23b..2925739bc 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.AddWeights.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.AddWeights.cs @@ -22,9 +22,9 @@ protected virtual IVariableV1 add_weight(string name, // If dtype is DT_FLOAT, provide a uniform unit scaling initializer if (dtype.is_floating()) initializer = tf.glorot_uniform_initializer; - else if (dtype.is_integer()) + else if (dtype.is_integer() || dtype.is_unsigned() || dtype.is_bool()) initializer = tf.zeros_initializer; - else + else if(getter is null) throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}"); } diff --git a/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs b/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs index 044296814..9c82370a9 100644 --- a/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs +++ b/src/TensorFlowNET.Keras/Saving/KerasMetaData.cs @@ -36,5 +36,9 @@ public class KerasMetaData public bool? Stateful { get; set; } [JsonProperty("model_config")] public KerasModelConfig? ModelConfig { get; set; } + [JsonProperty("sparse")] + public bool Sparse { get; set; } + [JsonProperty("ragged")] + public bool Ragged { get; set; } } } diff --git a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs index 41d1f0317..fee987294 100644 --- a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs +++ b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs @@ -401,13 +401,22 @@ private void _unblock_model_reconstruction(int layer_id, Layer layer) private (Trackable, Action) revive_custom_object(string identifier, KerasMetaData metadata) { - if(identifier == SavedModel.Constants.LAYER_IDENTIFIER) + if (identifier == SavedModel.Constants.LAYER_IDENTIFIER) { return RevivedLayer.init_from_metadata(metadata); } + else if(identifier == SavedModel.Constants.MODEL_IDENTIFIER || identifier == SavedModel.Constants.SEQUENTIAL_IDENTIFIER + || identifier == SavedModel.Constants.NETWORK_IDENTIFIER) + { + return RevivedNetwork.init_from_metadata(metadata); + } + else if(identifier == SavedModel.Constants.INPUT_LAYER_IDENTIFIER) + { + return RevivedInputLayer.init_from_metadata(metadata); + } else { - throw new NotImplementedException(); + throw new ValueError($"Cannot revive the layer {identifier}."); } } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedInputLayer.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedInputLayer.cs index 639d3aa06..e2cad8a37 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedInputLayer.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedInputLayer.cs @@ -1,15 +1,46 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; namespace Tensorflow.Keras.Saving.SavedModel { - public class RevivedInputLayer: Layer + public class RevivedInputLayer: InputLayer { - private RevivedInputLayer(): base(null) + protected RevivedConfig _config = null; + private RevivedInputLayer(InputLayerArgs args): base(args) { - throw new NotImplementedException(); + + } + + public override IKerasConfig get_config() + { + return _config; + } + + public static (RevivedInputLayer, Action) init_from_metadata(KerasMetaData metadata) + { + InputLayerArgs args = new InputLayerArgs() + { + Name = metadata.Name, + DType = metadata.DType, + Sparse = metadata.Sparse, + Ragged = metadata.Ragged, + BatchInputShape = metadata.BatchInputShape + }; + + RevivedInputLayer revived_obj = new RevivedInputLayer(args); + + revived_obj._config = new RevivedConfig() { Config = metadata.Config }; + + return (revived_obj, Loader.setattr); + } + + public override string ToString() + { + return $"Customized keras input layer: {Name}."; } } } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs index bca84a861..51e367ce8 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs @@ -53,7 +53,7 @@ public static (RevivedLayer, Action) init_from_metadata( return (revived_obj, ReviveUtils._revive_setter); } - private RevivedConfig _config = null; + protected RevivedConfig _config = null; public object keras_api { @@ -70,7 +70,7 @@ public object keras_api } } - public RevivedLayer(LayerArgs args): base(args) + protected RevivedLayer(LayerArgs args): base(args) { } @@ -84,17 +84,5 @@ public override IKerasConfig get_config() { return _config; } - - //protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) - //{ - // if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function) - // { - // return base.Call(inputs, state, training); - // } - // else - // { - // return (func as Function).Apply(inputs); - // } - //} } } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedNetwork.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedNetwork.cs new file mode 100644 index 000000000..1860c8c75 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedNetwork.cs @@ -0,0 +1,40 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + public class RevivedNetwork: RevivedLayer + { + private RevivedNetwork(LayerArgs args) : base(args) + { + + } + + public static (RevivedNetwork, Action) init_from_metadata(KerasMetaData metadata) + { + RevivedNetwork revived_obj = new(new LayerArgs() { Name = metadata.Name }); + + // TODO(Rinne): with utils.no_automatic_dependency_tracking_scope(revived_obj) + // TODO(Rinne): revived_obj._expects_training_arg + var config = metadata.Config; + if (generic_utils.validate_config(config)) + { + revived_obj._config = new RevivedConfig() { Config = config }; + } + if(metadata.ActivityRegularizer is not null) + { + throw new NotImplementedException(); + } + + return (revived_obj, ReviveUtils._revive_setter); + } + + public override string ToString() + { + return $"Customized keras Network: {Name}."; + } + } +} diff --git a/src/TensorflowNET.Hub/GcsCompressedFileResolver.cs b/src/TensorflowNET.Hub/GcsCompressedFileResolver.cs new file mode 100644 index 000000000..f3e1b9723 --- /dev/null +++ b/src/TensorflowNET.Hub/GcsCompressedFileResolver.cs @@ -0,0 +1,57 @@ +using System.IO; +using System.Threading.Tasks; + +namespace Tensorflow.Hub +{ + public class GcsCompressedFileResolver : IResolver + { + const int LOCK_FILE_TIMEOUT_SEC = 10 * 60; + public string Call(string handle) + { + var module_dir = _module_dir(handle); + + return resolver.atomic_download_async(handle, download, module_dir, LOCK_FILE_TIMEOUT_SEC) + .GetAwaiter().GetResult(); + } + public bool IsSupported(string handle) + { + return handle.StartsWith("gs://") && _is_tarfile(handle); + } + + private async Task download(string handle, string tmp_dir) + { + new resolver.DownloadManager(handle).download_and_uncompress( + new FileStream(handle, FileMode.Open, FileAccess.Read), tmp_dir); + await Task.Run(() => { }); + } + + private static string _module_dir(string handle) + { + var cache_dir = resolver.tfhub_cache_dir(use_temp: true); + var sha1 = ComputeSha1(handle); + return resolver.create_local_module_dir(cache_dir, sha1); + } + + private static bool _is_tarfile(string filename) + { + return filename.EndsWith(".tar") || filename.EndsWith(".tar.gz") || filename.EndsWith(".tgz"); + } + + private static string ComputeSha1(string s) + { + using (var sha = new System.Security.Cryptography.SHA1Managed()) + { + var bytes = System.Text.Encoding.UTF8.GetBytes(s); + var hash = sha.ComputeHash(bytes); + var stringBuilder = new System.Text.StringBuilder(hash.Length * 2); + + foreach (var b in hash) + { + stringBuilder.Append(b.ToString("x2")); + } + + return stringBuilder.ToString(); + } + } + } +} diff --git a/src/TensorflowNET.Hub/HttpCompressedFileResolver.cs b/src/TensorflowNET.Hub/HttpCompressedFileResolver.cs new file mode 100644 index 000000000..a127b28c0 --- /dev/null +++ b/src/TensorflowNET.Hub/HttpCompressedFileResolver.cs @@ -0,0 +1,78 @@ +using System; +using System.Net.Http; +using System.Threading.Tasks; + +namespace Tensorflow.Hub +{ + public class HttpCompressedFileResolver : HttpResolverBase + { + const int LOCK_FILE_TIMEOUT_SEC = 10 * 60; // 10 minutes + + private static readonly (string, string) _COMPRESSED_FORMAT_QUERY = + ("tf-hub-format", "compressed"); + + private static string _module_dir(string handle) + { + var cache_dir = resolver.tfhub_cache_dir(use_temp: true); + var sha1 = ComputeSha1(handle); + return resolver.create_local_module_dir(cache_dir, sha1); + } + + public override bool IsSupported(string handle) + { + if (!is_http_protocol(handle)) + { + return false; + } + var load_format = resolver.model_load_format(); + return load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.COMPRESSED) + || load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.AUTO); + } + + public override string Call(string handle) + { + var module_dir = _module_dir(handle); + + return resolver.atomic_download_async( + handle, + download, + module_dir, + LOCK_FILE_TIMEOUT_SEC + ).GetAwaiter().GetResult(); + } + + private async Task download(string handle, string tmp_dir) + { + var client = new HttpClient(); + + var response = await client.GetAsync(_append_compressed_format_query(handle)); + + using (var httpStream = await response.Content.ReadAsStreamAsync()) + { + new resolver.DownloadManager(handle).download_and_uncompress(httpStream, tmp_dir); + } + } + + private string _append_compressed_format_query(string handle) + { + return append_format_query(handle, _COMPRESSED_FORMAT_QUERY); + } + + private static string ComputeSha1(string s) + { + using (var sha = new System.Security.Cryptography.SHA1Managed()) + { + var bytes = System.Text.Encoding.UTF8.GetBytes(s); + var hash = sha.ComputeHash(bytes); + var stringBuilder = new System.Text.StringBuilder(hash.Length * 2); + + foreach (var b in hash) + { + stringBuilder.Append(b.ToString("x2")); + } + + return stringBuilder.ToString(); + } + } + } +} diff --git a/src/TensorflowNET.Hub/HttpUncompressedFileResolver.cs b/src/TensorflowNET.Hub/HttpUncompressedFileResolver.cs new file mode 100644 index 000000000..09a497484 --- /dev/null +++ b/src/TensorflowNET.Hub/HttpUncompressedFileResolver.cs @@ -0,0 +1,65 @@ +using System; +using System.Net; + +namespace Tensorflow.Hub +{ + public class HttpUncompressedFileResolver : HttpResolverBase + { + private readonly PathResolver _pathResolver; + + public HttpUncompressedFileResolver() + { + _pathResolver = new PathResolver(); + } + + public override string Call(string handle) + { + handle = AppendUncompressedFormatQuery(handle); + var gsLocation = RequestGcsLocation(handle); + return _pathResolver.Call(gsLocation); + } + + public override bool IsSupported(string handle) + { + if (!is_http_protocol(handle)) + { + return false; + } + + var load_format = resolver.model_load_format(); + return load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.UNCOMPRESSED); + } + + protected virtual string AppendUncompressedFormatQuery(string handle) + { + return append_format_query(handle, ("tf-hub-format", "uncompressed")); + } + + protected virtual string RequestGcsLocation(string handleWithParams) + { + var request = WebRequest.Create(handleWithParams); + var response = request.GetResponse() as HttpWebResponse; + + if (response == null) + { + throw new Exception("Failed to get a response from the server."); + } + + var statusCode = (int)response.StatusCode; + + if (statusCode != 303) + { + throw new Exception($"Expected 303 for GCS location lookup but got HTTP {statusCode} {response.StatusDescription}"); + } + + var location = response.Headers["Location"]; + + if (!location.StartsWith("gs://")) + { + throw new Exception($"Expected Location:GS path but received {location}"); + } + + return location; + } + } +} \ No newline at end of file diff --git a/src/TensorflowNET.Hub/KerasLayer.cs b/src/TensorflowNET.Hub/KerasLayer.cs new file mode 100644 index 000000000..b9ca949bc --- /dev/null +++ b/src/TensorflowNET.Hub/KerasLayer.cs @@ -0,0 +1,157 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.Engine; +using Tensorflow.Train; +using Tensorflow.Training; +using Tensorflow.Training.Saving.SavedModel; +using static Tensorflow.Binding; + +namespace Tensorflow.Hub +{ + public class KerasLayer : Layer + { + private string _handle; + private LoadOptions? _load_options; + private Trackable _func; + private Func _callable; + + public KerasLayer(string handle, bool trainable = false, LoadOptions? load_options = null) : + base(new Keras.ArgsDefinition.LayerArgs() { Trainable = trainable }) + { + _handle = handle; + _load_options = load_options; + + _func = load_module(_handle, _load_options); + _track_trackable(_func, "_func"); + // TODO(Rinne): deal with _is_hub_module_v1. + + _callable = _get_callable(); + _setup_layer(trainable); + } + + private void _setup_layer(bool trainable = false) + { + HashSet trainable_variables; + if (_func is Layer layer) + { + foreach (var v in layer.TrainableVariables) + { + _add_existing_weight(v, true); + } + trainable_variables = new HashSet(layer.TrainableVariables.Select(v => v.UniqueId)); + } + else if (_func.CustomizedFields.TryGetValue("trainable_variables", out var obj) && obj is IEnumerable trackables) + { + foreach (var trackable in trackables) + { + if (trackable is IVariableV1 v) + { + _add_existing_weight(v, true); + } + } + trainable_variables = new HashSet(trackables.Where(t => t is IVariableV1).Select(t => (t as IVariableV1).UniqueId)); + } + else + { + trainable_variables = new HashSet(); + } + + if (_func is Layer) + { + layer = (Layer)_func; + foreach (var v in layer.Variables) + { + if (!trainable_variables.Contains(v.UniqueId)) + { + _add_existing_weight(v, false); + } + } + } + else if (_func.CustomizedFields.TryGetValue("variables", out var obj) && obj is IEnumerable total_trackables) + { + foreach (var trackable in total_trackables) + { + if (trackable is IVariableV1 v && !trainable_variables.Contains(v.UniqueId)) + { + _add_existing_weight(v, false); + } + } + } + + if (_func.CustomizedFields.ContainsKey("regularization_losses")) + { + if ((_func.CustomizedFields["regularization_losses"] as ListWrapper)?.Count > 0) + { + throw new NotImplementedException("The regularization_losses loading has not been supported yet, " + + "please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues to let us know and add a feature."); + } + } + } + + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + { + _check_trainability(); + + // TODO(Rinne): deal with training_argument + + var result = _callable(inputs); + + return _apply_output_shape_if_set(inputs, result); + } + + private void _check_trainability() + { + if (!Trainable) return; + + // TODO(Rinne): deal with _is_hub_module_v1 and signature + + if (TrainableWeights is null || TrainableWeights.Count == 0) + { + tf.Logger.Error("hub.KerasLayer is trainable but has zero trainable weights."); + } + } + + private Tensors _apply_output_shape_if_set(Tensors inputs, Tensors result) + { + // TODO(Rinne): implement it. + return result; + } + + private void _add_existing_weight(IVariableV1 weight, bool? trainable = null) + { + bool is_trainable; + if (trainable is null) + { + is_trainable = weight.Trainable; + } + else + { + is_trainable = trainable.Value; + } + add_weight(weight.Name, weight.shape, weight.dtype, trainable: is_trainable, getter: x => weight); + } + + private Func _get_callable() + { + if (_func is Layer layer) + { + return x => layer.Apply(x); + } + if (_func.CustomizedFields.ContainsKey("__call__")) + { + if (_func.CustomizedFields["__call__"] is RestoredFunction function) + { + return x => function.Apply(x); + } + } + throw new ValueError("Cannot get the callable from the model."); + } + + private static Trackable load_module(string handle, LoadOptions? load_options = null) + { + //var set_load_options = load_options ?? LoadContext.get_load_option(); + return module_v2.load(handle, load_options); + } + } +} diff --git a/src/TensorflowNET.Hub/Tensorflow.Hub.csproj b/src/TensorflowNET.Hub/Tensorflow.Hub.csproj new file mode 100644 index 000000000..e179de69c --- /dev/null +++ b/src/TensorflowNET.Hub/Tensorflow.Hub.csproj @@ -0,0 +1,17 @@ + + + + netstandard2.0;net6;net7 + 11 + enable + + + + + + + + + + + diff --git a/src/TensorflowNET.Hub/file_utils.cs b/src/TensorflowNET.Hub/file_utils.cs new file mode 100644 index 000000000..3e959afef --- /dev/null +++ b/src/TensorflowNET.Hub/file_utils.cs @@ -0,0 +1,74 @@ +using SharpCompress.Common; +using SharpCompress.Readers; +using System; +using System.IO; + +namespace Tensorflow.Hub +{ + internal static class file_utils + { + //public static void extract_file(TarInputStream tgz, TarEntry tarInfo, string dstPath, uint bufferSize = 10 << 20, Action logFunction = null) + //{ + // using (var src = tgz.GetNextEntry() == tarInfo ? tgz : null) + // { + // if (src is null) + // { + // return; + // } + + // using (var dst = File.Create(dstPath)) + // { + // var buffer = new byte[bufferSize]; + // int count; + + // while ((count = src.Read(buffer, 0, buffer.Length)) > 0) + // { + // dst.Write(buffer, 0, count); + // logFunction?.Invoke(count); + // } + // } + // } + //} + + public static void extract_tarfile_to_destination(Stream fileobj, string dst_path, Action logFunction = null) + { + using (IReader reader = ReaderFactory.Open(fileobj)) + { + while (reader.MoveToNextEntry()) + { + if (!reader.Entry.IsDirectory) + { + reader.WriteEntryToDirectory( + dst_path, + new ExtractionOptions() { ExtractFullPath = true, Overwrite = true } + ); + } + } + } + } + + public static string merge_relative_path(string dstPath, string relPath) + { + var cleanRelPath = Path.GetFullPath(relPath).TrimStart('/', '\\'); + + if (cleanRelPath == ".") + { + return dstPath; + } + + if (cleanRelPath.StartsWith("..") || Path.IsPathRooted(cleanRelPath)) + { + throw new InvalidDataException($"Relative path '{relPath}' is invalid."); + } + + var merged = Path.Combine(dstPath, cleanRelPath); + + if (!merged.StartsWith(dstPath)) + { + throw new InvalidDataException($"Relative path '{relPath}' is invalid. Failed to merge with '{dstPath}'."); + } + + return merged; + } + } +} diff --git a/src/TensorflowNET.Hub/hub.cs b/src/TensorflowNET.Hub/hub.cs new file mode 100644 index 000000000..4fefe0cc2 --- /dev/null +++ b/src/TensorflowNET.Hub/hub.cs @@ -0,0 +1,17 @@ +using Tensorflow.Hub; + +namespace Tensorflow +{ + public static class HubAPI + { + public static HubMethods hub { get; } = new HubMethods(); + } + + public class HubMethods + { + public KerasLayer KerasLayer(string handle, bool trainable = false, LoadOptions? load_options = null) + { + return new KerasLayer(handle, trainable, load_options); + } + } +} diff --git a/src/TensorflowNET.Hub/module_v2.cs b/src/TensorflowNET.Hub/module_v2.cs new file mode 100644 index 000000000..a8e67311b --- /dev/null +++ b/src/TensorflowNET.Hub/module_v2.cs @@ -0,0 +1,33 @@ +using System.IO; +using Tensorflow.Train; + +namespace Tensorflow.Hub +{ + internal static class module_v2 + { + public static Trackable load(string handle, LoadOptions? options) + { + var module_path = resolve(handle); + + // TODO(Rinne): deal with is_hub_module_v1 + + var saved_model_path = Path.Combine(module_path, Constants.SAVED_MODEL_FILENAME_PB); + var saved_model_pb_txt_path = Path.Combine(module_path, Constants.SAVED_MODEL_FILENAME_PBTXT); + if (!File.Exists(saved_model_path) && !Directory.Exists(saved_model_path) && !File.Exists(saved_model_pb_txt_path) + && !Directory.Exists(saved_model_pb_txt_path)) + { + throw new ValueError($"Trying to load a model of incompatible/unknown type. " + + $"'{module_path}' contains neither '{Constants.SAVED_MODEL_FILENAME_PB}' " + + $"nor '{Constants.SAVED_MODEL_FILENAME_PBTXT}'."); + } + + var obj = Loader.load(module_path, options: options); + return obj; + } + + public static string resolve(string handle) + { + return MultiImplRegister.GetResolverRegister().Call(handle); + } + } +} diff --git a/src/TensorflowNET.Hub/registry.cs b/src/TensorflowNET.Hub/registry.cs new file mode 100644 index 000000000..cdc4589b2 --- /dev/null +++ b/src/TensorflowNET.Hub/registry.cs @@ -0,0 +1,55 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Tensorflow.Hub +{ + internal class MultiImplRegister + { + private static MultiImplRegister resolver = new MultiImplRegister("resolver", new IResolver[0]); + private static MultiImplRegister loader = new MultiImplRegister("loader", new IResolver[0]); + + static MultiImplRegister() + { + resolver.add_implementation(new PathResolver()); + resolver.add_implementation(new HttpUncompressedFileResolver()); + resolver.add_implementation(new GcsCompressedFileResolver()); + resolver.add_implementation(new HttpCompressedFileResolver()); + } + + string _name; + List _impls; + public MultiImplRegister(string name, IEnumerable impls) + { + _name = name; + _impls = impls.ToList(); + } + + public void add_implementation(IResolver resolver) + { + _impls.Add(resolver); + } + + public string Call(string handle) + { + foreach (var impl in _impls.Reverse()) + { + if (impl.IsSupported(handle)) + { + return impl.Call(handle); + } + } + throw new RuntimeError($"Cannot resolve the handle {handle}"); + } + + public static MultiImplRegister GetResolverRegister() + { + return resolver; + } + + public static MultiImplRegister GetLoaderRegister() + { + return loader; + } + } +} diff --git a/src/TensorflowNET.Hub/resolver.cs b/src/TensorflowNET.Hub/resolver.cs new file mode 100644 index 000000000..2f8c45ba6 --- /dev/null +++ b/src/TensorflowNET.Hub/resolver.cs @@ -0,0 +1,580 @@ +using ICSharpCode.SharpZipLib.Tar; +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Net.Security; +using System.Security.Authentication; +using System.Threading.Tasks; +using System.Web; +using static Tensorflow.Binding; + +namespace Tensorflow.Hub +{ + internal static class resolver + { + public enum ModelLoadFormat + { + [Description("COMPRESSED")] + COMPRESSED, + [Description("UNCOMPRESSED")] + UNCOMPRESSED, + [Description("AUTO")] + AUTO + } + public class DownloadManager + { + private readonly string _url; + private double _last_progress_msg_print_time; + private long _total_bytes_downloaded; + private int _max_prog_str; + + private bool _interactive_mode() + { + return !string.IsNullOrEmpty(Environment.GetEnvironmentVariable("_TFHUB_DOWNLOAD_PROGRESS")); + } + + private void _print_download_progress_msg(string msg, bool flush = false) + { + if (_interactive_mode()) + { + // Print progress message to console overwriting previous progress + // message. + _max_prog_str = Math.Max(_max_prog_str, msg.Length); + Console.Write($"\r{msg.PadRight(_max_prog_str)}"); + Console.Out.Flush(); + + //如果flush参数为true,则输出换行符减少干扰交互式界面。 + if (flush) + Console.WriteLine(); + + } + else + { + // Interactive progress tracking is disabled. Print progress to the + // standard TF log. + tf.Logger.Information(msg); + } + } + + private void _log_progress(long bytes_downloaded) + { + // Logs progress information about ongoing module download. + + _total_bytes_downloaded += bytes_downloaded; + var now = DateTime.Now.Ticks / TimeSpan.TicksPerSecond; + if (_interactive_mode() || now - _last_progress_msg_print_time > 15) + { + // Print progress message every 15 secs or if interactive progress + // tracking is enabled. + _print_download_progress_msg($"Downloading {_url}:" + + $"{tf_utils.bytes_to_readable_str(_total_bytes_downloaded, true)}"); + _last_progress_msg_print_time = now; + } + } + + public DownloadManager(string url) + { + _url = url; + _last_progress_msg_print_time = DateTime.Now.Ticks / TimeSpan.TicksPerSecond; + _total_bytes_downloaded = 0; + _max_prog_str = 0; + } + + public void download_and_uncompress(Stream fileobj, string dst_path) + { + // Streams the content for the 'fileobj' and stores the result in dst_path. + + try + { + file_utils.extract_tarfile_to_destination(fileobj, dst_path, _log_progress); + var total_size_str = tf_utils.bytes_to_readable_str(_total_bytes_downloaded, true); + _print_download_progress_msg($"Downloaded {_url}, Total size: {total_size_str}", flush: true); + } + catch (TarException ex) + { + throw new IOException($"{_url} does not appear to be a valid module. Inner message:{ex.Message}", ex); + } + } + } + private static Dictionary _flags = new(); + private static readonly string _TFHUB_CACHE_DIR = "TFHUB_CACHE_DIR"; + private static readonly string _TFHUB_DOWNLOAD_PROGRESS = "TFHUB_DOWNLOAD_PROGRESS"; + private static readonly string _TFHUB_MODEL_LOAD_FORMAT = "TFHUB_MODEL_LOAD_FORMAT"; + private static readonly string _TFHUB_DISABLE_CERT_VALIDATION = "TFHUB_DISABLE_CERT_VALIDATION"; + private static readonly string _TFHUB_DISABLE_CERT_VALIDATION_VALUE = "true"; + + static resolver() + { + set_new_flag("tfhub_model_load_format", "AUTO"); + set_new_flag("tfhub_cache_dir", null); + } + + public static string model_load_format() + { + return get_env_setting(_TFHUB_MODEL_LOAD_FORMAT, "tfhub_model_load_format"); + } + + public static string? get_env_setting(string env_var, string flag_name) + { + string value = System.Environment.GetEnvironmentVariable(env_var); + if (string.IsNullOrEmpty(value)) + { + if (_flags.ContainsKey(flag_name)) + { + return _flags[flag_name]; + } + else + { + return null; + } + } + else + { + return value; + } + } + + public static string tfhub_cache_dir(string default_cache_dir = null, bool use_temp = false) + { + var cache_dir = get_env_setting(_TFHUB_CACHE_DIR, "tfhub_cache_dir") ?? default_cache_dir; + if (string.IsNullOrWhiteSpace(cache_dir) && use_temp) + { + // Place all TF-Hub modules under /tfhub_modules. + cache_dir = Path.Combine(Path.GetTempPath(), "tfhub_modules"); + } + if (!string.IsNullOrWhiteSpace(cache_dir)) + { + Console.WriteLine("Using {0} to cache modules.", cache_dir); + } + return cache_dir; + } + + public static string create_local_module_dir(string cache_dir, string module_name) + { + Directory.CreateDirectory(cache_dir); + return Path.Combine(cache_dir, module_name); + } + + public static void set_new_flag(string name, string value) + { + string[] tokens = new string[] {_TFHUB_CACHE_DIR, _TFHUB_DISABLE_CERT_VALIDATION, + _TFHUB_DISABLE_CERT_VALIDATION_VALUE, _TFHUB_DOWNLOAD_PROGRESS, _TFHUB_MODEL_LOAD_FORMAT}; + if (!tokens.Contains(name)) + { + tf.Logger.Warning($"You are settinng a flag '{name}' that cannot be recognized. The flag you set" + + "may not affect anything in tensorflow.hub."); + } + _flags[name] = value; + } + + public static string _merge_relative_path(string dstPath, string relPath) + { + return file_utils.merge_relative_path(dstPath, relPath); + } + + public static string _module_descriptor_file(string moduleDir) + { + return $"{moduleDir}.descriptor.txt"; + } + + public static void _write_module_descriptor_file(string handle, string moduleDir) + { + var readme = _module_descriptor_file(moduleDir); + var content = $"Module: {handle}\nDownload Time: {DateTime.Now}\nDownloader Hostname: {Environment.MachineName} (PID:{Process.GetCurrentProcess().Id})"; + tf_utils.atomic_write_string_to_file(readme, content, overwrite: true); + } + + public static string _lock_file_contents(string task_uid) + { + return $"{Environment.MachineName}.{Process.GetCurrentProcess().Id}.{task_uid}"; + } + + public static string _lock_filename(string moduleDir) + { + return tf_utils.absolute_path(moduleDir) + ".lock"; + } + + private static string _module_dir(string lockFilename) + { + var path = Path.GetDirectoryName(Path.GetFullPath(lockFilename)); + if (!string.IsNullOrEmpty(path)) + { + return Path.Combine(path, "hub_modules"); + } + + throw new Exception("Unable to resolve hub_modules directory from lock file name."); + } + + private static string _task_uid_from_lock_file(string lockFilename) + { + // Returns task UID of the task that created a given lock file. + var lockstring = File.ReadAllText(lockFilename); + return lockstring.Split('.').Last(); + } + + private static string _temp_download_dir(string moduleDir, string taskUid) + { + // Returns the name of a temporary directory to download module to. + return $"{Path.GetFullPath(moduleDir)}.{taskUid}.tmp"; + } + + private static long _dir_size(string directory) + { + // Returns total size (in bytes) of the given 'directory'. + long size = 0; + foreach (var elem in Directory.EnumerateFileSystemEntries(directory)) + { + var stat = new FileInfo(elem); + size += stat.Length; + if ((stat.Attributes & FileAttributes.Directory) != 0) + size += _dir_size(stat.FullName); + } + return size; + } + + public static long _locked_tmp_dir_size(string lockFilename) + { + //Returns the size of the temp dir pointed to by the given lock file. + var taskUid = _task_uid_from_lock_file(lockFilename); + try + { + return _dir_size(_temp_download_dir(_module_dir(lockFilename), taskUid)); + } + catch (DirectoryNotFoundException) + { + return 0; + } + } + + private static void _wait_for_lock_to_disappear(string handle, string lockFile, double lockFileTimeoutSec) + { + long? lockedTmpDirSize = null; + var lockedTmpDirSizeCheckTime = DateTime.Now; + var lockFileContent = ""; + + while (File.Exists(lockFile)) + { + try + { + Console.WriteLine($"Module '{handle}' already being downloaded by '{File.ReadAllText(lockFile)}'. Waiting."); + + if ((DateTime.Now - lockedTmpDirSizeCheckTime).TotalSeconds > lockFileTimeoutSec) + { + var curLockedTmpDirSize = _locked_tmp_dir_size(lockFile); + var curLockFileContent = File.ReadAllText(lockFile); + + if (curLockedTmpDirSize == lockedTmpDirSize && curLockFileContent == lockFileContent) + { + Console.WriteLine($"Deleting lock file {lockFile} due to inactivity."); + File.Delete(lockFile); + break; + } + + lockedTmpDirSize = curLockedTmpDirSize; + lockedTmpDirSizeCheckTime = DateTime.Now; + lockFileContent = curLockFileContent; + } + } + catch (FileNotFoundException) + { + // Lock file or temp directory were deleted during check. Continue + // to check whether download succeeded or we need to start our own + // download. + } + + System.Threading.Thread.Sleep(5000); + } + } + + public static async Task atomic_download_async( + string handle, + Func downloadFn, + string moduleDir, + int lock_file_timeout_sec = 10 * 60) + { + var lockFile = _lock_filename(moduleDir); + var taskUid = Guid.NewGuid().ToString("N"); + var lockContents = _lock_file_contents(taskUid); + var tmpDir = _temp_download_dir(moduleDir, taskUid); + + // Function to check whether model has already been downloaded. + Func checkModuleExists = () => + Directory.Exists(moduleDir) && + Directory.EnumerateFileSystemEntries(moduleDir).Any(); + + // Check whether the model has already been downloaded before locking + // the destination path. + if (checkModuleExists()) + { + return moduleDir; + } + + // Attempt to protect against cases of processes being cancelled with + // KeyboardInterrupt by using a try/finally clause to remove the lock + // and tmp_dir. + while (true) + { + try + { + tf_utils.atomic_write_string_to_file(lockFile, lockContents, false); + // Must test condition again, since another process could have created + // the module and deleted the old lock file since last test. + if (checkModuleExists()) + { + // Lock file will be deleted in the finally-clause. + return moduleDir; + } + if (Directory.Exists(moduleDir)) + { + Directory.Delete(moduleDir, true); + } + break; // Proceed to downloading the module. + } + // These errors are believed to be permanent problems with the + // module_dir that justify failing the download. + catch (FileNotFoundException) + { + throw; + } + catch (UnauthorizedAccessException) + { + throw; + } + catch (IOException) + { + throw; + } + // All other errors are retried. + // TODO(b/144424849): Retrying an AlreadyExistsError from the atomic write + // should be good enough, but see discussion about misc filesystem types. + // TODO(b/144475403): How atomic is the overwrite=False check? + catch (Exception) + { + } + + // Wait for lock file to disappear. + _wait_for_lock_to_disappear(handle, lockFile, lock_file_timeout_sec); + // At this point we either deleted a lock or a lock got removed by the + // owner or another process. Perform one more iteration of the while-loop, + // we would either terminate due tf.compat.v1.gfile.Exists(module_dir) or + // because we would obtain a lock ourselves, or wait again for the lock to + // disappear. + } + + // Lock file acquired. + tf.Logger.Information($"Downloading TF-Hub Module '{handle}'..."); + Directory.CreateDirectory(tmpDir); + await downloadFn(handle, tmpDir); + // Write module descriptor to capture information about which module was + // downloaded by whom and when. The file stored at the same level as a + // directory in order to keep the content of the 'model_dir' exactly as it + // was define by the module publisher. + // + // Note: The descriptor is written purely to help the end-user to identify + // which directory belongs to which module. The descriptor is not part of the + // module caching protocol and no code in the TF-Hub library reads its + // content. + _write_module_descriptor_file(handle, moduleDir); + try + { + Directory.Move(tmpDir, moduleDir); + Console.WriteLine($"Downloaded TF-Hub Module '{handle}'."); + } + catch (IOException e) + { + Console.WriteLine(e.Message); + Console.WriteLine($"Failed to move {tmpDir} to {moduleDir}"); + // Keep the temp directory so we will retry building vocabulary later. + } + + // Temp directory is owned by the current process, remove it. + try + { + Directory.Delete(tmpDir, true); + } + catch (DirectoryNotFoundException) + { + } + + // Lock file exists and is owned by this process. + try + { + var contents = File.ReadAllText(lockFile); + if (contents == lockContents) + { + File.Delete(lockFile); + } + } + catch (Exception) + { + } + + return moduleDir; + } + } + internal interface IResolver + { + string Call(string handle); + bool IsSupported(string handle); + } + + internal class PathResolver : IResolver + { + public string Call(string handle) + { + if (!File.Exists(handle) && !Directory.Exists(handle)) + { + throw new IOException($"{handle} does not exist in file system."); + } + return handle; + } + public bool IsSupported(string handle) + { + return true; + } + } + + public abstract class HttpResolverBase : IResolver + { + private readonly HttpClient httpClient; + private SslProtocol sslProtocol; + private RemoteCertificateValidationCallback certificateValidator; + + protected HttpResolverBase() + { + httpClient = new HttpClient(); + _maybe_disable_cert_validation(); + } + + public abstract string Call(string handle); + public abstract bool IsSupported(string handle); + + protected async Task GetLocalFileStreamAsync(string filePath) + { + try + { + var fs = new FileStream(filePath, FileMode.Open, FileAccess.Read); + return await Task.FromResult(fs); + } + catch (Exception ex) + { + Console.WriteLine($"Failed to read file stream: {ex.Message}"); + return null; + } + } + + protected async Task GetFileStreamAsync(string filePath) + { + if (!is_http_protocol(filePath)) + { + // If filePath is not an HTTP(S) URL, delegate to a file resolver. + return await GetLocalFileStreamAsync(filePath); + } + + var request = new HttpRequestMessage(HttpMethod.Get, filePath); + var response = await _call_urlopen(request); + + if (response.IsSuccessStatusCode) + { + return await response.Content.ReadAsStreamAsync(); + } + else + { + Console.WriteLine($"Failed to fetch file stream: {response.StatusCode} - {response.ReasonPhrase}"); + return null; + } + } + + protected void SetUrlContext(SslProtocol protocol, RemoteCertificateValidationCallback validator) + { + sslProtocol = protocol; + certificateValidator = validator; + } + + public static string append_format_query(string handle, (string, string) formatQuery) + { + var parsed = new Uri(handle); + + var queryBuilder = HttpUtility.ParseQueryString(parsed.Query); + queryBuilder.Add(formatQuery.Item1, formatQuery.Item2); + + parsed = new UriBuilder(parsed.Scheme, parsed.Host, parsed.Port, parsed.AbsolutePath, + "?" + queryBuilder.ToString()).Uri; + + return parsed.ToString(); + } + + protected bool is_http_protocol(string handle) + { + return handle.StartsWith("http://") || handle.StartsWith("https://"); + } + + protected async Task _call_urlopen(HttpRequestMessage request) + { + if (sslProtocol != null) + { + var handler = new HttpClientHandler() + { + SslProtocols = sslProtocol.AsEnum(), + }; + if (certificateValidator != null) + { + handler.ServerCertificateCustomValidationCallback = (x, y, z, w) => + { + return certificateValidator(x, y, z, w); + }; + } + + var client = new HttpClient(handler); + return await client.SendAsync(request); + } + else + { + return await httpClient.SendAsync(request); + } + } + + protected void _maybe_disable_cert_validation() + { + if (Environment.GetEnvironmentVariable("_TFHUB_DISABLE_CERT_VALIDATION") == "_TFHUB_DISABLE_CERT_VALIDATION_VALUE") + { + ServicePointManager.ServerCertificateValidationCallback = (_, _, _, _) => true; + Console.WriteLine("Disabled certificate validation for resolving handles."); + } + } + } + + public class SslProtocol + { + private readonly string protocolString; + + public static readonly SslProtocol Tls = new SslProtocol("TLS"); + public static readonly SslProtocol Tls11 = new SslProtocol("TLS 1.1"); + public static readonly SslProtocol Tls12 = new SslProtocol("TLS 1.2"); + + private SslProtocol(string protocolString) + { + this.protocolString = protocolString; + } + + public SslProtocols AsEnum() + { + switch (protocolString.ToUpper()) + { + case "TLS": + return SslProtocols.Tls; + case "TLS 1.1": + return SslProtocols.Tls11; + case "TLS 1.2": + return SslProtocols.Tls12; + default: + throw new ArgumentException($"Unknown SSL/TLS protocol: {protocolString}"); + } + } + } +} diff --git a/src/TensorflowNET.Hub/tf_utils.cs b/src/TensorflowNET.Hub/tf_utils.cs new file mode 100644 index 000000000..96d8c92d6 --- /dev/null +++ b/src/TensorflowNET.Hub/tf_utils.cs @@ -0,0 +1,80 @@ +using System; +using System.IO; + +namespace Tensorflow.Hub +{ + internal class tf_utils + { + public static string bytes_to_readable_str(long? numBytes, bool includeB = false) + { + if (numBytes == null) return numBytes.ToString(); + + var num = (double)numBytes; + + if (num < 1024) + { + return $"{(long)num}{(includeB ? "B" : "")}"; + } + + num /= 1 << 10; + if (num < 1024) + { + return $"{num:F2}k{(includeB ? "B" : "")}"; + } + + num /= 1 << 10; + if (num < 1024) + { + return $"{num:F2}M{(includeB ? "B" : "")}"; + } + + num /= 1 << 10; + return $"{num:F2}G{(includeB ? "B" : "")}"; + } + + public static void atomic_write_string_to_file(string filename, string contents, bool overwrite) + { + var tempPath = $"{filename}.tmp.{Guid.NewGuid():N}"; + + using (var fileStream = new FileStream(tempPath, FileMode.Create)) + { + using (var writer = new StreamWriter(fileStream)) + { + writer.Write(contents); + writer.Flush(); + } + } + + try + { + if (File.Exists(filename)) + { + if (overwrite) + { + File.Delete(filename); + File.Move(tempPath, filename); + } + } + else + { + File.Move(tempPath, filename); + } + } + catch + { + File.Delete(tempPath); + throw; + } + } + + public static string absolute_path(string path) + { + if (path.Contains("://")) + { + return path; + } + + return Path.GetFullPath(path); + } + } +} diff --git a/test/TensorflowNET.Hub.Unittest/KerasLayerTest.cs b/test/TensorflowNET.Hub.Unittest/KerasLayerTest.cs new file mode 100644 index 000000000..4ee4d54c4 --- /dev/null +++ b/test/TensorflowNET.Hub.Unittest/KerasLayerTest.cs @@ -0,0 +1,46 @@ +using static Tensorflow.Binding; +using static Tensorflow.HubAPI; + +namespace Tensorflow.Hub.Unittest +{ + [TestClass] + public class KerasLayerTest + { + [TestMethod] + public void SmallBert() + { + var layer = hub.KerasLayer("https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-256_A-4/1"); + + var input_type_ids = tf.convert_to_tensor(new int[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, dtype: tf.int32); + input_type_ids = tf.reshape(input_type_ids, (1, 128)); + var input_word_ids = tf.convert_to_tensor(new int[] { 101, 2129, 2024, 2017, 102, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0 }, dtype: tf.int32); + input_word_ids = tf.reshape(input_word_ids, (1, 128)); + var input_mask = tf.convert_to_tensor(new int[] { 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, dtype: dtypes.int32); + input_mask = tf.reshape(input_mask, (1, 128)); + + var result = layer.Apply(new Tensors(input_type_ids, input_word_ids, input_mask)); + } + + } +} \ No newline at end of file diff --git a/test/TensorflowNET.Hub.Unittest/Tensorflow.Hub.Unittest.csproj b/test/TensorflowNET.Hub.Unittest/Tensorflow.Hub.Unittest.csproj new file mode 100644 index 000000000..67c72f54e --- /dev/null +++ b/test/TensorflowNET.Hub.Unittest/Tensorflow.Hub.Unittest.csproj @@ -0,0 +1,23 @@ + + + + net7 + enable + enable + + false + + + + + + + + + + + + + + + diff --git a/test/TensorflowNET.Hub.Unittest/Usings.cs b/test/TensorflowNET.Hub.Unittest/Usings.cs new file mode 100644 index 000000000..ab67c7ea9 --- /dev/null +++ b/test/TensorflowNET.Hub.Unittest/Usings.cs @@ -0,0 +1 @@ +global using Microsoft.VisualStudio.TestTools.UnitTesting; \ No newline at end of file