Skip to content

Commit 6adcfae

Browse files
committed
ndarray string comparison.
1 parent e748a8f commit 6adcfae

File tree

8 files changed

+49
-13
lines changed

8 files changed

+49
-13
lines changed

src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ void NewEagerTensorHandle(SafeTensorHandle h)
6060
{
6161
_id = ops.uid();
6262
_eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status.Handle);
63+
#if TRACK_TENSOR_LIFE
64+
Console.WriteLine($"New EagerTensor {_eagerTensorHandle}");
65+
#endif
6366
tf.Status.Check(true);
6467
}
6568

src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ public override bool Equals(object obj)
1616
long val => GetAtIndex<long>(0) == val,
1717
float val => GetAtIndex<float>(0) == val,
1818
double val => GetAtIndex<double>(0) == val,
19+
string val => StringData(0) == val,
1920
NDArray val => Equals(this, val),
2021
_ => base.Equals(obj)
2122
};

src/TensorFlowNET.Core/NumPy/NDArrayRender.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ static string Render(NDArray array)
9191
.Take(25)
9292
.Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())) + "'";
9393
else
94-
return $"['{string.Join("', '", array.StringData().Take(25))}']";
94+
return $"'{string.Join("', '", array.StringData().Take(25))}'";
9595
}
9696
else if (dtype == TF_DataType.TF_VARIANT)
9797
{

src/TensorFlowNET.Core/Tensorflow.Binding.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ tf.net 0.6x.x aligns with TensorFlow v2.6.x native library.</PackageReleaseNotes
5151

5252
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
5353
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
54-
<DefineConstants>TRACE;DEBUG</DefineConstants>
54+
<DefineConstants>TRACE;DEBUG;TRACK_TENSOR_LIFE1</DefineConstants>
5555
<PlatformTarget>x64</PlatformTarget>
5656
<DocumentationFile>TensorFlow.NET.xml</DocumentationFile>
5757
</PropertyGroup>

src/TensorFlowNET.Core/Tensors/SafeStringTensorHandle.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace Tensorflow
88
public sealed class SafeStringTensorHandle : SafeTensorHandle
99
{
1010
Shape _shape;
11-
IntPtr _handle;
11+
SafeTensorHandle _tensorHandle;
1212
const int TF_TSRING_SIZE = 24;
1313

1414
protected SafeStringTensorHandle()
@@ -18,23 +18,26 @@ protected SafeStringTensorHandle()
1818
public SafeStringTensorHandle(SafeTensorHandle handle, Shape shape)
1919
: base(handle.DangerousGetHandle())
2020
{
21-
_handle = c_api.TF_TensorData(handle);
21+
_tensorHandle = handle;
2222
_shape = shape;
23+
bool success = false;
24+
_tensorHandle.DangerousAddRef(ref success);
2325
}
2426

2527
protected override bool ReleaseHandle()
2628
{
29+
var _handle = c_api.TF_TensorData(_tensorHandle);
2730
#if TRACK_TENSOR_LIFE
28-
print($"Delete StringTensorHandle 0x{handle.ToString("x16")}");
31+
Console.WriteLine($"Delete StringTensorData 0x{_handle.ToString("x16")}");
2932
#endif
30-
3133
for (int i = 0; i < _shape.size; i++)
3234
{
3335
c_api.TF_StringDealloc(_handle);
3436
_handle += TF_TSRING_SIZE;
3537
}
3638

3739
SetHandle(IntPtr.Zero);
40+
_tensorHandle.DangerousRelease();
3841

3942
return true;
4043
}

src/TensorFlowNET.Core/Tensors/Tensor.String.cs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ public SafeStringTensorHandle StringTensor(byte[][] buffer, Shape shape)
2929

3030
var tstr = c_api.TF_TensorData(handle);
3131
#if TRACK_TENSOR_LIFE
32-
print($"New TString 0x{handle.ToString("x16")} Data: 0x{tstr.ToString("x16")}");
32+
print($"New StringTensor {handle} Data: 0x{tstr.ToString("x16")}");
3333
#endif
3434
for (int i = 0; i < buffer.Length; i++)
3535
{
3636
c_api.TF_StringInit(tstr);
3737
c_api.TF_StringCopy(tstr, buffer[i], buffer[i].Length);
38-
var data = c_api.TF_StringGetDataPointer(tstr);
38+
// var data = c_api.TF_StringGetDataPointer(tstr);
3939
tstr += TF_TSRING_SIZE;
4040
}
4141

@@ -53,6 +53,36 @@ public string[] StringData()
5353
return _str;
5454
}
5555

56+
public string StringData(int index)
57+
{
58+
var bytes = StringBytes(index);
59+
return Encoding.UTF8.GetString(bytes);
60+
}
61+
62+
public byte[] StringBytes(int index)
63+
{
64+
if (dtype != TF_DataType.TF_STRING)
65+
throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})");
66+
67+
byte[] buffer = new byte[0];
68+
var tstrings = TensorDataPointer;
69+
for (int i = 0; i < shape.size; i++)
70+
{
71+
if(index == i)
72+
{
73+
var data = c_api.TF_StringGetDataPointer(tstrings);
74+
var len = c_api.TF_StringGetSize(tstrings);
75+
buffer = new byte[len];
76+
// var capacity = c_api.TF_StringGetCapacity(tstrings);
77+
// var type = c_api.TF_StringGetType(tstrings);
78+
Marshal.Copy(data, buffer, 0, Convert.ToInt32(len));
79+
break;
80+
}
81+
tstrings += TF_TSRING_SIZE;
82+
}
83+
return buffer;
84+
}
85+
5686
public byte[][] StringBytes()
5787
{
5888
if (dtype != TF_DataType.TF_STRING)

src/TensorFlowNET.Keras/Utils/np_utils.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public static NDArray to_categorical(NDArray y, int num_classes = -1, TF_DataTyp
2222
// categorical[np.arange(y.size), y] = 1;
2323
for (var i = 0; i < (int)y.size; i++)
2424
{
25-
categorical[i][y1[i]] = 1.0f;
25+
categorical[i, y1[i]] = 1.0f;
2626
}
2727

2828
return categorical;

test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,11 @@ public void StringArray()
5151
{
5252
var strings = new[] { "map_and_batch_fusion", "noop_elimination", "shuffle_and_repeat_fusion" };
5353
var tensor = tf.constant(strings, dtype: tf.@string, name: "optimizations");
54-
var stringData = tensor.StringData();
5554

5655
Assert.AreEqual(3, tensor.shape[0]);
57-
Assert.AreEqual(strings[0], stringData[0]);
58-
Assert.AreEqual(strings[1], stringData[1]);
59-
Assert.AreEqual(strings[2], stringData[2]);
56+
Assert.AreEqual(tensor[0].numpy(), strings[0]);
57+
Assert.AreEqual(tensor[1].numpy(), strings[1]);
58+
Assert.AreEqual(tensor[2].numpy(), strings[2]);
6059
}
6160

6261
[TestMethod]

0 commit comments

Comments
 (0)