Skip to content

Commit 4aab86b

Browse files
committed
Fix the error when using layers.Input with unknown batch size.
1 parent 8da573c commit 4aab86b

File tree

2 files changed

+146
-8
lines changed

2 files changed

+146
-8
lines changed

src/TensorFlowNET.Core/Operations/math_ops.cs

Lines changed: 104 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -905,13 +905,29 @@ public static Tensor tensordot(Tensor a, Tensor b, NDArray axes, string name = n
905905
var (a_reshape, a_free_dims, a_free_dims_static) = _tensordot_reshape(a, a_axes);
906906
var (b_reshape, b_free_dims, b_free_dims_static) = _tensordot_reshape(b, b_axes, true);
907907
var ab_matmul = matmul(a_reshape, b_reshape);
908-
var dims = new List<int>();
909-
dims.AddRange(a_free_dims);
910-
dims.AddRange(b_free_dims);
911-
if (ab_matmul.shape.Equals(dims))
912-
return ab_matmul;
908+
if(a_free_dims is int[] a_free_dims_list && b_free_dims is int[] b_free_dims_list)
909+
{
910+
var total_free_dims = a_free_dims_list.Concat(b_free_dims_list).ToArray();
911+
if (ab_matmul.shape.IsFullyDefined && ab_matmul.shape.as_int_list().SequenceEqual(total_free_dims))
912+
{
913+
return ab_matmul;
914+
}
915+
else
916+
{
917+
return array_ops.reshape(ab_matmul, ops.convert_to_tensor(total_free_dims), name);
918+
}
919+
}
913920
else
914-
return array_ops.reshape(ab_matmul, tf.constant(dims.ToArray()), name: name);
921+
{
922+
var a_free_dims_tensor = ops.convert_to_tensor(a_free_dims, dtype: dtypes.int32);
923+
var b_free_dims_tensor = ops.convert_to_tensor(b_free_dims, dtype: dtypes.int32);
924+
var product = array_ops.reshape(ab_matmul, array_ops.concat(new[] { a_free_dims_tensor, b_free_dims_tensor }, 0), name);
925+
if(a_free_dims_static is not null && b_free_dims_static is not null)
926+
{
927+
product.shape = new Shape(a_free_dims_static.Concat(b_free_dims_static).ToArray());
928+
}
929+
return product;
930+
}
915931
});
916932
}
917933

@@ -927,14 +943,42 @@ public static Tensor tensordot(Tensor a, Tensor b, NDArray axes, string name = n
927943
return (Binding.range(a.shape.ndim - axe, a.shape.ndim).ToArray(),
928944
Binding.range(0, axe).ToArray());
929945
}
930-
else
946+
else if(axes.rank == 1)
931947
{
948+
if (axes.shape[0] != 2)
949+
{
950+
throw new ValueError($"`axes` must be an integer or have length 2. Received {axes}.");
951+
}
932952
(int a_axe, int b_axe) = (axes[0], axes[1]);
933953
return (new[] { a_axe }, new[] { b_axe });
934954
}
955+
else if(axes.rank == 2)
956+
{
957+
if (axes.shape[0] != 2)
958+
{
959+
throw new ValueError($"`axes` must be an integer or have length 2. Received {axes}.");
960+
}
961+
int[] a_axes = new int[axes.shape[1]];
962+
int[] b_axes = new int[axes.shape[1]];
963+
for(int i = 0; i < a_axes.Length; i++)
964+
{
965+
a_axes[i] = axes[0, i];
966+
b_axes[i] = axes[1, i];
967+
if (a_axes[i] == -1 || b_axes[i] == -1)
968+
{
969+
throw new ValueError($"Different number of contraction axes `a` and `b`," +
970+
$"{len(a_axes)} != {len(b_axes)}.");
971+
}
972+
}
973+
return (a_axes, b_axes);
974+
}
975+
else
976+
{
977+
throw new ValueError($"Invalid rank {axes.rank} to make tensor dot.");
978+
}
935979
}
936980

