Skip to content

Commit 4f29e10

Browse files
authored
Merge pull request #1022 from AsakusaRinne/support_function_load
Partially Support the function loading
2 parents a075bba + a59ebae commit 4f29e10

File tree

189 files changed

+73815
-1950
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

189 files changed

+73815
-1950
lines changed

TensorFlow.NET.sln

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
Microsoft Visual Studio Solution File, Format Version 12.00
3-
# Visual Studio Version 16
4-
VisualStudioVersion = 16.0.31624.102
3+
# Visual Studio Version 17
4+
VisualStudioVersion = 17.4.33213.308
55
MinimumVisualStudioVersion = 10.0.40219.1
66
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding", "src\TensorFlowNET.Core\Tensorflow.Binding.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}"
77
EndProject
@@ -23,6 +23,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest",
2323
EndProject
2424
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}"
2525
EndProject
26+
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tensorflow.Common", "Tensorflow.Common\Tensorflow.Common.csproj", "{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}"
27+
EndProject
2628
Global
2729
GlobalSection(SolutionConfigurationPlatforms) = preSolution
2830
Debug|Any CPU = Debug|Any CPU
@@ -153,6 +155,18 @@ Global
153155
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|x64
154156
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU
155157
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU
158+
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
159+
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|Any CPU.Build.0 = Debug|Any CPU
160+
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x64.ActiveCfg = Debug|Any CPU
161+
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x64.Build.0 = Debug|Any CPU
162+
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x86.ActiveCfg = Debug|Any CPU
163+
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x86.Build.0 = Debug|Any CPU
164+
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|Any CPU.ActiveCfg = Release|Any CPU
165+
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|Any CPU.Build.0 = Release|Any CPU
166+
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x64.ActiveCfg = Release|Any CPU
167+
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x64.Build.0 = Release|Any CPU
168+
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x86.ActiveCfg = Release|Any CPU
169+
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x86.Build.0 = Release|Any CPU
156170
EndGlobalSection
157171
GlobalSection(SolutionProperties) = preSolution
158172
HideSolutionNode = FALSE
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Runtime.CompilerServices;
4+
using System.Text;
5+
6+
namespace Tensorflow.Common.Extensions
7+
{
8+
public static class DictionaryExtension
9+
{
10+
public static void Deconstruct<T1, T2>(this KeyValuePair<T1, T2> pair, out T1 first, out T2 second)
11+
{
12+
first = pair.Key;
13+
second = pair.Value;
14+
}
15+
public static void Update<T1, T2>(this Dictionary<T1, T2> dic, IDictionary<T1, T2> other)
16+
{
17+
foreach(var (key, value) in other)
18+
{
19+
dic[key] = value;
20+
}
21+
}
22+
public static T2 GetOrDefault<T1, T2>(this Dictionary<T1, T2> dic, T1 key, T2 defaultValue)
23+
{
24+
if (dic.ContainsKey(key))
25+
{
26+
return dic[key];
27+
}
28+
return defaultValue;
29+
}
30+
}
31+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using OneOf;
2+
using System;
3+
4+
namespace Tensorflow.Common.Extensions
5+
{
6+
public static class OneofExtension
7+
{
8+
public static bool IsTypeOrDeriveFrom<T>(this IOneOf src)
9+
{
10+
return src.Value is T;
11+
}
12+
}
13+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<TargetFramework>netstandard2.0</TargetFramework>
5+
</PropertyGroup>
6+
7+
<ItemGroup>
8+
<PackageReference Include="OneOf" Version="3.0.223" />
9+
</ItemGroup>
10+
11+
</Project>

Tensorflow.Common/Types/NamedTuple.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Runtime.CompilerServices;
4+
using System.Text;
5+
6+
namespace Tensorflow.Common.Types
7+
{
8+
public class NamedTuple
9+
{
10+
public string Name { get; set; }
11+
public Dictionary<string, object> ValueDict { get; set; }
12+
}
13+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Runtime.InteropServices;
4+
using System.Text;
5+
6+
namespace Tensorflow
7+
{
8+
public partial class c_api
9+
{
10+
[DllImport(TensorFlowLibName)]
11+
public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status);
12+
[DllImport(TensorFlowLibName)]
13+
public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
14+
[DllImport(TensorFlowLibName)]
15+
public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status);
16+
}
17+
}

src/TensorFlowNET.Core/APIs/tf.compat.cs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using Google.Protobuf;
1718
using System.Text;
1819

1920
namespace Tensorflow
@@ -45,6 +46,23 @@ internal string as_str(byte[] bytes_or_text, Encoding? encoding = null)
4546
{
4647
return as_text(bytes_or_text, encoding);
4748
}
49+
50+
public ByteString as_bytes(ByteString bytes, Encoding encoding = null)
51+
{
52+
return bytes;
53+
}
54+
public ByteString as_bytes(byte[] bytes, Encoding encoding = null)
55+
{
56+
return ByteString.CopyFrom(bytes);
57+
}
58+
public ByteString as_bytes(string text, Encoding encoding = null)
59+
{
60+
if(encoding is null)
61+
{
62+
encoding = Encoding.UTF8;
63+
}
64+
return ByteString.CopyFrom(encoding.GetBytes(text));
65+
}
4866
}
4967

