Skip to content

Commit f19808e

Browse files
committed
TfLiteInterpreterInvoke
1 parent 66bb81b commit f19808e

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

src/TensorFlowNET.Core/APIs/c_api_lite.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,14 @@ public static extern TfLiteStatus TfLiteInterpreterResizeInputTensor(SafeTfLiteI
7878

7979
[DllImport(TensorFlowLibName)]
8080
public static extern TfLiteStatus TfLiteTensorCopyFromBuffer(TfLiteTensor tensor, IntPtr input_data, int input_data_size);
81+
82+
[DllImport(TensorFlowLibName)]
83+
public static extern TfLiteStatus TfLiteInterpreterInvoke(SafeTfLiteInterpreterHandle interpreter);
84+
85+
[DllImport(TensorFlowLibName)]
86+
public static extern IntPtr TfLiteInterpreterGetOutputTensor(SafeTfLiteInterpreterHandle interpreter, int output_index);
87+
88+
[DllImport(TensorFlowLibName)]
89+
public static extern TfLiteStatus TfLiteTensorCopyToBuffer(TfLiteTensor output_tensor, IntPtr output_data, int output_data_size);
8190
}
8291
}

test/TensorFlowNET.Native.UnitTest/Lite/TfLiteTest.cs

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public void TfLiteVersion()
2020
}
2121

2222
[TestMethod]
23-
public void SmokeTest()
23+
public unsafe void SmokeTest()
2424
{
2525
var model = c_api_lite.TfLiteModelCreateFromFile("Lite/testdata/add.bin");
2626
var options = c_api_lite.TfLiteInterpreterOptionsCreate();
@@ -52,7 +52,36 @@ public void SmokeTest()
5252
Assert.AreEqual(0, input_params.zero_point);
5353

5454
var input = new[] { 1f, 3f };
55-
// c_api_lite.TfLiteTensorCopyFromBuffer(input_tensor, input, 2 * sizeof(float));
55+
fixed (float* addr = &input[0])
56+
{
57+
Assert.AreEqual(TfLiteStatus.kTfLiteOk,
58+
c_api_lite.TfLiteTensorCopyFromBuffer(input_tensor, new IntPtr(addr), 2 * sizeof(float)));
59+
}
60+
61+
Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterInvoke(interpreter));
62+
63+
var output_tensor = c_api_lite.TfLiteInterpreterGetOutputTensor(interpreter, 0);
64+
Assert.AreEqual(TF_DataType.TF_FLOAT, c_api_lite.TfLiteTensorType(output_tensor));
65+
Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(output_tensor));
66+
Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(output_tensor, 0));
67+
Assert.AreEqual(sizeof(float) * 2, c_api_lite.TfLiteTensorByteSize(output_tensor));
68+
Assert.IsNotNull(c_api_lite.TfLiteTensorData(output_tensor));
69+
Assert.AreEqual("output", c_api_lite.StringPiece(c_api_lite.TfLiteTensorName(output_tensor)));
70+
71+
var output_params = c_api_lite.TfLiteTensorQuantizationParams(output_tensor);
72+
Assert.AreEqual(0f, output_params.scale);
73+
Assert.AreEqual(0, output_params.zero_point);
74+
75+
var output = new float[2];
76+
fixed (float* addr = &output[0])
77+
{
78+
Assert.AreEqual(TfLiteStatus.kTfLiteOk,
79+
c_api_lite.TfLiteTensorCopyToBuffer(output_tensor, new IntPtr(addr), 2 * sizeof(float)));
80+
}
81+
Assert.AreEqual(3f, output[0]);
82+
Assert.AreEqual(9f, output[1]);
83+
84+
c_api_lite.TfLiteInterpreterDelete(interpreter.DangerousGetHandle());
5685
}
5786
}
5887
}

0 commit comments

Comments
 (0)