@@ -868,133 +868,92 @@ public static Tensor conj(Tensor x, string name = null)
868
868
public static Tensor tanh ( Tensor x , string name = null )
869
869
=> gen_math_ops . tanh ( x , name ) ;
870
870
871
- public static Tensor tensordot ( Tensor x , Tensor y , int [ ] axes , string name = null )
871
+ public static Tensor tensordot ( Tensor a , Tensor b , NDArray axes , string name = null )
872
872
{
873
- Tensor _tensordot_reshape ( Tensor a , int [ ] axes , bool flipped = false )
873
+ return tf_with ( ops . name_scope ( name , "Tensordot" , new { a , b , axes } ) , scope =>
874
874
{
875
- if ( a . shape . IsFullyDefined && isinstance ( axes , ( typeof ( List < object > ) , typeof ( Tuple ) ) ) )
876
- {
877
- var shape_a = a . shape . dims ;
878
-
879
- // axes
880
- int iter = 0 ;
881
- foreach ( int i in axes )
882
- {
883
- if ( i >= 0 )
884
- axes [ 0 + iter ] = i ;
885
- else
886
- axes [ 0 + iter ] = i + len ( shape_a ) ;
887
- iter ++ ;
888
- }
889
-
890
- // free
891
- int [ ] free = { } ;
892
- iter = 0 ;
893
- foreach ( int i in Enumerable . Range ( 0 , len ( axes ) ) )
894
- if ( ! Array . Exists ( axes , i => i == i ) )
895
- free [ free . Length ] = i ;
896
-
897
- // free_dims
898
- int [ ] free_dims = { } ;
899
- foreach ( int i in free )
900
- free_dims [ free_dims . Length ] = ( int ) shape_a [ i ] ;
901
-
902
- int prod_free = ( int ) np . prod ( free_dims ) ;
903
-
904
- // prod_axes
905
- int [ ] prod_axes_pre = { } ;
906
- foreach ( int i in axes )
907
- prod_axes_pre [ prod_axes_pre . Length ] = ( int ) shape_a [ i ] ;
908
- int prod_axes = ( int ) np . prod ( prod_axes_pre ) ;
909
-
910
- // perm
911
- Tensor perm ;
912
- if ( flipped )
913
- perm = ops . convert_to_tensor ( list ( free ) ) + ops . convert_to_tensor ( free ) ;
914
- else
915
- perm = ops . convert_to_tensor ( list ( free ) ) + ops . convert_to_tensor ( free )
916
- + ops . convert_to_tensor ( list ( axes ) ) ;
917
-
918
- // new_shape
919
- Shape new_shape ;
920
- if ( flipped )
921
- new_shape = new Shape ( new int [ ] { prod_axes , prod_free } ) ;
922
- else
923
- new_shape = new Shape ( new int [ ] { prod_free , prod_axes } ) ;
924
- }
875
+ name = scope ;
876
+ var ( a_axes , b_axes ) = _tensordot_axes ( a , axes ) ;
877
+ var ( a_reshape , a_free_dims , a_free_dims_static ) = _tensordot_reshape ( a , a_axes ) ;
878
+ var ( b_reshape , b_free_dims , b_free_dims_static ) = _tensordot_reshape ( b , b_axes , true ) ;
879
+ var ab_matmul = matmul ( a_reshape , b_reshape ) ;
880
+ var dims = new List < int > ( ) ;
881
+ dims . AddRange ( a_free_dims ) ;
882
+ dims . AddRange ( b_free_dims ) ;
883
+ if ( ab_matmul . shape . Equals ( dims ) )
884
+ return ab_matmul ;
885
+ else
886
+ return array_ops . reshape ( ab_matmul , tf . constant ( dims . ToArray ( ) ) , name : name ) ;
887
+ } ) ;
888
+ }
925
889
926
- throw new NotImplementedException ( "_tensordot_reshape" ) ;
890
+ static ( int [ ] , int [ ] ) _tensordot_axes ( Tensor a , NDArray axes )
891
+ {
892
+ if ( axes . rank == 0 )
893
+ {
894
+ int axe = axes ;
895
+ if ( axe > a . shape . ndim )
896
+ throw new ValueError ( "`axes` must not be larger than the number of " +
897
+ $ "dimensions of tensor { a } . Received { axes } , vs " +
898
+ $ "tensor dimensions { a . ndim } .") ;
899
+ return ( Binding . range ( a . shape . ndim - axe , a . shape . ndim ) . ToArray ( ) ,
900
+ Binding . range ( 0 , axe ) . ToArray ( ) ) ;
901
+ }
902
+ else
903
+ {
904
+ ( int a_axe , int b_axe ) = ( axes [ 0 ] , axes [ 1 ] ) ;
905
+ return ( new [ ] { a_axe } , new [ ] { b_axe } ) ;
927
906
}
928
-
929
- throw new NotImplementedException ( "tensordot" ) ;
930
907
}
931
908
932
- public static Tensor tensordot ( Tensor x , Tensor y , Tensor axes , string name = null )
909
+ static ( Tensor , int [ ] , int [ ] ) _tensordot_reshape ( Tensor a , int [ ] axes , bool flipped = false )
933
910
{
934
- Tensor _tensordot_reshape ( Tensor a , int [ ] axes , bool flipped = false )
911
+ if ( a . shape . IsFullyDefined && isinstance ( axes , ( typeof ( int [ ] ) , typeof ( Tuple ) ) ) )
935
912
{
936
- if ( a . shape . IsFullyDefined && isinstance ( axes , ( typeof ( List < object > ) , typeof ( Tuple ) ) ) )
937
- {
938
- var shape_a = a . shape . dims ;
913
+ var shape_a = a . shape . as_int_list ( ) ;
939
914
940
- // axes
941
- int iter = 0 ;
942
- foreach ( int i in axes )
943
- {
944
- if ( i >= 0 )
945
- axes [ 0 + iter ] = i ;
946
- else
947
- axes [ 0 + iter ] = i + len ( shape_a ) ;
948
- iter ++ ;
949
- }
915
+ // axes
916
+ axes = axes . Select ( i => i >= 0 ? i : i + len ( shape_a ) ) . ToArray ( ) ;
917
+
918
+ // free
919
+ int [ ] free = Binding . range ( a . shape . ndim ) . Where ( i => ! axes . Contains ( i ) ) . ToArray ( ) ;
920
+
921
+ // free_dims
922
+ int [ ] free_dims = free . Select ( i => shape_a [ i ] ) . ToArray ( ) ;
950
923
951
- // free
952
- int [ ] free = { } ;
953
- iter = 0 ;
954
- foreach ( int i in Enumerable . Range ( 0 , len ( axes ) ) )
955
- if ( ! Array . Exists ( axes , i => i == i ) )
956
- free [ free . Length ] = i ;
957
-
958
- // free_dims
959
- int [ ] free_dims = { } ;
960
- foreach ( int i in free )
961
- free_dims [ free_dims . Length ] = ( int ) shape_a [ i ] ;
962
-
963
- int prod_free = ( int ) np . prod ( free_dims ) ;
964
-
965
- // prod_axes
966
- int [ ] prod_axes_pre = { } ;
967
- foreach ( int i in axes )
968
- prod_axes_pre [ prod_axes_pre . Length ] = ( int ) shape_a [ i ] ;
969
- int prod_axes = ( int ) np . prod ( prod_axes_pre ) ;
970
-
971
- // perm
972
- Tensor perm ;
973
- if ( flipped )
974
- perm = ops . convert_to_tensor ( list ( free ) ) + ops . convert_to_tensor ( free ) ;
975
- else
976
- perm = ops . convert_to_tensor ( list ( free ) ) + ops . convert_to_tensor ( free )
977
- + ops . convert_to_tensor ( list ( axes ) ) ;
978
-
979
- // new_shape
980
- Shape new_shape ;
981
- if ( flipped )
982
- new_shape = new Shape ( new int [ ] { prod_axes , prod_free } ) ;
983
- else
984
- new_shape = new Shape ( new int [ ] { prod_free , prod_axes } ) ;
924
+ int prod_free = np . prod ( free_dims ) ;
925
+
926
+ // prod_axes
927
+ int prod_axes = np . prod ( axes . Select ( i => shape_a [ i ] ) . ToArray ( ) ) ;
928
+
929
+ // perm
930
+ List < int > perm = new List < int > ( ) ;
931
+ if ( flipped )
932
+ {
933
+ perm . AddRange ( axes ) ;
934
+ perm . AddRange ( free ) ;
935
+ }
936
+ else
937
+ {
938
+ perm . AddRange ( free ) ;
939
+ perm . AddRange ( axes ) ;
985
940
}
986
941
987
- throw new NotImplementedException ( "_tensordot_reshape" ) ;
942
+ // new_shape
943
+ Shape new_shape ;
944
+ if ( flipped )
945
+ new_shape = new Shape ( new int [ ] { prod_axes , prod_free } ) ;
946
+ else
947
+ new_shape = new Shape ( new int [ ] { prod_free , prod_axes } ) ;
948
+ var a_trans = a ;
949
+ var reshaped_a = array_ops . reshape ( a_trans , new_shape ) ;
950
+ return ( reshaped_a , free_dims , free_dims ) ;
988
951
}
989
952
990
- return tf_with ( ops . name_scope ( name , "Tensordot" , new { x , y , axes } ) , scope =>
991
- {
992
- name = scope ;
993
- var ( a_axes , b_axes ) = ( axes [ 0 ] , axes [ 1 ] ) ;
994
- return x ;
995
- } ) ;
953
+ throw new NotImplementedException ( "_tensordot_reshape" ) ;
996
954
}
997
955
956
+
998
957
public static Tensor truediv ( Tensor x , Tensor y , string name = null )
999
958
=> _truediv_python3 ( x , y , name ) ;
1000
959
0 commit comments