5068
public bool executing_eagerly()

src/TensorFlowNET.Core/APIs/tf.io.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,6 @@ public ITensorOrOperation[] import_graph_def(GraphDef graph_def,
5454
Dictionary<string, Tensor> input_map = null,
5555
string[] return_elements = null,
5656
string name = null,
57-
OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name, producer_op_list);
57+
OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name: name, producer_op_list: producer_op_list);
5858
}
5959
}

src/TensorFlowNET.Core/APIs/tf.tensor.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using Tensorflow.Operations;
18+
1719
namespace Tensorflow
1820
{
1921
public partial class tensorflow
@@ -79,5 +81,10 @@ public Tensor[] split(Tensor value, int num_split, int axis, string name = null)
7981
num_split: num_split,
8082
axis: axis,
8183
name: name);
84+
85+
public Tensor ensure_shape(Tensor x, Shape shape, string name = null)
86+
{
87+
return gen_ops.ensure_shape(x, shape, name);
88+
}
8289
}
8390
}

src/TensorFlowNET.Core/Attributes/c_api.ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public partial class c_api
6161
public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value);
6262

6363
[DllImport(TensorFlowLibName)]
64-
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status);
64+
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, ulong proto_len, SafeStatusHandle status);
6565

6666
/// <summary>
6767
/// Set `num_dims` to -1 to represent "unknown rank".

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License.
2222
using System.Diagnostics;
2323
using System.IO;
2424
using System.Linq;
25+
using Tensorflow.Operations;
2526

2627
namespace Tensorflow
2728
{

src/TensorFlowNET.Core/Buffers/Buffer.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ public unsafe byte[] ToArray()
107107
}
108108
}
109109

110+
public void Release()
111+
{
112+
_handle.Dispose();
113+
_handle = null;
114+
}
115+
110116
public override string ToString()
111117
=> $"0x{_handle.DangerousGetHandle():x16}";
112118

src/TensorFlowNET.Core/Buffers/TF_Buffer.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,32 @@ public struct TF_Buffer
2525
public IntPtr data;
2626
public ulong length;
2727
public IntPtr data_deallocator;
28+
29+
public unsafe Span<T> AsSpan<T>() where T: unmanaged
30+
{
31+
if(length > int.MaxValue)
32+
{
33+
throw new ValueError($"The length {length} is too large to use in the span.");
34+
}
35+
return new Span<T>(data.ToPointer(), (int)length);
36+
}
37+
38+
public unsafe byte[] ToByteArray()
39+
{
40+
byte[] res = new byte[length];
41+
if(length > int.MaxValue)
42+
{
43+
byte* root = (byte*)data;
44+
for(ulong i = 0; i < length; i++)
45+
{
46+
res[i] = *(root++);
47+
}
48+
}
49+
else
50+
{
51+
new Span<byte>(data.ToPointer(), (int)length).CopyTo(res.AsSpan());
52+
}
53+
return res;
54+
}
2855
}
2956
}

src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ public static IList<Trackable> list_objects(ObjectGraphView graph_view)
161161

