Skip to content

Commit b680700

Browse files
committed
fix import_meta_graph without VariableV1 bug. #453
1 parent 252543f commit b680700

File tree

5 files changed

+37
-16
lines changed

5 files changed

+37
-16
lines changed

src/TensorFlowNET.Core/Framework/meta_graph.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,22 @@ private static void add_collection_def(MetaGraphDef meta_graph_def,
268268

269269
switch (graph.get_collection(key))
270270
{
271+
case List<VariableV1> collection_list:
272+
col_def.BytesList = new Types.BytesList();
273+
foreach (var x in collection_list)
274+
{
275+
if(x is RefVariable x_ref_var)
276+
{
277+
var proto = x_ref_var.to_proto(export_scope);
278+
col_def.BytesList.Value.Add(proto.ToByteString());
279+
}
280+
else
281+
{
282+
Console.WriteLine($"Can't find to_proto method for type {x.GetType().Name}");
283+
}
284+
}
285+
286+
break;
271287
case List<RefVariable> collection_list:
272288
col_def.BytesList = new Types.BytesList();
273289
foreach (var x in collection_list)

src/TensorFlowNET.Core/Operations/Operation.Control.cs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,6 @@ public void _add_control_inputs(Operation[] ops)
5252

5353
public void _set_control_flow_context(ControlFlowContext ctx)
5454
{
55-
if (name.Contains("gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad/f_acc"))
56-
{
57-
58-
}
5955
_control_flow_context = ctx;
6056
}
6157

src/TensorFlowNET.Core/Training/Saving/Saver.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using NumSharp;
1718
using System;
1819
using System.Collections.Generic;
1920
using System.IO;
@@ -170,7 +171,7 @@ public string save(Session sess,
170171
{
171172
if (string.IsNullOrEmpty(latest_filename))
172173
latest_filename = "checkpoint";
173-
object model_checkpoint_path = "";
174+
NDArray[] model_checkpoint_path = null;
174175
string checkpoint_file = "";
175176

176177
if (global_step > 0)
@@ -183,15 +184,14 @@ public string save(Session sess,
183184
if (!_is_empty)
184185
{
185186
model_checkpoint_path = sess.run(_saver_def.SaveTensorName,
186-
new FeedItem(_saver_def.FilenameTensorName, checkpoint_file)
187-
);
187+
(_saver_def.FilenameTensorName, checkpoint_file));
188188

189189
if (write_state)
190190
{
191-
_RecordLastCheckpoint(model_checkpoint_path.ToString());
191+
_RecordLastCheckpoint(model_checkpoint_path[0].ToString());
192192
checkpoint_management.update_checkpoint_state_internal(
193193
save_dir: save_path_parent,
194-
model_checkpoint_path: model_checkpoint_path.ToString(),
194+
model_checkpoint_path: model_checkpoint_path[0].ToString(),
195195
all_model_checkpoint_paths: _last_checkpoints.Keys.Select(x => x).ToList(),
196196
latest_filename: latest_filename,
197197
save_relative_paths: _save_relative_paths);

src/TensorFlowNET.Core/Training/Saving/checkpoint_management.py.cs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ public static void update_checkpoint_state_internal(string save_dir,
4444
float? last_preserved_timestamp = null
4545
)
4646
{
47-
CheckpointState ckpt = null;
48-
47+
CheckpointState ckpt = null;
4948
// Writes the "checkpoint" file for the coordinator for later restoration.
5049
string coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename);
5150
if (save_relative_paths)
@@ -65,7 +64,12 @@ public static void update_checkpoint_state_internal(string save_dir,
6564
throw new RuntimeError($"Save path '{model_checkpoint_path}' conflicts with path used for " +
6665
"checkpoint state. Please use a different save path.");
6766

68-
File.WriteAllText(coord_checkpoint_filename, ckpt.ToString());
67+
// File.WriteAllText(coord_checkpoint_filename, ckpt.ToString());
68+
File.WriteAllLines(coord_checkpoint_filename, new[]
69+
{
70+
$"model_checkpoint_path: \"{ckpt.ModelCheckpointPath}\"",
71+
$"all_model_checkpoint_paths: \"{ckpt.AllModelCheckpointPaths[0]}\"",
72+
});
6973
}
7074

7175
/// <summary>
@@ -98,7 +102,14 @@ private static CheckpointState generate_checkpoint_state_proto(string save_dir,
98102
all_model_checkpoint_paths.Add(model_checkpoint_path);
99103

100104
// Relative paths need to be rewritten to be relative to the "save_dir"
101-
// if model_checkpoint_path already contains "save_dir".
105+
if (model_checkpoint_path.StartsWith(save_dir))
106+
{
107+
model_checkpoint_path = model_checkpoint_path.Substring(save_dir.Length + 1);
108+
all_model_checkpoint_paths = all_model_checkpoint_paths
109+
.Select(x => x.Substring(save_dir.Length + 1))
110+
.ToList();
111+
}
112+
102113

103114
var coord_checkpoint_proto = new CheckpointState()
104115
{

src/TensorFlowNET.Core/Training/Saving/saver.py.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,12 @@ public static (Saver, object) _import_meta_graph_with_return_elements(string met
2929
{
3030
var meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file);
3131

32-
var meta = meta_graph.import_scoped_meta_graph_with_return_elements(
32+
var (imported_vars, imported_return_elements) = meta_graph.import_scoped_meta_graph_with_return_elements(
3333
meta_graph_def,
3434
clear_devices: clear_devices,
3535
import_scope: import_scope,
3636
return_elements: return_elements);
3737

38-
var (imported_vars, imported_return_elements) = meta;
39-
4038
var saver = _create_saver_from_imported_meta_graph(
4139
meta_graph_def, import_scope, imported_vars);
4240

0 commit comments

Comments
 (0)