Skip to content

Commit 3dafbf0

Browse files
committed
Fix the stucking of training when loading model.
1 parent 0060039 commit 3dafbf0

File tree

2 files changed

+39
-45
lines changed

2 files changed

+39
-45
lines changed
Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,42 @@
1-
using System;
2-
using System.Collections.Generic;
3-
using System.IO;
4-
using System.Linq;
5-
using System.Runtime.InteropServices;
6-
using System.Text;
7-
using Tensorflow.Util;
1+
using Tensorflow.Util;
82

93
namespace Tensorflow.Checkpoint
104
{
11-
public class CheckpointReader : SafeTensorflowHandle
5+
sealed class SafeCheckpointReaderHandle : SafeTensorflowHandle
126
{
7+
public SafeCheckpointReaderHandle(): base()
8+
{
9+
10+
}
11+
public SafeCheckpointReaderHandle(IntPtr handle): base(handle)
12+
{
13+
14+
}
15+
16+
protected override bool ReleaseHandle()
17+
{
18+
c_api.TF_DeleteCheckpointReader(handle);
19+
SetHandle(IntPtr.Zero);
20+
return true;
21+
}
22+
}
23+
public class CheckpointReader
24+
{
25+
private SafeCheckpointReaderHandle _handle;
1326
public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; }
1427
public Dictionary<string, Shape> VariableToShapeMap { get; set; }
1528

1629
public CheckpointReader(string filename)
1730
{
1831
Status status = new Status();
19-
handle = c_api.TF_NewCheckpointReader(filename, status.Handle);
32+
_handle = c_api.TF_NewCheckpointReader(filename, status.Handle);
2033
status.Check(true);
2134
ReadAllShapeAndType();
2235
}
2336

2437
public int HasTensor(string name)
2538
{
26-
return c_api.TF_CheckpointReaderHasTensor(handle, name);
39+
return c_api.TF_CheckpointReaderHasTensor(_handle, name);
2740
}
2841

2942
/// <summary>
@@ -33,45 +46,39 @@ public int HasTensor(string name)
3346
/// <returns></returns>
3447
public string GetVariable(int index)
3548
{
36-
return c_api.TF_CheckpointReaderGetVariable(handle, index);
49+
return c_api.StringPiece(c_api.TF_CheckpointReaderGetVariable(_handle, index));
3750
}
3851

3952
public int Size()
4053
{
41-
return c_api.TF_CheckpointReaderSize(handle);
54+
return c_api.TF_CheckpointReaderSize(_handle);
4255
}
4356

4457
public TF_DataType GetVariableDataType(string name)
4558
{
46-
return c_api.TF_CheckpointReaderGetVariableDataType(handle, name);
59+
return c_api.TF_CheckpointReaderGetVariableDataType(_handle, name);
4760
}
4861

4962
public Shape GetVariableShape(string name)
5063
{
51-
// TODO(Rinne): Change it to a constant.
5264
int num_dims = GetVariableNumDims(name);
5365
long[] dims = new long[num_dims];
5466
Status status = new Status();
55-
c_api.TF_CheckpointReaderGetVariableShape(handle, name, dims, num_dims, status.Handle);
67+
c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle);
5668
status.Check(true);
5769
return new Shape(dims);
5870
}
5971

6072
public int GetVariableNumDims(string name)
6173
{
62-
return c_api.TF_CheckpointReaderGetVariableNumDims(handle, name);
74+
return c_api.TF_CheckpointReaderGetVariableNumDims(_handle, name);
6375
}
6476

6577
public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid)
6678
{
6779
Status status = new Status();
68-
var tensor = c_api.TF_CheckpointReaderGetTensor(handle, name, status.Handle);
80+
var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle);
6981
status.Check(true);
70-
var shape = GetVariableShape(name);
71-
if(dtype == TF_DataType.DtInvalid)
72-
{
73-
dtype = GetVariableDataType(name);
74-
}
7582
return new Tensor(tensor);
7683
}
7784

@@ -89,16 +96,5 @@ private void ReadAllShapeAndType()
8996
VariableToShapeMap[name] = shape;
9097
}
9198
}
92-
93-
protected override bool ReleaseHandle()
94-
{
95-
c_api.TF_DeleteCheckpointReader(handle);
96-
return true;
97-
}
98-
99-
public void Dispose()
100-
{
101-
c_api.TF_DeleteCheckpointReader(handle);
102-
}
10399
}
104100
}
Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,27 @@
1-
using System;
2-
using System.Collections.Generic;
3-
using System.Text;
4-
using System.Runtime.InteropServices;
1+
using System.Runtime.InteropServices;
2+
using Tensorflow.Checkpoint;
53

64
namespace Tensorflow
75
{
86
public unsafe partial class c_api
97
{
108
[DllImport(TensorFlowLibName)]
11-
internal static extern IntPtr TF_NewCheckpointReader(string filename, SafeStatusHandle status);
9+
internal static extern SafeCheckpointReaderHandle TF_NewCheckpointReader(string filename, SafeStatusHandle status);
1210
[DllImport(TensorFlowLibName)]
1311
internal static extern void TF_DeleteCheckpointReader(IntPtr reader);
1412
[DllImport(TensorFlowLibName)]
15-
internal static extern int TF_CheckpointReaderHasTensor(IntPtr reader, string name);
13+
internal static extern int TF_CheckpointReaderHasTensor(SafeCheckpointReaderHandle reader, string name);
1614
[DllImport(TensorFlowLibName)]
17-
internal static extern string TF_CheckpointReaderGetVariable(IntPtr reader, int index);
15+
internal static extern IntPtr TF_CheckpointReaderGetVariable(SafeCheckpointReaderHandle reader, int index);
1816
[DllImport(TensorFlowLibName)]
19-
internal static extern int TF_CheckpointReaderSize(IntPtr reader);
17+
internal static extern int TF_CheckpointReaderSize(SafeCheckpointReaderHandle reader);
2018
[DllImport(TensorFlowLibName)]
21-
internal static extern TF_DataType TF_CheckpointReaderGetVariableDataType(IntPtr reader, string name);
19+
internal static extern TF_DataType TF_CheckpointReaderGetVariableDataType(SafeCheckpointReaderHandle reader, string name);
2220
[DllImport(TensorFlowLibName)]
23-
internal static extern void TF_CheckpointReaderGetVariableShape(IntPtr reader, string name, long[] dims, int num_dims, SafeStatusHandle status);
21+
internal static extern void TF_CheckpointReaderGetVariableShape(SafeCheckpointReaderHandle reader, string name, long[] dims, int num_dims, SafeStatusHandle status);
2422
[DllImport(TensorFlowLibName)]
25-
internal static extern int TF_CheckpointReaderGetVariableNumDims(IntPtr reader, string name);
23+
internal static extern int TF_CheckpointReaderGetVariableNumDims(SafeCheckpointReaderHandle reader, string name);
2624
[DllImport(TensorFlowLibName)]
27-
internal static extern SafeTensorHandle TF_CheckpointReaderGetTensor(IntPtr reader, string name, SafeStatusHandle status);
25+
internal static extern SafeTensorHandle TF_CheckpointReaderGetTensor(SafeCheckpointReaderHandle reader, string name, SafeStatusHandle status);
2826
}
2927
}

0 commit comments

Comments
 (0)