1
- using System ;
2
- using System . Collections . Generic ;
3
- using System . IO ;
4
- using System . Linq ;
5
- using System . Runtime . InteropServices ;
6
- using System . Text ;
7
- using Tensorflow . Util ;
1
+ using Tensorflow . Util ;
8
2
9
3
namespace Tensorflow . Checkpoint
10
4
{
11
- public class CheckpointReader : SafeTensorflowHandle
5
+ sealed class SafeCheckpointReaderHandle : SafeTensorflowHandle
12
6
{
7
+ public SafeCheckpointReaderHandle ( ) : base ( )
8
+ {
9
+
10
+ }
11
+ public SafeCheckpointReaderHandle ( IntPtr handle ) : base ( handle )
12
+ {
13
+
14
+ }
15
+
16
+ protected override bool ReleaseHandle ( )
17
+ {
18
+ c_api . TF_DeleteCheckpointReader ( handle ) ;
19
+ SetHandle ( IntPtr . Zero ) ;
20
+ return true ;
21
+ }
22
+ }
23
+ public class CheckpointReader
24
+ {
25
+ private SafeCheckpointReaderHandle _handle ;
13
26
public Dictionary < string , TF_DataType > VariableToDataTypeMap { get ; set ; }
14
27
public Dictionary < string , Shape > VariableToShapeMap { get ; set ; }
15
28
16
29
public CheckpointReader ( string filename )
17
30
{
18
31
Status status = new Status ( ) ;
19
- handle = c_api . TF_NewCheckpointReader ( filename , status . Handle ) ;
32
+ _handle = c_api . TF_NewCheckpointReader ( filename , status . Handle ) ;
20
33
status . Check ( true ) ;
21
34
ReadAllShapeAndType ( ) ;
22
35
}
23
36
24
37
public int HasTensor ( string name )
25
38
{
26
- return c_api . TF_CheckpointReaderHasTensor ( handle , name ) ;
39
+ return c_api . TF_CheckpointReaderHasTensor ( _handle , name ) ;
27
40
}
28
41
29
42
/// <summary>
@@ -33,45 +46,39 @@ public int HasTensor(string name)
33
46
/// <returns></returns>
34
47
public string GetVariable ( int index )
35
48
{
36
- return c_api . TF_CheckpointReaderGetVariable ( handle , index ) ;
49
+ return c_api . StringPiece ( c_api . TF_CheckpointReaderGetVariable ( _handle , index ) ) ;
37
50
}
38
51
39
52
public int Size ( )
40
53
{
41
- return c_api . TF_CheckpointReaderSize ( handle ) ;
54
+ return c_api . TF_CheckpointReaderSize ( _handle ) ;
42
55
}
43
56
44
57
public TF_DataType GetVariableDataType ( string name )
45
58
{
46
- return c_api . TF_CheckpointReaderGetVariableDataType ( handle , name ) ;
59
+ return c_api . TF_CheckpointReaderGetVariableDataType ( _handle , name ) ;
47
60
}
48
61
49
62
public Shape GetVariableShape ( string name )
50
63
{
51
- // TODO(Rinne): Change it to a constant.
52
64
int num_dims = GetVariableNumDims ( name ) ;
53
65
long [ ] dims = new long [ num_dims ] ;
54
66
Status status = new Status ( ) ;
55
- c_api . TF_CheckpointReaderGetVariableShape ( handle , name , dims , num_dims , status . Handle ) ;
67
+ c_api . TF_CheckpointReaderGetVariableShape ( _handle , name , dims , num_dims , status . Handle ) ;
56
68
status . Check ( true ) ;
57
69
return new Shape ( dims ) ;
58
70
}
59
71
60
72
public int GetVariableNumDims ( string name )
61
73
{
62
- return c_api . TF_CheckpointReaderGetVariableNumDims ( handle , name ) ;
74
+ return c_api . TF_CheckpointReaderGetVariableNumDims ( _handle , name ) ;
63
75
}
64
76
65
77
public unsafe Tensor GetTensor ( string name , TF_DataType dtype = TF_DataType . DtInvalid )
66
78
{
67
79
Status status = new Status ( ) ;
68
- var tensor = c_api . TF_CheckpointReaderGetTensor ( handle , name , status . Handle ) ;
80
+ var tensor = c_api . TF_CheckpointReaderGetTensor ( _handle , name , status . Handle ) ;
69
81
status . Check ( true ) ;
70
- var shape = GetVariableShape ( name ) ;
71
- if ( dtype == TF_DataType . DtInvalid )
72
- {
73
- dtype = GetVariableDataType ( name ) ;
74
- }
75
82
return new Tensor ( tensor ) ;
76
83
}
77
84
@@ -89,16 +96,5 @@ private void ReadAllShapeAndType()
89
96
VariableToShapeMap [ name ] = shape ;
90
97
}
91
98
}
92
-
93
- protected override bool ReleaseHandle ( )
94
- {
95
- c_api . TF_DeleteCheckpointReader ( handle ) ;
96
- return true ;
97
- }
98
-
99
- public void Dispose ( )
100
- {
101
- c_api . TF_DeleteCheckpointReader ( handle ) ;
102
- }
103
99
}
104
100
}
0 commit comments