Skip to content

Commit 66bb81b

Browse files
committed
tflite native api.
1 parent 1fa2f1d commit 66bb81b

File tree

10 files changed

+292
-0
lines changed

10 files changed

+292
-0
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Runtime.InteropServices;
4+
using System.Text;
5+
using Tensorflow.Lite;
6+
7+
namespace Tensorflow
8+
{
9+
public class c_api_lite
10+
{
11+
public const string TensorFlowLibName = "tensorflowlite_c";
12+
13+
public static string StringPiece(IntPtr handle)
14+
{
15+
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle);
16+
}
17+
18+
[DllImport(TensorFlowLibName)]
19+
public static extern IntPtr TfLiteVersion();
20+
21+
[DllImport(TensorFlowLibName)]
22+
public static extern SafeTfLiteModelHandle TfLiteModelCreateFromFile(string model_path);
23+
24+
[DllImport(TensorFlowLibName)]
25+
public static extern void TfLiteModelDelete(IntPtr model);
26+
27+
[DllImport(TensorFlowLibName)]
28+
public static extern SafeTfLiteInterpreterOptionsHandle TfLiteInterpreterOptionsCreate();
29+
30+
[DllImport(TensorFlowLibName)]
31+
public static extern void TfLiteInterpreterOptionsDelete(IntPtr options);
32+
33+
[DllImport(TensorFlowLibName)]
34+
public static extern void TfLiteInterpreterOptionsSetNumThreads(SafeTfLiteInterpreterOptionsHandle options, int num_threads);
35+
36+
[DllImport(TensorFlowLibName)]
37+
public static extern SafeTfLiteInterpreterHandle TfLiteInterpreterCreate(SafeTfLiteModelHandle model, SafeTfLiteInterpreterOptionsHandle optional_options);
38+
39+
[DllImport(TensorFlowLibName)]
40+
public static extern void TfLiteInterpreterDelete(IntPtr interpreter);
41+
42+
[DllImport(TensorFlowLibName)]
43+
public static extern TfLiteStatus TfLiteInterpreterAllocateTensors(SafeTfLiteInterpreterHandle interpreter);
44+
45+
[DllImport(TensorFlowLibName)]
46+
public static extern int TfLiteInterpreterGetInputTensorCount(SafeTfLiteInterpreterHandle interpreter);
47+
48+
[DllImport(TensorFlowLibName)]
49+
public static extern int TfLiteInterpreterGetOutputTensorCount(SafeTfLiteInterpreterHandle interpreter);
50+
51+
[DllImport(TensorFlowLibName)]
52+
public static extern TfLiteStatus TfLiteInterpreterResizeInputTensor(SafeTfLiteInterpreterHandle interpreter,
53+
int input_index, int[] input_dims, int input_dims_size);
54+
55+
[DllImport(TensorFlowLibName)]
56+
public static extern TfLiteTensor TfLiteInterpreterGetInputTensor(SafeTfLiteInterpreterHandle interpreter, int input_index);
57+
58+
[DllImport(TensorFlowLibName)]
59+
public static extern TF_DataType TfLiteTensorType(TfLiteTensor tensor);
60+
61+
[DllImport(TensorFlowLibName)]
62+
public static extern int TfLiteTensorNumDims(TfLiteTensor tensor);
63+
64+
[DllImport(TensorFlowLibName)]
65+
public static extern int TfLiteTensorDim(TfLiteTensor tensor, int dim_index);
66+
67+
[DllImport(TensorFlowLibName)]
68+
public static extern int TfLiteTensorByteSize(TfLiteTensor tensor);
69+
70+
[DllImport(TensorFlowLibName)]
71+
public static extern IntPtr TfLiteTensorData(TfLiteTensor tensor);
72+
73+
[DllImport(TensorFlowLibName)]
74+
public static extern IntPtr TfLiteTensorName(TfLiteTensor tensor);
75+
76+
[DllImport(TensorFlowLibName)]
77+
public static extern TfLiteQuantizationParams TfLiteTensorQuantizationParams(TfLiteTensor tensor);
78+
79+
[DllImport(TensorFlowLibName)]
80+
public static extern TfLiteStatus TfLiteTensorCopyFromBuffer(TfLiteTensor tensor, IntPtr input_data, int input_data_size);
81+
}
82+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Util;
5+
6+
namespace Tensorflow.Lite
7+
{
8+
public class SafeTfLiteInterpreterHandle : SafeTensorflowHandle
9+
{
10+
protected SafeTfLiteInterpreterHandle()
11+
{
12+
}
13+
14+
public SafeTfLiteInterpreterHandle(IntPtr handle)
15+
: base(handle)
16+
{
17+
}
18+
19+
protected override bool ReleaseHandle()
20+
{
21+
c_api_lite.TfLiteInterpreterDelete(handle);
22+
SetHandle(IntPtr.Zero);
23+
return true;
24+
}
25+
}
26+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Util;
5+
6+
namespace Tensorflow.Lite
7+
{
8+
public class SafeTfLiteInterpreterOptionsHandle : SafeTensorflowHandle
9+
{
10+
protected SafeTfLiteInterpreterOptionsHandle()
11+
{
12+
}
13+
14+
public SafeTfLiteInterpreterOptionsHandle(IntPtr handle)
15+
: base(handle)
16+
{
17+
}
18+
19+
protected override bool ReleaseHandle()
20+
{
21+
c_api_lite.TfLiteInterpreterOptionsDelete(handle);
22+
SetHandle(IntPtr.Zero);
23+
return true;
24+
}
25+
}
26+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Util;
5+
6+
namespace Tensorflow.Lite
7+
{
8+
public class SafeTfLiteModelHandle : SafeTensorflowHandle
9+
{
10+
protected SafeTfLiteModelHandle()
11+
{
12+
}
13+
14+
public SafeTfLiteModelHandle(IntPtr handle)
15+
: base(handle)
16+
{
17+
}
18+
19+
protected override bool ReleaseHandle()
20+
{
21+
c_api_lite.TfLiteModelDelete(handle);
22+
SetHandle(IntPtr.Zero);
23+
return true;
24+
}
25+
}
26+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Lite
6+
{
7+
public struct TfLiteQuantizationParams
8+
{
9+
public float scale;
10+
public int zero_point;
11+
}
12+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Lite
6+
{
7+
public enum TfLiteStatus
8+
{
9+
kTfLiteOk = 0,
10+
11+
// Generally referring to an error in the runtime (i.e. interpreter)
12+
kTfLiteError = 1,
13+
14+
// Generally referring to an error from a TfLiteDelegate itself.
15+
kTfLiteDelegateError = 2,
16+
17+
// Generally referring to an error in applying a delegate due to
18+
// incompatibility between runtime and delegate, e.g., this error is returned
19+
// when trying to apply a TfLite delegate onto a model graph that's already
20+
// immutable.
21+
kTfLiteApplicationError = 3,
22+
23+
// Generally referring to serialized delegate data not being found.
24+
// See tflite::delegates::Serialization.
25+
kTfLiteDelegateDataNotFound = 4,
26+
27+
// Generally referring to data-writing issues in delegate serialization.
28+
// See tflite::delegates::Serialization.
29+
kTfLiteDelegateDataWriteError = 5,
30+
}
31+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using System;
2+
3+
namespace Tensorflow.Lite
4+
{
5+
public struct TfLiteTensor
6+
{
7+
IntPtr _handle;
8+
9+
public TfLiteTensor(IntPtr handle)
10+
=> _handle = handle;
11+
12+
public static implicit operator TfLiteTensor(IntPtr handle)
13+
=> new TfLiteTensor(handle);
14+
15+
public static implicit operator IntPtr(TfLiteTensor tensor)
16+
=> tensor._handle;
17+
18+
public override string ToString()
19+
=> $"TfLiteTensor 0x{_handle.ToString("x16")}";
20+
}
21+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Runtime.InteropServices;
6+
using System.Text;
7+
using System.Threading.Tasks;
8+
using Tensorflow.Lite;
9+
10+
namespace Tensorflow.Native.UnitTest
11+
{
12+
[TestClass]
13+
public class TfLiteTest
14+
{
15+
[TestMethod]
16+
public void TfLiteVersion()
17+
{
18+
var ver = c_api_lite.StringPiece(c_api_lite.TfLiteVersion());
19+
Assert.IsNotNull(ver);
20+
}
21+
22+
[TestMethod]
23+
public void SmokeTest()
24+
{
25+
var model = c_api_lite.TfLiteModelCreateFromFile("Lite/testdata/add.bin");
26+
var options = c_api_lite.TfLiteInterpreterOptionsCreate();
27+
c_api_lite.TfLiteInterpreterOptionsSetNumThreads(options, 2);
28+
29+
var interpreter = c_api_lite.TfLiteInterpreterCreate(model, options);
30+
31+
c_api_lite.TfLiteInterpreterOptionsDelete(options.DangerousGetHandle());
32+
c_api_lite.TfLiteModelDelete(model.DangerousGetHandle());
33+
34+
Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter));
35+
Assert.AreEqual(1, c_api_lite.TfLiteInterpreterGetInputTensorCount(interpreter));
36+
Assert.AreEqual(1, c_api_lite.TfLiteInterpreterGetOutputTensorCount(interpreter));
37+
38+
var input_dims = new int[] { 2 };
39+
Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterResizeInputTensor(interpreter, 0, input_dims, input_dims.Length));
40+
Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter));
41+
42+
var input_tensor = c_api_lite.TfLiteInterpreterGetInputTensor(interpreter, 0);
43+
Assert.AreEqual(TF_DataType.TF_FLOAT, c_api_lite.TfLiteTensorType(input_tensor));
44+
Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(input_tensor));
45+
Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(input_tensor, 0));
46+
Assert.AreEqual(sizeof(float) * 2, c_api_lite.TfLiteTensorByteSize(input_tensor));
47+
Assert.IsNotNull(c_api_lite.TfLiteTensorData(input_tensor));
48+
Assert.AreEqual("input", c_api_lite.StringPiece(c_api_lite.TfLiteTensorName(input_tensor)));
49+
50+
var input_params = c_api_lite.TfLiteTensorQuantizationParams(input_tensor);
51+
Assert.AreEqual(0f, input_params.scale);
52+
Assert.AreEqual(0, input_params.zero_point);
53+
54+
var input = new[] { 1f, 3f };
55+
// c_api_lite.TfLiteTensorCopyFromBuffer(input_tensor, input, 2 * sizeof(float));
56+
}
57+
}
58+
}
Binary file not shown.

test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@
2424
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
2525
</PropertyGroup>
2626

27+
<ItemGroup>
28+
<None Remove="Lite\testdata\add.bin" />
29+
</ItemGroup>
30+
31+
<ItemGroup>
32+
<Content Include="Lite\testdata\add.bin">
33+
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
34+
</Content>
35+
</ItemGroup>
36+
2737
<ItemGroup>
2838
<PackageReference Include="FluentAssertions" Version="5.10.3" />
2939
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.11.0-release-20210626-04" />

0 commit comments

Comments
 (0)