Skip to content

Commit 85139ed

Browse files
committed
Fix Session.LoadFromSavedModel memroy leak.
1 parent b3d0862 commit 85139ed

File tree

7 files changed

+52
-63
lines changed

7 files changed

+52
-63
lines changed

src/TensorFlowNET.Core/Graphs/c_api.graph.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ public partial class c_api
289289
[DllImport(TensorFlowLibName)]
290290
public static extern IntPtr TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options,
291291
string export_dir, string[] tags, int tags_len,
292-
IntPtr graph, ref TF_Buffer meta_graph_def, SafeStatusHandle status);
292+
IntPtr graph, IntPtr meta_graph_def, SafeStatusHandle status);
293293

294294
[DllImport(TensorFlowLibName)]
295295
public static extern IntPtr TF_NewGraph();

src/TensorFlowNET.Core/Sessions/BaseSession.cs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ public class BaseSession : DisposableObject
3636
protected byte[] _target;
3737
public Graph graph => _graph;
3838

39+
public BaseSession(IntPtr handle, Graph g)
40+
{
41+
_handle = handle;
42+
_graph = g ?? ops.get_default_graph();
43+
}
44+
3945
public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null)
4046
{
4147
_graph = g ?? ops.get_default_graph();
@@ -291,12 +297,8 @@ private void _extend_graph()
291297

292298
protected override void DisposeUnmanagedResources(IntPtr handle)
293299
{
294-
lock (Locks.ProcessWide)
295-
using (var status = new Status())
296-
{
297-
c_api.TF_DeleteSession(handle, status.Handle);
298-
status.Check(true);
299-
}
300+
// c_api.TF_CloseSession(handle, tf.Status.Handle);
301+
c_api.TF_DeleteSession(handle, tf.Status.Handle);
300302
}
301303
}
302304
}

src/TensorFlowNET.Core/Sessions/Session.cs

Lines changed: 23 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,8 @@ public class Session : BaseSession, ITensorFlowObject
2626
public Session(string target = "", Graph g = null) : base(target, g, null)
2727
{ }
2828

29-
public Session(IntPtr handle, Graph g = null) : base("", g, null)
30-
{
31-
_handle = handle;
32-
}
29+
public Session(IntPtr handle, Graph g = null) : base(handle, g)
30+
{ }
3331

3432
public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s)
3533
{ }
@@ -39,51 +37,29 @@ public Session as_default()
3937
return ops.set_default_session(this);
4038
}
4139

42-
[MethodImpl(MethodImplOptions.NoOptimization)]
4340
public static Session LoadFromSavedModel(string path)
4441
{
45-
lock (Locks.ProcessWide)
46-
{
47-
var graph = c_api.TF_NewGraph();
48-
using var status = new Status();
49-
var opt = new SessionOptions();
50-
51-
var tags = new string[] { "serve" };
52-
var buffer = new TF_Buffer();
53-
54-
IntPtr sess;
55-
try
56-
{
57-
sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle,
58-
IntPtr.Zero,
59-
path,
60-
tags,
61-
tags.Length,
62-
graph,
63-
ref buffer,
64-
status.Handle);
65-
status.Check(true);
66-
}
67-
catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel"))
68-
{
69-
sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle,
70-
IntPtr.Zero,
71-
Path.GetFullPath(path),
72-
tags,
73-
tags.Length,
74-
graph,
75-
ref buffer,
76-
status.Handle);
77-
status.Check(true);
78-
}
79-
80-
// load graph bytes
81-
// var data = new byte[buffer.length];
82-
// Marshal.Copy(buffer.data, data, 0, (int)buffer.length);
83-
// var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/
84-
85-
return new Session(sess, g: new Graph(graph)).as_default();
86-
}
42+
using var graph = new Graph();
43+
using var status = new Status();
44+
using var opt = c_api.TF_NewSessionOptions();
45+
46+
var tags = new string[] { "serve" };
47+
48+
var sess = c_api.TF_LoadSessionFromSavedModel(opt,
49+
IntPtr.Zero,
50+
path,
51+
tags,
52+
tags.Length,
53+
graph,
54+
IntPtr.Zero,
55+
status.Handle);
56+
status.Check(true);
57+
58+
// load graph bytes
59+
// var data = new byte[buffer.length];
60+
// Marshal.Copy(buffer.data, data, 0, (int)buffer.length);
61+
// var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/
62+
return new Session(sess, g: graph);
8763
}
8864

8965
public static implicit operator IntPtr(Session session) => session._handle;

src/TensorFlowNET.Core/Sessions/c_api.session.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@ namespace Tensorflow
2121
{
2222
public partial class c_api
2323
{
24+
/// <summary>
25+
/// Close a session.
26+
///
27+
/// Contacts any other processes associated with the session, if applicable.
28+
/// May not be called after TF_DeleteSession().
29+
/// </summary>
30+
/// <param name="s"></param>
31+
/// <param name="status"></param>
32+
33+
[DllImport(TensorFlowLibName)]
34+
public static extern void TF_CloseSession(IntPtr session, SafeStatusHandle status);
35+
2436
/// <summary>
2537
/// Destroy a session object.
2638
///

src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Reflection;
77
using System.Text;
88
using System.Threading.Tasks;
9+
using static Tensorflow.Binding;
910

1011
namespace Tensorflow.Benchmark.Leak
1112
{
@@ -18,13 +19,9 @@ public void Run()
1819
var modelDir = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location);
1920
var ClassifierModelPath = Path.Combine(modelDir, "Leak", "TestModel", "saved_model");
2021

21-
for (var i = 0; i < 50; i++)
22-
{
23-
var session = Session.LoadFromSavedModel(ClassifierModelPath);
24-
25-
session.graph.Exit();
26-
session.graph.Dispose();
27-
session.Dispose();
22+
for (var i = 0; i < 1024; i++)
23+
{
24+
using var sess = Session.LoadFromSavedModel(ClassifierModelPath);
2825
}
2926
}
3027
}

src/TensorFlowNet.Benchmarks/Program.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ class Program
1313
static void Main(string[] args)
1414
{
1515
print(tf.VERSION);
16-
/*new RepeatDataSetCrash().Run();
16+
17+
/*new SavedModelCleanup().Run();
18+
new RepeatDataSetCrash().Run();
1719
new GpuLeakByCNN().Run();*/
1820

1921
if (args?.Length > 0)

src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
<ItemGroup>
3939
<PackageReference Include="BenchmarkDotNet" Version="0.13.0" />
40-
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.6.0-rc0" />
40+
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.6.0" />
4141
</ItemGroup>
4242

4343
<ItemGroup>

0 commit comments

Comments
 (0)