Skip to content

Commit bdf229a

Browse files
committed
Renmae to AssetResource.
1 parent bd154e8 commit bdf229a

File tree

7 files changed

+73
-30
lines changed

7 files changed

+73
-30
lines changed

src/TensorFlowNET.Core/Tensors/tensor_util.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,14 @@ public static NDArray MakeNdarray(TensorProto tensor)
8080
{
8181
return np.array(tensor.IntVal.ToArray()).reshape(shape);
8282
}
83+
else if (new DataType[] { DataType.DtInt64 }.Contains(tensor.Dtype))
84+
{
85+
return np.array(tensor.Int64Val.ToArray()).reshape(shape);
86+
}
87+
else if (new DataType[] { DataType.DtUint64 }.Contains(tensor.Dtype))
88+
{
89+
return np.array(tensor.Uint64Val.ToArray()).reshape(shape);
90+
}
8391
else if (tensor.Dtype == DataType.DtBool)
8492
{
8593
return np.array(tensor.BoolVal.ToArray()).reshape(shape);

src/TensorFlowNET.Core/Trackables/Asset.cs

Lines changed: 0 additions & 11 deletions
This file was deleted.
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using Google.Protobuf.Collections;
2+
using System.IO;
3+
using Tensorflow.Train;
4+
5+
namespace Tensorflow.Trackables;
6+
7+
public class AssetResource : Trackable
8+
{
9+
public static (Trackable, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto,
10+
string export_dir,
11+
RepeatedField<AssetFileDef> asset_file_def,
12+
Dictionary<string, MapField<string, AttrValue>> operation_attributes)
13+
{
14+
var proto = object_proto.Asset;
15+
var filename = Path.Combine(export_dir, asset_file_def[proto.AssetFileDefIndex].Filename);
16+
return (new AssetResource(), null);
17+
}
18+
}
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
using System.Runtime.CompilerServices;
1+
using Google.Protobuf.Collections;
22
using Tensorflow.Train;
33

44
namespace Tensorflow.Trackables;
55

66
public class RestoredResource : TrackableResource
77
{
8-
public static (Trackable, Action<object, object, object>) deserialize_from_proto()
8+
public static (Trackable, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto,
9+
Dictionary<string, MapField<string, AttrValue>> operation_attributes)
910
{
10-
return (null, null);
11+
return (new RestoredResource(), null);
1112
}
1213
}
Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,22 @@
1-
using Tensorflow.Train;
1+
using Google.Protobuf.Collections;
2+
using Tensorflow.Train;
23

34
namespace Tensorflow.Trackables;
45

56
public class TrackableConstant : Trackable
67
{
7-
public static (Trackable, Action<object, object, object>) deserialize_from_proto()
8+
Tensor _constant;
9+
public TrackableConstant(Tensor constant)
810
{
9-
return (null, null);
11+
_constant = constant;
12+
}
13+
14+
public static (Trackable, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto,
15+
Dictionary<string, MapField<string, AttrValue>> operation_attributes)
16+
{
17+
var tensor_proto = operation_attributes[object_proto.Constant.Operation]["value"].Tensor;
18+
var ndarray = tensor_util.MakeNdarray(tensor_proto);
19+
var imported_constant = constant_op.constant(ndarray);
20+
return (new TrackableConstant(imported_constant), null);
1021
}
1122
}

src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,19 @@ namespace Tensorflow.Training.Saving.SavedModel
99
{
1010
public static class function_deserialization
1111
{
12+
/// <summary>
13+
/// Creates a `Function` from a `SavedFunction`.
14+
/// </summary>
15+
/// <param name="saved_concrete_function"></param>
16+
/// <param name="concrete_functions"></param>
17+
/// <returns></returns>
18+
public static ConcreteFunction recreate_function(SavedFunction saved_concrete_function,
19+
IDictionary<string, ConcreteFunction> concrete_functions)
20+
{
21+
var function_spec = _deserialize_function_spec_as_nonmethod(saved_concrete_function.FunctionSpec);
22+
return null;
23+
}
24+
1225
public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function,
1326
IDictionary<string, ConcreteFunction> concrete_functions)
1427
{

src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -387,13 +387,6 @@ private void _load_nodes()
387387
}
388388
else
389389
{
390-
// skip the function and concrete function.
391-
if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction || proto.KindCase == SavedObject.KindOneofCase.Function)
392-
{
393-
nodes[node_id] = null;
394-
node_setters[node_id] = null;
395-
continue;
396-
}
397390
var (node, setter) = _recreate(proto, node_id, nodes);
398391
nodes[node_id] = node;
399392
node_setters[node_id] = setter;
@@ -471,6 +464,11 @@ private void _load_edges()
471464
}
472465
}
473466

467+
private void _setup_function_captures()
468+
{
469+
// TODO: implement it with concrete functions.
470+
}
471+
474472
private void _setup_remaining_functions()
475473
{
476474
// TODO: implement it with concrete functions.
@@ -542,9 +540,9 @@ private void _add_object_graph_edges(SavedObject proto, int node_id)
542540

543541
return proto.KindCase switch
544542
{
545-
SavedObject.KindOneofCase.Resource => RestoredResource.deserialize_from_proto(),
546-
SavedObject.KindOneofCase.Asset => Asset.deserialize_from_proto(),
547-
SavedObject.KindOneofCase.Constant => TrackableConstant.deserialize_from_proto(),
543+
SavedObject.KindOneofCase.Resource => RestoredResource.deserialize_from_proto(proto, _operation_attributes),
544+
SavedObject.KindOneofCase.Asset => AssetResource.deserialize_from_proto(proto, _export_dir, _asset_file_def, _operation_attributes),
545+
SavedObject.KindOneofCase.Constant => TrackableConstant.deserialize_from_proto(proto, _operation_attributes),
548546
_ => _recreate_default(proto, node_id, dependencies)
549547
};
550548
}
@@ -563,7 +561,8 @@ private void _add_object_graph_edges(SavedObject proto, int node_id)
563561
SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null),
564562
SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(),
565563
SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable),
566-
SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException()
564+
SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException(),
565+
_ => throw new NotImplementedException()
567566
};
568567
}
569568

@@ -623,8 +622,12 @@ private void _add_object_graph_edges(SavedObject proto, int node_id)
623622
private (ConcreteFunction, Action<object, object, object>) _recreate_function(SavedFunction proto,
624623
Dictionary<Maybe<string, int>, Trackable> dependencies)
625624
{
626-
throw new NotImplementedException();
627-
//var fn = function_deserialization.setup_bare_concrete_function(proto, )
625+
var fn = function_deserialization.recreate_function(proto, null);
626+
foreach (var name in proto.ConcreteFunctions)
627+
{
628+
_setup_function_captures();
629+
}
630+
return (fn, setattr);
628631
}
629632

630633
private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto,

0 commit comments

Comments
 (0)