162162
internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list)
163163
{
164-
return full_list.TakeWhile(x =>
164+
return full_list.Where(x =>
165165
{
166166
var saveables = x.gather_saveables_for_checkpoint();
167167
return saveables is not null && saveables.Count > 0;

src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
using System;
1+
using OneOf;
2+
using System;
23
using System.Collections.Generic;
34
using System.Diagnostics;
45
using System.Linq;
56
using System.Text;
67
using Tensorflow.Train;
78
using Tensorflow.Training;
9+
using Tensorflow.Common.Extensions;
810
using pbc = global::Google.Protobuf.Collections;
911

1012
namespace Tensorflow.Checkpoint
@@ -28,7 +30,7 @@ Trackable object_to_save
2830
);
2931
public static class SaveUtil
3032
{
31-
public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
33+
public static (IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
3234
serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null)
3335
{
3436
var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map);
@@ -104,7 +106,10 @@ private static TrackableObjectGraph fill_object_graph_proto(IList<TrackableData>
104106
{
105107
var td = trackable_data[i];
106108
Debug.Assert(td.node_id == i);
107-
object_graph_proto.Nodes.Add(new TrackableObjectGraph.Types.TrackableObject(td.slot_variable_proto, td.children_proto));
109+
TrackableObjectGraph.Types.TrackableObject trackable_object = new();
110+
trackable_object.SlotVariables.AddRange(td.slot_variable_proto);
111+
trackable_object.Children.AddRange(td.children_proto);
112+
object_graph_proto.Nodes.Add(trackable_object);
108113
}
109114
return object_graph_proto;
110115
}
@@ -117,16 +122,16 @@ private static TrackableObjectGraph fill_object_graph_proto(IList<TrackableData>
117122
/// <param name="call_with_mapped_captures"></param>
118123
/// <param name="cache"></param>
119124
/// <param name="object_graph_proto"></param>
120-
private static IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids,
125+
private static IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids,
121126
bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto)
122127
{
123-
Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new();
128+
Dictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors = new();
124129
foreach(var td in tensor_trackables)
125130
{
126131
// TODO: deal with cache.
127132
var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? "";
128133
Trackable trackable = null;
129-
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict;
134+
IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict;
130135
if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0)
131136
{
132137
(trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto);
@@ -148,12 +153,12 @@ private static IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDiction
148153
return serialized_tensors;
149154
}
150155

151-
private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
156+
private static IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
152157
{
153158
var trackable = trackable_data.object_to_save;
154159

155160
// TODO: complete it. Note that actually `call_with_mapped_captures` is of function type.
156-
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict;
161+
IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> ret_tensor_dict;
157162
if (call_with_mapped_captures)
158163
{
159164
throw new NotImplementedException();
@@ -163,8 +168,7 @@ private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> g
163168
ret_tensor_dict = trackable.serialize_to_tensors();
164169
}
165170

166-
// TODO: deal with the type `SaveSpce` (currently it will never be it).
167-
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new();
171+
Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict = new();
168172
foreach(var pair in ret_tensor_dict)
169173
{
170174
var local_name = TrackableUtils.escape_local_name(pair.Key);
@@ -173,10 +177,12 @@ private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> g
173177

174178
tensor_dict[checkpoint_key] = maybe_tensor;
175179

176-
if(maybe_tensor.IsTypeOrDeriveFrom<SaveSpec>())
180+
foreach(var key in maybe_tensor.Keys)
177181
{
178-
throw new NotImplementedException();
179-
//((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name;
182+
if (maybe_tensor[key].IsTypeOrDeriveFrom<SaveSpec>())
183+
{
184+
maybe_tensor[key].AsT1.name = local_name + maybe_tensor[key].AsT1.name;
185+
}
180186
}
181187

182188
if(object_graph_proto is not null)
@@ -200,7 +206,7 @@ private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> g
200206
/// <param name="call_with_mapped_captures"></param>
201207
/// <param name="object_graph_proto"></param>
202208
/// <returns></returns>
203-
private static (Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids,
209+
private static (Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids,
204210
bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
205211
{
206212
Dictionary<Trackable, string> object_names = new();

0 commit comments

Comments
 (0)