Skip to content

Commit f48ba40

Browse files
committed
Fix MaxPooling1D #969
1 parent 4f88109 commit f48ba40

File tree

6 files changed

+28
-21
lines changed

6 files changed

+28
-21
lines changed

src/TensorFlowNET.Core/Operations/NnOps/MaxPoolFunction.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using System.Linq;
1718
using static Tensorflow.Binding;
1819

1920
namespace Tensorflow.Operations
@@ -24,7 +25,7 @@ namespace Tensorflow.Operations
2425
public class MaxPoolFunction : IPoolFunction
2526
{
2627
public Tensor Apply(Tensor value,
27-
int[] ksize,
28+
int[] pool_size,
2829
int[] strides,
2930
string padding,
3031
string data_format = "NHWC",
@@ -33,10 +34,9 @@ public Tensor Apply(Tensor value,
3334
return tf_with(ops.name_scope(name, "MaxPool", value), scope =>
3435
{
3536
name = scope;
36-
value = ops.convert_to_tensor(value, name: "input");
3737
return gen_nn_ops.max_pool(
3838
value,
39-
ksize: ksize,
39+
ksize: pool_size,
4040
strides: strides,
4141
padding: padding,
4242
data_format: data_format,

src/TensorFlowNET.Core/Tensorflow.Binding.csproj

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
<AssemblyName>Tensorflow.Binding</AssemblyName>
66
<RootNamespace>Tensorflow</RootNamespace>
77
<TargetTensorFlow>2.10.0</TargetTensorFlow>
8-
<Version>0.100.1</Version>
8+
<Version>0.100.2</Version>
99
<LangVersion>10.0</LangVersion>
1010
<Nullable>enable</Nullable>
1111
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
@@ -20,7 +20,7 @@
2020
<Description>Google's TensorFlow full binding in .NET Standard.
2121
Building, training and infering deep learning models.
2222
https://tensorflownet.readthedocs.io</Description>
23-
<AssemblyVersion>0.100.1.0</AssemblyVersion>
23+
<AssemblyVersion>0.100.2.0</AssemblyVersion>
2424
<PackageReleaseNotes>
2525
tf.net 0.100.x and above are based on tensorflow native 2.10.0
2626

@@ -38,7 +38,7 @@ https://tensorflownet.readthedocs.io</Description>
3838
tf.net 0.7x.x aligns with TensorFlow v2.7.x native library.
3939
tf.net 0.10x.x aligns with TensorFlow v2.10.x native library.
4040
</PackageReleaseNotes>
41-
<FileVersion>0.100.1.0</FileVersion>
41+
<FileVersion>0.100.2.0</FileVersion>
4242
<PackageLicenseFile>LICENSE</PackageLicenseFile>
4343
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
4444
<SignAssembly>true</SignAssembly>

src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using System.Linq;
1718
using Tensorflow.Keras.ArgsDefinition;
1819
using Tensorflow.Keras.Engine;
1920
using Tensorflow.Keras.Utils;
21+
using static Tensorflow.Binding;
2022

2123
namespace Tensorflow.Keras.Layers
2224
{
@@ -36,27 +38,31 @@ public Pooling1D(Pooling1DArgs args)
3638

3739
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
3840
{
39-
int[] pool_shape;
40-
int[] strides;
41+
int pad_axis = args.DataFormat == "channels_first" ? 2 : 3;
42+
inputs = tf.expand_dims(inputs, pad_axis);
43+
int[] pool_shape = new int[] { args.PoolSize, 1 };
44+
int[] strides = new int[] { args.Strides, 1 };
45+
var ndim = inputs[0].ndim;
46+
4147
if (args.DataFormat == "channels_last")
4248
{
43-
pool_shape = new int[] { 1, args.PoolSize, 1 };
44-
strides = new int[] { 1, args.Strides, 1 };
49+
pool_shape = new int[] { 1 }.Concat(pool_shape).Concat(new int[] { 1 }).ToArray();
50+
strides = new int[] { 1 }.Concat(strides).Concat(new int[] { 1 }).ToArray();
4551
}
4652
else
4753
{
48-
pool_shape = new int[] { 1, 1, args.PoolSize };
49-
strides = new int[] { 1, 1, args.Strides };
54+
pool_shape = new int[] { 1, 1 }.Concat(pool_shape).ToArray();
55+
strides = new int[] { 1, 1 }.Concat(strides).ToArray();
5056
}
5157

5258
var outputs = args.PoolFunction.Apply(
5359
inputs,
5460
ksize: pool_shape,
5561
strides: strides,
5662
padding: args.Padding.ToUpper(),
57-
data_format: conv_utils.convert_data_format(args.DataFormat, 3));
63+
data_format: conv_utils.convert_data_format(args.DataFormat, ndim));
5864

59-
return outputs;
65+
return tf.squeeze(outputs, pad_axis);
6066
}
6167
}
6268
}

src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? train
4242
int[] strides;
4343
if (args.DataFormat == "channels_last")
4444
{
45-
pool_shape = new int[] { 1, (int)args.PoolSize.dims[0], (int)args.PoolSize.dims[1], 1 };
45+
pool_shape = new int[] { 1, (int)args.PoolSize.dims[0], (int)args.PoolSize.dims[1], 1 };
4646
strides = new int[] { 1, (int)args.Strides.dims[0], (int)args.Strides.dims[1], 1 };
4747
}
4848
else

src/TensorFlowNET.Keras/Tensorflow.Keras.csproj

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
<Nullable>enable</Nullable>
88
<RootNamespace>Tensorflow.Keras</RootNamespace>
99
<Platforms>AnyCPU;x64</Platforms>
10-
<Version>0.10.1</Version>
10+
<Version>0.10.2</Version>
1111
<Authors>Haiping Chen</Authors>
1212
<Product>Keras for .NET</Product>
1313
<Copyright>Apache 2.0, Haiping Chen 2021</Copyright>
@@ -37,8 +37,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
3737
<RepositoryType>Git</RepositoryType>
3838
<SignAssembly>true</SignAssembly>
3939
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>
40-
<AssemblyVersion>0.10.1.0</AssemblyVersion>
41-
<FileVersion>0.10.1.0</FileVersion>
40+
<AssemblyVersion>0.10.2.0</AssemblyVersion>
41+
<FileVersion>0.10.2.0</FileVersion>
4242
<PackageLicenseFile>LICENSE</PackageLicenseFile>
4343
<Configurations>Debug;Release;GPU</Configurations>
4444
</PropertyGroup>
@@ -70,7 +70,7 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
7070
</PropertyGroup>
7171

7272
<ItemGroup>
73-
<PackageReference Include="HDF5-CSharp" Version="1.16.2" />
73+
<PackageReference Include="HDF5-CSharp" Version="1.16.3" />
7474
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
7575
<PackageReference Include="Newtonsoft.Json" Version="13.0.2" />
7676
<PackageReference Include="SharpZipLib" Version="1.4.1" />

test/TensorFlowNET.Keras.UnitTest/Layers/PoolingTest.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using Tensorflow;
55
using static Tensorflow.Binding;
66
using static Tensorflow.KerasApi;
7+
using Microsoft.VisualBasic;
78

89
namespace TensorFlowNET.Keras.UnitTest
910
{
@@ -226,7 +227,7 @@ public void GlobalMax2DPoolingChannelsFirst()
226227
Assert.AreEqual(expected, y[0].numpy());
227228
}
228229

229-
[TestMethod, Ignore("There's an error generated from TF complaining about the shape of the pool. Needs further investigation.")]
230+
[TestMethod]
230231
public void Max1DPoolingChannelsLast()
231232
{
232233
var x = input_array_1D;
@@ -239,7 +240,7 @@ public void Max1DPoolingChannelsLast()
239240

240241
var expected = np.array(new float[,,]
241242
{
242-
{{2.0f, 2.0f, 3.0f, 3.0f, 3.0f},
243+
{{1.0f, 2.0f, 3.0f, 3.0f, 3.0f},
243244
{ 1.0f, 2.0f, 3.0f, 3.0f, 3.0f}},
244245

245246
{{4.0f, 5.0f, 6.0f, 3.0f, 3.0f},

0 commit comments

Comments
 (0)