Skip to content

[mypyc] Add basic optimization for sorted #18902

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mypyc/doc/native_operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Functions
* ``delattr(obj, name)``
* ``slice(start, stop, step)``
* ``globals()``
* ``sorted(obj)``

Method decorators
-----------------
Expand Down
1 change: 1 addition & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ int CPyList_Insert(PyObject *list, CPyTagged index, PyObject *value);
PyObject *CPyList_Extend(PyObject *o1, PyObject *o2);
int CPyList_Remove(PyObject *list, PyObject *obj);
CPyTagged CPyList_Index(PyObject *list, PyObject *obj);
PyObject *CPySequence_Sort(PyObject *seq);
PyObject *CPySequence_Multiply(PyObject *seq, CPyTagged t_size);
PyObject *CPySequence_RMultiply(CPyTagged t_size, PyObject *seq);
PyObject *CPySequence_InPlaceMultiply(PyObject *seq, CPyTagged t_size);
Expand Down
12 changes: 12 additions & 0 deletions mypyc/lib-rt/list_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,18 @@ CPyTagged CPyList_Index(PyObject *list, PyObject *obj) {
return index << 1;
}

PyObject *CPySequence_Sort(PyObject *seq) {
PyObject *newlist = PySequence_List(seq);
if (newlist == NULL)
return NULL;
int res = PyList_Sort(newlist);
if (res < 0) {
Py_DECREF(newlist);
return NULL;
}
return newlist;
}

PyObject *CPySequence_Multiply(PyObject *seq, CPyTagged t_size) {
Py_ssize_t size = CPyTagged_AsSsize_t(t_size);
if (size == -1 && PyErr_Occurred()) {
Expand Down
9 changes: 9 additions & 0 deletions mypyc/primitives/list_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@
# Get the 'builtins.list' type object.
load_address_op(name="builtins.list", type=object_rprimitive, src="PyList_Type")

# sorted(obj)
function_op(
name="builtins.sorted",
arg_types=[object_rprimitive],
return_type=list_rprimitive,
c_function_name="CPySequence_Sort",
error_kind=ERR_MAGIC,
)

# list(obj)
to_list = function_op(
name="builtins.list",
Expand Down
1 change: 1 addition & 0 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def pow(base: __SupportsPow2[T_contra, T_co], exp: T_contra, mod: None = None) -
def pow(base: __SupportsPow3NoneOnly[T_contra, T_co], exp: T_contra, mod: None = None) -> T_co: ...
@overload
def pow(base: __SupportsPow3[T_contra, _M, T_co], exp: T_contra, mod: _M) -> T_co: ...
def sorted(iterable: Iterable[_T]) -> list[_T]: ...
def exit() -> None: ...
def min(x: _T, y: _T) -> _T: ...
def max(x: _T, y: _T) -> _T: ...
Expand Down
22 changes: 22 additions & 0 deletions mypyc/test-data/irbuild-lists.test
Original file line number Diff line number Diff line change
Expand Up @@ -561,3 +561,25 @@ L3:
goto L1
L4:
return 1

[case testSorted]
from typing import List, Any
def list_sort(a: List[int]) -> None:
a.sort()
def sort_iterable(a: Any) -> None:
sorted(a)
[out]
def list_sort(a):
a :: list
r0 :: i32
r1 :: bit
L0:
r0 = PyList_Sort(a)
r1 = r0 >= 0 :: signed
return 1
def sort_iterable(a):
a :: object
r0 :: list
L0:
r0 = CPySequence_Sort(a)
return 1
22 changes: 22 additions & 0 deletions mypyc/test-data/run-lists.test
Original file line number Diff line number Diff line change
Expand Up @@ -489,3 +489,25 @@ def test_index_with_literal() -> None:
assert d is d2
d = a[-2].d
assert d is d1

[case testSorted]
from typing import List

def test_list_sort() -> None:
l1 = [2, 1, 3]
id_l1 = id(l1)
l1.sort()
assert l1 == [1, 2, 3]
assert id_l1 == id(l1)

def test_sorted() -> None:
res = [1, 2, 3]
l1 = [2, 1, 3]
id_l1 = id(l1)
s_l1 = sorted(l1)
assert s_l1 == res
assert id_l1 != id(s_l1)
assert l1 == [2, 1, 3]
assert sorted((2, 1, 3)) == res
assert sorted({2, 1, 3}) == res
assert sorted({2: "", 1: "", 3: ""}) == res