Skip to content

Commit 56a64da

Browse files
committed
np.sort
1 parent bd26bbd commit 56a64da

File tree

3 files changed

+56
-3
lines changed

3 files changed

+56
-3
lines changed

src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections;
33
using System.Collections.Generic;
4+
using System.Globalization;
45
using System.Numerics;
56
using System.Text;
67

@@ -9,11 +10,11 @@ namespace Tensorflow.NumPy
910
public partial class np
1011
{
1112
[AutoNumPy]
12-
public static NDArray argmax(NDArray a, Axis axis = null)
13+
public static NDArray argmax(NDArray a, Axis? axis = null)
1314
=> new NDArray(math_ops.argmax(a, axis ?? 0));
1415

1516
[AutoNumPy]
16-
public static NDArray argsort(NDArray a, Axis axis = null)
17+
public static NDArray argsort(NDArray a, Axis? axis = null)
1718
=> new NDArray(sort_ops.argsort(a, axis: axis ?? -1));
1819

1920
[AutoNumPy]
@@ -25,5 +26,22 @@ public static (NDArray, NDArray) unique(NDArray a)
2526

2627
[AutoNumPy]
2728
public static void shuffle(NDArray x) => np.random.shuffle(x);
29+
30+
/// <summary>
31+
/// Sorts a ndarray
32+
/// </summary>
33+
/// <param name="values"></param>
34+
/// <param name="axis">
35+
/// The axis along which to sort. The default is -1, which sorts the last axis.
36+
/// </param>
37+
/// <param name="direction">
38+
/// The direction in which to sort the values (`'ASCENDING'` or `'DESCENDING'`)
39+
/// </param>
40+
/// <returns>
41+
/// A `NDArray` with the same dtype and shape as `values`, with the elements sorted along the given `axis`.
42+
/// </returns>
43+
[AutoNumPy]
44+
public static NDArray sort(NDArray values, Axis? axis = null, string direction = "ASCENDING")
45+
=> new NDArray(sort_ops.sort(values, axis: axis ?? -1, direction: direction));
2846
}
2947
}

src/TensorFlowNET.Core/Operations/sort_ops.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,31 @@ public static Tensor argsort(Tensor values, Axis axis = null, string direction =
4747
return indices;
4848
}
4949

50+
public static Tensor sort(Tensor values, Axis axis, string direction = "ASCENDING", string? name = null)
51+
{
52+
var k = array_ops.shape(values)[axis];
53+
values = -values;
54+
var static_rank = values.shape.ndim;
55+
var top_k_input = values;
56+
if (axis == -1 || axis + 1 == values.shape.ndim)
57+
{
58+
}
59+
else
60+
{
61+
if (axis == 0 && static_rank == 2)
62+
top_k_input = array_ops.transpose(values, new[] { 1, 0 });
63+
else
64+
throw new NotImplementedException("");
65+
}
66+
67+
(values, _) = tf.Context.ExecuteOp("TopKV2", name,
68+
new ExecuteOpArgs(top_k_input, k).SetAttributes(new
69+
{
70+
sorted = true
71+
}));
72+
return -values;
73+
}
74+
5075
public Tensor matrix_inverse(Tensor input, bool adjoint = false, string name = null)
5176
=> tf.Context.ExecuteOp("MatrixInverse", name,
5277
new ExecuteOpArgs(input).SetAttributes(new

test/TensorFlowNET.UnitTest/NumPy/Array.Sorting.Test.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
using System.Text;
66
using Tensorflow;
77
using Tensorflow.NumPy;
8-
using static Tensorflow.Binding;
98

109
namespace TensorFlowNET.UnitTest.NumPy
1110
{
@@ -30,5 +29,16 @@ public void argsort()
3029
Assert.AreEqual(ind[0], new[] { 0, 1 });
3130
Assert.AreEqual(ind[1], new[] { 1, 0 });
3231
}
32+
33+
/// <summary>
34+
/// https://numpy.org/doc/stable/reference/generated/numpy.sort.html
35+
/// </summary>
36+
[TestMethod]
37+
public void sort()
38+
{
39+
var x = np.array(new int[] { 3, 1, 2 });
40+
var sorted = np.sort(x);
41+
Assert.IsTrue(sorted.ToArray<int>() is [1, 2, 3]);
42+
}
3343
}
3444
}

0 commit comments

Comments
 (0)