Skip to content

Commit 824308a

Browse files
committed
tf.tensordot #898
1 parent 0440282 commit 824308a

File tree

5 files changed

+96
-112
lines changed

5 files changed

+96
-112
lines changed

src/TensorFlowNET.Core/APIs/tf.linalg.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ public Tensor global_norm(Tensor[] t_list, string name = null)
5757
public Tensor lstsq(Tensor matrix, Tensor rhs,
5858
NDArray l2_regularizer = null, bool fast = true, string name = null)
5959
=> ops.matrix_solve_ls(matrix, rhs, l2_regularizer: l2_regularizer, fast: fast, name: name);
60+
61+
public Tensor tensordot(Tensor x, Tensor y, NDArray axes, string name = null)
62+
=> math_ops.tensordot(x, y, axes, name: name);
6063
}
6164

6265
public Tensor diag(Tensor diagonal, string name = null)

src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Statistics.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public NDArray average(NDArray a, int axis = -1, NDArray? weights = null, bool r
2020
if(a.rank != weights.rank)
2121
{
2222
var weights_sum = math_ops.reduce_sum(tensorW);
23-
var axes = ops.convert_to_tensor(new[,] { { axis }, { 0 } });
23+
var axes = np.array(new[,] { { axis }, { 0 } });
2424
var avg = math_ops.tensordot(a, weights, axes) / weights_sum;
2525
}
2626

src/TensorFlowNET.Core/NumPy/ShapeHelper.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,14 @@ public static bool Equals(Shape shape, object target)
104104
if (shape.ndim != shape3.Length)
105105
return false;
106106
return Enumerable.SequenceEqual(shape.as_int_list(), shape3);
107+
case List<long> shape4:
108+
if (shape.ndim != shape4.Count)
109+
return false;
110+
return Enumerable.SequenceEqual(shape.dims, shape4);
111+
case List<int> shape5:
112+
if (shape.ndim != shape5.Count)
113+
return false;
114+
return Enumerable.SequenceEqual(shape.as_int_list(), shape5);
107115
default:
108116
return false;
109117
}

src/TensorFlowNET.Core/Operations/math_ops.cs

Lines changed: 70 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -868,133 +868,92 @@ public static Tensor conj(Tensor x, string name = null)
868868
public static Tensor tanh(Tensor x, string name = null)
869869
=> gen_math_ops.tanh(x, name);
870870

871-
public static Tensor tensordot(Tensor x, Tensor y, int[] axes, string name = null)
871+
public static Tensor tensordot(Tensor a, Tensor b, NDArray axes, string name = null)
872872
{
873-
Tensor _tensordot_reshape(Tensor a, int[] axes, bool flipped = false)
873+
return tf_with(ops.name_scope(name, "Tensordot", new { a, b, axes }), scope =>
874874
{
875-
if (a.shape.IsFullyDefined && isinstance(axes, (typeof(List<object>), typeof(Tuple))))
876-
{
877-
var shape_a = a.shape.dims;
878-
879-
// axes
880-
int iter = 0;
881-
foreach (int i in axes)
882-
{
883-
if (i >= 0)
884-
axes[0 + iter] = i;
885-
else
886-
axes[0 + iter] = i + len(shape_a);
887-
iter++;
888-
}
889-
890-
// free
891-
int[] free = { };
892-
iter = 0;
893-
foreach (int i in Enumerable.Range(0, len(axes)))
894-
if (!Array.Exists(axes, i => i == i))
895-
free[free.Length] = i;
896-
897-
// free_dims
898-
int[] free_dims = { };
899-
foreach (int i in free)
900-
free_dims[free_dims.Length] = (int)shape_a[i];
901-
902-
int prod_free = (int)np.prod(free_dims);
903-
904-
// prod_axes
905-
int[] prod_axes_pre = { };
906-
foreach (int i in axes)
907-
prod_axes_pre[prod_axes_pre.Length] = (int)shape_a[i];
908-
int prod_axes = (int)np.prod(prod_axes_pre);
909-
910-
// perm
911-
Tensor perm;
912-
if (flipped)
913-
perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free);
914-
else
915-
perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free)
916-
+ ops.convert_to_tensor(list(axes));
917-
918-
// new_shape
919-
Shape new_shape;
920-
if (flipped)
921-
new_shape = new Shape(new int[] { prod_axes, prod_free });
922-
else
923-
new_shape = new Shape(new int[] { prod_free, prod_axes });
924-
}
875+
name = scope;
876+
var (a_axes, b_axes) = _tensordot_axes(a, axes);
877+
var (a_reshape, a_free_dims, a_free_dims_static) = _tensordot_reshape(a, a_axes);
878+
var (b_reshape, b_free_dims, b_free_dims_static) = _tensordot_reshape(b, b_axes, true);
879+
var ab_matmul = matmul(a_reshape, b_reshape);
880+
var dims = new List<int>();
881+
dims.AddRange(a_free_dims);
882+
dims.AddRange(b_free_dims);
883+
if (ab_matmul.shape.Equals(dims))
884+
return ab_matmul;
885+
else
886+
return array_ops.reshape(ab_matmul, tf.constant(dims.ToArray()), name: name);
887+
});
888+
}
925889