937-
static (Tensor, int[], int[]) _tensordot_reshape(Tensor a, int[] axes, bool flipped = false)
981+
static (Tensor, object, int[]) _tensordot_reshape(Tensor a, int[] axes, bool flipped = false)
938982
{
939983
if (a.shape.IsFullyDefined && isinstance(axes, (typeof(int[]), typeof(Tuple))))
940984
{
@@ -977,6 +1021,58 @@ public static Tensor tensordot(Tensor a, Tensor b, NDArray axes, string name = n
9771021
var reshaped_a = array_ops.reshape(a_trans, new_shape);
9781022
return (reshaped_a, free_dims, free_dims);
9791023
}
1024+
else
1025+
{
1026+
int[] free_dims_static;
1027+
Tensor converted_shape_a, converted_axes, converted_free;
1028+
if (a.shape.ndim != -1)
1029+
{
1030+
var shape_a = a.shape.as_int_list();
1031+
for(int i = 0; i < axes.Length; i++)
1032+
{
1033+
if (axes[i] < 0)
1034+
{
1035+
axes[i] += shape_a.Length;
1036+
}
1037+
}
1038+
var free = Enumerable.Range(0, shape_a.Length).Where(i => !axes.Contains(i)).ToArray();
1039+
1040+
var axes_dims = axes.Select(i => shape_a[i]);
1041+
var free_dims = free.Select(i => shape_a[i]).ToArray();
1042+
free_dims_static = free_dims;
1043+
converted_axes = ops.convert_to_tensor(axes, dtypes.int32, "axes");
1044+
converted_free = ops.convert_to_tensor(free, dtypes.int32, "free");
1045+
converted_shape_a = array_ops.shape(a);
1046+
}
1047+
else
1048+
{
1049+
free_dims_static = null;
1050+
converted_shape_a = array_ops.shape(a);
1051+
var rank_a = array_ops.rank(a);
1052+
converted_axes = ops.convert_to_tensor(axes, dtypes.int32, "axes");
1053+
converted_axes = array_ops.where_v2(converted_axes >= 0, converted_axes, converted_axes + rank_a);
1054+
(converted_free, var _) = gen_ops.list_diff(gen_math_ops.range(ops.convert_to_tensor(0), rank_a, ops.convert_to_tensor(1)),
1055+
converted_axes, dtypes.int32);
1056+
}
1057+
var converted_free_dims = array_ops.gather(converted_shape_a, converted_free);
1058+
var converted_axes_dims = array_ops.gather(converted_shape_a, converted_axes);
1059+
var prod_free_dims = reduce_prod(converted_free_dims);
1060+
var prod_axes_dims = reduce_prod(converted_axes_dims);
1061+
Tensor reshaped_a;
1062+
if (flipped)
1063+
{
1064+
var perm = array_ops.concat(new[] { converted_axes, converted_free }, 0);
1065+
var new_shape = array_ops.stack(new[] { prod_axes_dims, prod_free_dims });
1066+
reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape);
1067+
}
1068+
else
1069+
{
1070+
var perm = array_ops.concat(new[] { converted_free, converted_axes }, 0);
1071+
var new_shape = array_ops.stack(new[] { prod_free_dims, prod_axes_dims });
1072+
reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape);
1073+
}
1074+
return (reshaped_a, converted_free_dims, free_dims_static);
1075+
}
9801076

9811077
throw new NotImplementedException("_tensordot_reshape");
9821078
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Text;
6+
using System.Threading.Tasks;
7+
using static Tensorflow.Binding;
8+
9+
namespace TensorflowNET.Keras
10+
{
11+
[TestClass]
12+
public class ModelBuildTest
13+
{
14+
[TestMethod]
15+
public void DenseBuild()
16+
{
17+
// two dimensions input with unknown batchsize
18+
var input = tf.keras.layers.Input((17, 60));
19+
var dense = tf.keras.layers.Dense(64);
20+
var output = dense.Apply(input);
21+
var model = tf.keras.Model(input, output);
22+
23+
// one dimensions input with unknown batchsize
24+
var input_2 = tf.keras.layers.Input((60));
25+
var dense_2 = tf.keras.layers.Dense(64);
26+
var output_2 = dense.Apply(input_2);
27+
var model_2 = tf.keras.Model(input_2, output_2);
28+
29+
// two dimensions input with specified batchsize
30+
var input_3 = tf.keras.layers.Input((17, 60), 8);
31+
var dense_3 = tf.keras.layers.Dense(64);
32+
var output_3 = dense.Apply(input_3);
33+
var model_3 = tf.keras.Model(input_3, output_3);
34+
35+
// one dimensions input with specified batchsize
36+
var input_4 = tf.keras.layers.Input((60), 8);
37+
var dense_4 = tf.keras.layers.Dense(64);
38+
var output_4 = dense.Apply(input_4);
39+
var model_4 = tf.keras.Model(input_4, output_4);
40+
}
41+
}
42+
}

0 commit comments

Comments
 (0)