Skip to content

Commit 67a7fc5

Browse files
dssOceania2018
dss
authored andcommitted
fix Shape.Equals
1 parent 80e5e18 commit 67a7fc5

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

src/TensorFlowNET.Core/NumPy/ShapeHelper.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ public static bool Equals(Shape shape, object target)
100100
if (shape.ndim != shape2.Length)
101101
return false;
102102
return Enumerable.SequenceEqual(shape.dims, shape2);
103+
case int[] shape3:
104+
if (shape.ndim != shape3.Length)
105+
return false;
106+
return Enumerable.SequenceEqual(shape.as_int_list(), shape3);
103107
default:
104108
return false;
105109
}

src/TensorFlowNET.Keras/BackendImpl.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,19 +347,21 @@ public Tensor conv2d_transpose(Tensor x,
347347
string data_format = null,
348348
Shape dilation_rate = null)
349349
{
350+
/*
350351
var force_transpose = false;
351352
if (data_format == "channels_first" && !dilation_rate.Equals(new[] { 1, 1 }))
352353
force_transpose = true;
353-
// x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
354+
x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
355+
*/
354356
var tf_data_format = "NHWC";
355357
padding = padding.ToUpper();
356358
strides = new Shape(1, strides[0], strides[1], 1);
357-
if (dilation_rate.Equals(new long[] { 1, 1 }))
359+
if (dilation_rate.Equals(new[] { 1, 1 }))
358360
x = nn_impl.conv2d_transpose(x, kernel, output_shape, strides,
359361
padding: padding,
360362
data_format: tf_data_format);
361363
else
362-
throw new NotImplementedException("");
364+
throw new NotImplementedException("dilation_rate other than [1,1] is not yet supported");
363365

364366
return x;
365367
}

0 commit comments

Comments
 (0)