926-
throw new NotImplementedException("_tensordot_reshape");
890+
static (int[], int[]) _tensordot_axes(Tensor a, NDArray axes)
891+
{
892+
if (axes.rank == 0)
893+
{
894+
int axe = axes;
895+
if (axe > a.shape.ndim)
896+
throw new ValueError("`axes` must not be larger than the number of " +
897+
$"dimensions of tensor {a}. Received {axes}, vs " +
898+
$"tensor dimensions {a.ndim}.");
899+
return (Binding.range(a.shape.ndim - axe, a.shape.ndim).ToArray(),
900+
Binding.range(0, axe).ToArray());
901+
}
902+
else
903+
{
904+
(int a_axe, int b_axe) = (axes[0], axes[1]);
905+
return (new[] { a_axe }, new[] { b_axe });
927906
}
928-
929-
throw new NotImplementedException("tensordot");
930907
}
931908

932-
public static Tensor tensordot(Tensor x, Tensor y, Tensor axes, string name = null)
909+
static (Tensor, int[], int[]) _tensordot_reshape(Tensor a, int[] axes, bool flipped = false)
933910
{
934-
Tensor _tensordot_reshape(Tensor a, int[] axes, bool flipped = false)
911+
if (a.shape.IsFullyDefined && isinstance(axes, (typeof(int[]), typeof(Tuple))))
935912
{
936-
if (a.shape.IsFullyDefined && isinstance(axes, (typeof(List<object>), typeof(Tuple))))
937-
{
938-
var shape_a = a.shape.dims;
913+
var shape_a = a.shape.as_int_list();
939914

940-
// axes
941-
int iter = 0;
942-
foreach (int i in axes)
943-
{
944-
if (i >= 0)
945-
axes[0 + iter] = i;
946-
else
947-
axes[0 + iter] = i + len(shape_a);
948-
iter++;
949-
}
915+
// axes
916+
axes = axes.Select(i => i >= 0 ? i : i + len(shape_a)).ToArray();
917+
918+
// free
919+
int[] free = Binding.range(a.shape.ndim).Where(i => !axes.Contains(i)).ToArray();
920+
921+
// free_dims
922+
int[] free_dims = free.Select(i => shape_a[i]).ToArray();
950923

951-
// free
952-
int[] free = { };
953-
iter = 0;
954-
foreach (int i in Enumerable.Range(0, len(axes)))
955-
if (!Array.Exists(axes, i => i == i))
956-
free[free.Length] = i;
957-
958-
// free_dims
959-
int[] free_dims = { };
960-
foreach (int i in free)
961-
free_dims[free_dims.Length] = (int)shape_a[i];
962-
963-
int prod_free = (int)np.prod(free_dims);
964-
965-
// prod_axes
966-
int[] prod_axes_pre = { };
967-
foreach (int i in axes)
968-
prod_axes_pre[prod_axes_pre.Length] = (int)shape_a[i];
969-
int prod_axes = (int)np.prod(prod_axes_pre);
970-
971-
// perm
972-
Tensor perm;
973-
if (flipped)
974-
perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free);
975-
else
976-
perm = ops.convert_to_tensor(list(free)) + ops.convert_to_tensor(free)
977-
+ ops.convert_to_tensor(list(axes));
978-
979-
// new_shape
980-
Shape new_shape;
981-
if (flipped)
982-
new_shape = new Shape(new int[] { prod_axes, prod_free });
983-
else
984-
new_shape = new Shape(new int[] { prod_free, prod_axes });
924+
int prod_free = np.prod(free_dims);
925+
926+
// prod_axes
927+
int prod_axes = np.prod(axes.Select(i => shape_a[i]).ToArray());
928+
929+
// perm
930+
List<int> perm = new List<int>();
931+
if (flipped)
932+
{
933+
perm.AddRange(axes);
934+
perm.AddRange(free);
935+
}
936+
else
937+
{
938+
perm.AddRange(free);
939+
perm.AddRange(axes);
985940
}
986941

987-
throw new NotImplementedException("_tensordot_reshape");
942+
// new_shape
943+
Shape new_shape;
944+
if (flipped)
945+
new_shape = new Shape(new int[] { prod_axes, prod_free });
946+
else
947+
new_shape = new Shape(new int[] { prod_free, prod_axes });
948+
var a_trans = a;
949+
var reshaped_a = array_ops.reshape(a_trans, new_shape);
950+
return (reshaped_a, free_dims, free_dims);
988951
}
989952

990-
return tf_with(ops.name_scope(name, "Tensordot", new { x, y, axes }), scope =>
991-
{
992-
name = scope;
993-
var (a_axes, b_axes) = (axes[0], axes[1]);
994-
return x;
995-
});
953+
throw new NotImplementedException("_tensordot_reshape");
996954
}
997955

956+
998957
public static Tensor truediv(Tensor x, Tensor y, string name = null)
999958
=> _truediv_python3(x, y, name);
1000959

test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,5 +63,19 @@ public void GlobalNorm()
6363
var norm = tf.linalg.global_norm(t_list);
6464
Assert.AreEqual(norm.numpy(), 14.282857f);
6565
}
66+
67+
[TestMethod]
68+
public void Tensordot()
69+
{
70+
var a = tf.constant(new[] { 1, 2 });
71+
var b = tf.constant(new[] { 2, 3 });
72+
var c = tf.linalg.tensordot(a, b, 0);
73+
Assert.AreEqual(c.shape, (2, 2));
74+
AssetSequenceEqual(c.ToArray<int>(), new[] { 2, 3, 4, 6 });
75+
76+
c = tf.linalg.tensordot(a, b, new[] { 0, 0 });
77+
Assert.AreEqual(c.shape.ndim, 0);
78+
Assert.AreEqual(c.numpy(), 8);
79+
}
6680
}
6781
}

0 commit comments

Comments
 (0)