259 lines
7.5 KiB
Python
259 lines
7.5 KiB
Python
import collections
|
|
|
|
import numpy as np
|
|
|
|
from numba.core import types
|
|
|
|
|
|
QuicksortImplementation = collections.namedtuple(
|
|
'QuicksortImplementation',
|
|
(# The compile function itself
|
|
'compile',
|
|
# All subroutines exercised by test_sort
|
|
'partition', 'partition3', 'insertion_sort',
|
|
# The top-level function
|
|
'run_quicksort',
|
|
))
|
|
|
|
|
|
Partition = collections.namedtuple('Partition', ('start', 'stop'))
|
|
|
|
# Under this size, switch to a simple insertion sort
|
|
SMALL_QUICKSORT = 15
|
|
|
|
MAX_STACK = 100
|
|
|
|
|
|
def make_quicksort_impl(wrap, lt=None, is_argsort=False, is_list=False, is_np_array=False):
|
|
|
|
intp = types.intp
|
|
zero = intp(0)
|
|
|
|
# Two subroutines to make the core algorithm generic wrt. argsort
|
|
# or normal sorting. Note the genericity may make basic sort()
|
|
# slightly slower (~5%)
|
|
if is_argsort:
|
|
if is_list:
|
|
@wrap
|
|
def make_res(A):
|
|
return [x for x in range(len(A))]
|
|
else:
|
|
@wrap
|
|
def make_res(A):
|
|
return np.arange(A.size)
|
|
|
|
@wrap
|
|
def GET(A, idx_or_val):
|
|
return A[idx_or_val]
|
|
|
|
else:
|
|
@wrap
|
|
def make_res(A):
|
|
return A
|
|
|
|
@wrap
|
|
def GET(A, idx_or_val):
|
|
return idx_or_val
|
|
|
|
def default_lt(a, b):
|
|
"""
|
|
Trivial comparison function between two keys.
|
|
"""
|
|
return a < b
|
|
|
|
LT = wrap(lt if lt is not None else default_lt)
|
|
|
|
@wrap
|
|
def insertion_sort(A, R, low, high):
|
|
"""
|
|
Insertion sort A[low:high + 1]. Note the inclusive bounds.
|
|
"""
|
|
assert low >= 0
|
|
if high <= low:
|
|
return
|
|
|
|
for i in range(low + 1, high + 1):
|
|
k = R[i]
|
|
v = GET(A, k)
|
|
# Insert v into A[low:i]
|
|
j = i
|
|
while j > low and LT(v, GET(A, R[j - 1])):
|
|
# Make place for moving A[i] downwards
|
|
R[j] = R[j - 1]
|
|
j -= 1
|
|
R[j] = k
|
|
|
|
@wrap
|
|
def partition(A, R, low, high):
|
|
"""
|
|
Partition A[low:high + 1] around a chosen pivot. The pivot's index
|
|
is returned.
|
|
"""
|
|
assert low >= 0
|
|
assert high > low
|
|
|
|
mid = (low + high) >> 1
|
|
# NOTE: the pattern of swaps below for the pivot choice and the
|
|
# partitioning gives good results (i.e. regular O(n log n))
|
|
# on sorted, reverse-sorted, and uniform arrays. Subtle changes
|
|
# risk breaking this property.
|
|
|
|
# median of three {low, middle, high}
|
|
if LT(GET(A, R[mid]), GET(A, R[low])):
|
|
R[low], R[mid] = R[mid], R[low]
|
|
if LT(GET(A, R[high]), GET(A, R[mid])):
|
|
R[high], R[mid] = R[mid], R[high]
|
|
if LT(GET(A, R[mid]), GET(A, R[low])):
|
|
R[low], R[mid] = R[mid], R[low]
|
|
pivot = GET(A, R[mid])
|
|
|
|
# Temporarily stash the pivot at the end
|
|
R[high], R[mid] = R[mid], R[high]
|
|
i = low
|
|
j = high - 1
|
|
while True:
|
|
while i < high and LT(GET(A, R[i]), pivot):
|
|
i += 1
|
|
while j >= low and LT(pivot, GET(A, R[j])):
|
|
j -= 1
|
|
if i >= j:
|
|
break
|
|
R[i], R[j] = R[j], R[i]
|
|
i += 1
|
|
j -= 1
|
|
# Put the pivot back in its final place (all items before `i`
|
|
# are smaller than the pivot, all items at/after `i` are larger)
|
|
R[i], R[high] = R[high], R[i]
|
|
return i
|
|
|
|
@wrap
|
|
def partition3(A, low, high):
|
|
"""
|
|
Three-way partition [low, high) around a chosen pivot.
|
|
A tuple (lt, gt) is returned such that:
|
|
- all elements in [low, lt) are < pivot
|
|
- all elements in [lt, gt] are == pivot
|
|
- all elements in (gt, high] are > pivot
|
|
"""
|
|
mid = (low + high) >> 1
|
|
# median of three {low, middle, high}
|
|
if LT(A[mid], A[low]):
|
|
A[low], A[mid] = A[mid], A[low]
|
|
if LT(A[high], A[mid]):
|
|
A[high], A[mid] = A[mid], A[high]
|
|
if LT(A[mid], A[low]):
|
|
A[low], A[mid] = A[mid], A[low]
|
|
pivot = A[mid]
|
|
|
|
A[low], A[mid] = A[mid], A[low]
|
|
lt = low
|
|
gt = high
|
|
i = low + 1
|
|
while i <= gt:
|
|
if LT(A[i], pivot):
|
|
A[lt], A[i] = A[i], A[lt]
|
|
lt += 1
|
|
i += 1
|
|
elif LT(pivot, A[i]):
|
|
A[gt], A[i] = A[i], A[gt]
|
|
gt -= 1
|
|
else:
|
|
i += 1
|
|
return lt, gt
|
|
|
|
@wrap
|
|
def run_quicksort1(A):
|
|
R = make_res(A)
|
|
|
|
if len(A) < 2:
|
|
return R
|
|
|
|
stack = [Partition(zero, zero)] * MAX_STACK
|
|
stack[0] = Partition(zero, len(A) - 1)
|
|
n = 1
|
|
|
|
while n > 0:
|
|
n -= 1
|
|
low, high = stack[n]
|
|
# Partition until it becomes more efficient to do an insertion sort
|
|
while high - low >= SMALL_QUICKSORT:
|
|
assert n < MAX_STACK
|
|
i = partition(A, R, low, high)
|
|
# Push largest partition on the stack
|
|
if high - i > i - low:
|
|
# Right is larger
|
|
if high > i:
|
|
stack[n] = Partition(i + 1, high)
|
|
n += 1
|
|
high = i - 1
|
|
else:
|
|
if i > low:
|
|
stack[n] = Partition(low, i - 1)
|
|
n += 1
|
|
low = i + 1
|
|
|
|
insertion_sort(A, R, low, high)
|
|
|
|
return R
|
|
|
|
if is_np_array:
|
|
@wrap
|
|
def run_quicksort(A):
|
|
if A.ndim == 1:
|
|
return run_quicksort1(A)
|
|
else:
|
|
for idx in np.ndindex(A.shape[:-1]):
|
|
run_quicksort1(A[idx])
|
|
return A
|
|
else:
|
|
@wrap
|
|
def run_quicksort(A):
|
|
return run_quicksort1(A)
|
|
|
|
# Unused quicksort implementation based on 3-way partitioning; the
|
|
# partitioning scheme turns out exhibiting bad behaviour on sorted arrays.
|
|
@wrap
|
|
def _run_quicksort(A):
|
|
stack = [Partition(zero, zero)] * 100
|
|
stack[0] = Partition(zero, len(A) - 1)
|
|
n = 1
|
|
|
|
while n > 0:
|
|
n -= 1
|
|
low, high = stack[n]
|
|
# Partition until it becomes more efficient to do an insertion sort
|
|
while high - low >= SMALL_QUICKSORT:
|
|
assert n < MAX_STACK
|
|
l, r = partition3(A, low, high)
|
|
# One trivial (empty) partition => iterate on the other
|
|
if r == high:
|
|
high = l - 1
|
|
elif l == low:
|
|
low = r + 1
|
|
# Push largest partition on the stack
|
|
elif high - r > l - low:
|
|
# Right is larger
|
|
stack[n] = Partition(r + 1, high)
|
|
n += 1
|
|
high = l - 1
|
|
else:
|
|
stack[n] = Partition(low, l - 1)
|
|
n += 1
|
|
low = r + 1
|
|
|
|
insertion_sort(A, low, high)
|
|
|
|
|
|
return QuicksortImplementation(wrap,
|
|
partition, partition3, insertion_sort,
|
|
run_quicksort)
|
|
|
|
|
|
def make_py_quicksort(*args, **kwargs):
|
|
return make_quicksort_impl((lambda f: f), *args, **kwargs)
|
|
|
|
def make_jit_quicksort(*args, **kwargs):
|
|
from numba.core.extending import register_jitable
|
|
return make_quicksort_impl((lambda f: register_jitable(f)),
|
|
*args, **kwargs)
|