Skip to content

Commit 179e32a

Browse files
authored
Merge pull request #1044 from AsakusaRinne/add_cv_compatibility
Add the constructor of NDArray which reuses memory
2 parents 34338c7 + 44d203d commit 179e32a

File tree

5 files changed

+47
-3
lines changed

5 files changed

+47
-3
lines changed

src/TensorFlowNET.Core/NumPy/Persistence/NpyFormat.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ private static string GetDtypeName(NDArray array, out Type type, out int bytes)
7070
if (type == typeof(bool))
7171
return "|b1";
7272
else if (type == typeof(byte))
73-
return "|i1";
73+
return "|u1";
7474
else if (type == typeof(short))
7575
return "<i2";
7676
else if (type == typeof(int))

src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ namespace Tensorflow.NumPy
88
{
99
public partial class NDArray
1010
{
11+
protected NDArray() { }
1112
public NDArray(bool value) : base(value) => NewEagerTensorHandle();
1213
public NDArray(byte value) : base(value) => NewEagerTensorHandle();
1314
public NDArray(short value) : base(value) => NewEagerTensorHandle();
@@ -57,6 +58,20 @@ public static NDArray Scalar<T>(T value) where T : unmanaged
5758
_ => throw new NotImplementedException("")
5859
};
5960

61+
/// <summary>
62+
/// Reuse the existing memory instead of copying it.
63+
/// </summary>
64+
/// <param name="data_ptr"></param>
65+
/// <param name="shape"></param>
66+
/// <param name="dtype"></param>
67+
/// <param name="deallocator"></param>
68+
protected void InitWithExistingMemory(IntPtr data_ptr, Shape shape, TF_DataType dtype, c_api.DeallocatorV2 deallocator)
69+
{
70+
_handle = c_api.TF_NewTensor(TF_DataType.TF_STRING, shape.dims, shape.ndim, data_ptr, (ulong)(shape.size * dtype.get_datatype_size()), deallocator, IntPtr.Zero);
71+
tensor_util.DangerousManuallySetTensorDType(_handle, dtype);
72+
NewEagerTensorHandle();
73+
}
74+
6075
void NewEagerTensorHandle()
6176
{
6277
if (_handle is not null)

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ public static Tensor ones(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT,
417417
{
418418
TF_DataType.TF_DOUBLE => constant(1.0d),
419419
TF_DataType.TF_FLOAT => constant(1.0f),
420-
_ => constant(1)
420+
_ => constant(1, dtype)
421421
};
422422

423423
if (shape.ndim == 0)

src/TensorFlowNET.Core/Tensors/c_api.tensor.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ public partial class c_api
7171
/// <param name="deallocator_arg"></param>
7272
/// <returns></returns>
7373
[DllImport(TensorFlowLibName)]
74-
public static extern SafeTensorHandle TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len, Deallocator deallocator, IntPtr deallocator_arg);
74+
public static extern SafeTensorHandle TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len, DeallocatorV2 deallocator, IntPtr deallocator_arg);
7575

7676
public static unsafe SafeTensorHandle TF_NewTensor(byte[] data, Shape shape, TF_DataType dtype)
7777
{
@@ -147,6 +147,15 @@ public static unsafe SafeTensorHandle TF_NewTensor<T>(T value)
147147
[DllImport(TensorFlowLibName)]
148148
public static extern TF_DataType TF_TensorType(SafeTensorHandle tensor);
149149

150+
/// <summary>
151+
/// Set a new shape for the Tensor. Note that this API only works after tf2.11.
152+
/// </summary>
153+
/// <param name="tensor"></param>
154+
/// <param name="dims"></param>
155+
/// <param name="num_dims"></param>
156+
[DllImport(TensorFlowLibName)]
157+
public static extern void TF_SetShape(SafeTensorHandle tensor, long[] dims, int num_dims);
158+
150159
/// <summary>
151160
/// Return the size in bytes required to encode a string `len` bytes long into a
152161
/// TF_STRING tensor.

src/TensorFlowNET.Core/Tensors/tensor_util.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License.
2222
using Tensorflow.Eager;
2323
using Tensorflow.Graphs;
2424
using static Tensorflow.Binding;
25+
using System.Diagnostics;
2526

2627
namespace Tensorflow
2728
{
@@ -649,5 +650,24 @@ public static ParsedSliceArgs ParseSlices(Tensor start, Tensor stop = null, Tens
649650
NewAxisMask = new_axis_mask
650651
};
651652
}
653+
654+
/// <summary>
655+
/// Warning: this method is an extremely dangerous method. It directly changes the dtype inside the tensor
656+
/// and security is not guaranteed at all. Currently this method is only used for some conditions to reuse
657+
/// the existing memory. Any other usage should be prevented. If you are sure you want to use it when
658+
/// developing tensorflow.net, please ask @Oceanic2018 or @AsakusaRinne first.
659+
/// </summary>
660+
/// <param name="handle"></param>
661+
/// <param name="dtype"></param>
662+
internal static unsafe void DangerousManuallySetTensorDType(SafeTensorHandle handle, TF_DataType dtype)
663+
{
664+
long tf_tensor_address = handle.DangerousGetHandle().ToInt64();
665+
long interface_address = *(long*)(tf_tensor_address);
666+
long tensor_shape_address = interface_address + 8;
667+
long tensor_dtype_address = tensor_shape_address + 13;
668+
byte* dtype_pointer = (byte*)tensor_dtype_address;
669+
*dtype_pointer = (byte)dtype;
670+
Debug.Assert(c_api.TF_TensorType(handle) == dtype);
671+
}
652672
}
653673
}

0 commit comments

Comments
 (0)