Skip to content

Commit ec340ee

Browse files
committed
np.ones_like and np.zeros_like
1 parent 197224f commit ec340ee

File tree

5 files changed

+24
-14
lines changed

5 files changed

+24
-14
lines changed

src/TensorFlowNET.Console/SimpleRnnTest.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ public class SimpleRnnTest
1212
{
1313
public void Run()
1414
{
15-
tf.UseKeras<KerasInterface>();
1615
var inputs = np.random.random((6, 10, 8)).astype(np.float32);
1716
//var simple_rnn = tf.keras.layers.SimpleRNN(4);
1817
//var output = simple_rnn.Apply(inputs); // The output has shape `[32, 4]`.

src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
using System.Collections;
33
using System.Collections.Generic;
44
using System.IO;
5-
using System.Numerics;
65
using System.Text;
76
using static Tensorflow.Binding;
87

@@ -103,11 +102,15 @@ public static NDArray ndarray(Shape shape, TF_DataType dtype = TF_DataType.TF_DO
103102
public static NDArray ones(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE)
104103
=> new NDArray(tf.ones(shape, dtype: dtype));
105104

106-
public static NDArray ones_like(NDArray a, Type dtype = null)
107-
=> throw new NotImplementedException("");
105+
public static NDArray ones_like(NDArray a, TF_DataType dtype = TF_DataType.DtInvalid)
106+
=> new NDArray(tf.ones_like(a, dtype: dtype));
108107

109108
[AutoNumPy]
110109
public static NDArray zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE)
111110
=> new NDArray(tf.zeros(shape, dtype: dtype));
111+
112+
[AutoNumPy]
113+
public static NDArray zeros_like(NDArray a, TF_DataType dtype = TF_DataType.DtInvalid)
114+
=> new NDArray(tf.zeros_like(a, dtype: dtype));
112115
}
113116
}

src/TensorFlowNET.Core/Sessions/BaseSession.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,14 @@ private void _extend_graph()
291291
protected override void DisposeUnmanagedResources(IntPtr handle)
292292
{
293293
// c_api.TF_CloseSession(handle, tf.Status.Handle);
294-
c_api.TF_DeleteSession(handle, c_api.TF_NewStatus());
294+
if (tf.Status == null || tf.Status.Handle.IsInvalid)
295+
{
296+
c_api.TF_DeleteSession(handle, c_api.TF_NewStatus());
297+
}
298+
else
299+
{
300+
c_api.TF_DeleteSession(handle, tf.Status.Handle);
301+
}
295302
}
296303
}
297304
}

src/python/simple_rnn.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import numpy as np
22
import tensorflow as tf
3+
import tensorflow.experimental.numpy as tnp
34

45
# tf.experimental.numpy
5-
inputs = np.random.random([32, 10, 8]).astype(np.float32)
6-
simple_rnn = tf.keras.layers.SimpleRNN(4)
6+
inputs = np.arange(6 * 10 * 8).reshape([6, 10, 8]).astype(np.float32)
7+
# simple_rnn = tf.keras.layers.SimpleRNN(4)
78

8-
output = simple_rnn(inputs) # The output has shape `[32, 4]`.
9+
# output = simple_rnn(inputs) # The output has shape `[6, 4]`.
910

10-
simple_rnn = tf.keras.layers.SimpleRNN(
11-
4, return_sequences=True, return_state=True)
11+
simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences=True, return_state=True)
1212

13-
# whole_sequence_output has shape `[32, 10, 4]`.
14-
# final_state has shape `[32, 4]`.
15-
whole_sequence_output, final_state = simple_rnn(inputs)
13+
# whole_sequence_output has shape `[6, 10, 4]`.
14+
# final_state has shape `[6, 4]`.
15+
whole_sequence_output, final_state = simple_rnn(inputs)
16+
print(whole_sequence_output)
17+
print(final_state)

test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
<TargetFramework>net6.0</TargetFramework>
55

66
<IsPackable>false</IsPackable>
7-
<LangVersion>11.0</LangVersion>
87
<Platforms>AnyCPU;x64</Platforms>
98
</PropertyGroup>
109

0 commit comments

Comments
 (0)