Skip to content

Commit 135562e

Browse files
committed
argsort fix.
1 parent 43e59ca commit 135562e

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

src/TensorFlowNET.Core/NumPy/Axis.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ public static implicit operator Axis(Shape axis)
5555
public static implicit operator Tensor(Axis axis)
5656
=> constant_op.constant(axis);
5757

58+
public static bool operator ==(Axis left, int right)
59+
=> left.IsScalar && left[0] == right;
60+
61+
public static bool operator !=(Axis left, int right)
62+
=> !(left == right);
63+
5864
public override string ToString()
5965
=> IsScalar ? $"{axis[0]}" : $"({string.Join(", ", axis)})";
6066
}

src/TensorFlowNET.Core/Operations/sort_ops.cs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17-
using Tensorflow.Operations;
17+
using System;
1818
using static Tensorflow.Binding;
1919

2020
namespace Tensorflow
@@ -26,8 +26,21 @@ public static Tensor argsort(Tensor values, Axis axis = null, string direction =
2626
axis = axis ?? new Axis(-1);
2727
var k = array_ops.shape(values)[axis];
2828
values = -values;
29+
var static_rank = values.shape.ndim;
30+
var top_k_input = values;
31+
if (axis == -1 || axis + 1 == values.shape.ndim)
32+
{
33+
}
34+
else
35+
{
36+
if (axis == 0 && static_rank == 2)
37+
top_k_input = array_ops.transpose(values, new[] { 1, 0 });
38+
else
39+
throw new NotImplementedException("");
40+
}
41+
2942
var (_, indices) = tf.Context.ExecuteOp("TopKV2", name,
30-
new ExecuteOpArgs(values, k).SetAttributes(new
43+
new ExecuteOpArgs(top_k_input, k).SetAttributes(new
3144
{
3245
sorted = true
3346
}));

0 commit comments

Comments
 (0)