267 lines
6.0 KiB
Python
267 lines
6.0 KiB
Python
# A port of https://github.com/python/cpython/blob/e42b7051/Lib/heapq.py
|
|
|
|
|
|
import heapq as hq
|
|
|
|
from numba.core import types
|
|
from numba.core.errors import TypingError
|
|
from numba.core.extending import overload, register_jitable
|
|
|
|
|
|
@register_jitable
|
|
def _siftdown(heap, startpos, pos):
|
|
newitem = heap[pos]
|
|
|
|
while pos > startpos:
|
|
parentpos = (pos - 1) >> 1
|
|
parent = heap[parentpos]
|
|
if newitem < parent:
|
|
heap[pos] = parent
|
|
pos = parentpos
|
|
continue
|
|
break
|
|
|
|
heap[pos] = newitem
|
|
|
|
|
|
@register_jitable
|
|
def _siftup(heap, pos):
|
|
endpos = len(heap)
|
|
startpos = pos
|
|
newitem = heap[pos]
|
|
|
|
childpos = 2 * pos + 1
|
|
while childpos < endpos:
|
|
|
|
rightpos = childpos + 1
|
|
if rightpos < endpos and not heap[childpos] < heap[rightpos]:
|
|
childpos = rightpos
|
|
|
|
heap[pos] = heap[childpos]
|
|
pos = childpos
|
|
childpos = 2 * pos + 1
|
|
|
|
heap[pos] = newitem
|
|
_siftdown(heap, startpos, pos)
|
|
|
|
|
|
@register_jitable
|
|
def _siftdown_max(heap, startpos, pos):
|
|
newitem = heap[pos]
|
|
|
|
while pos > startpos:
|
|
parentpos = (pos - 1) >> 1
|
|
parent = heap[parentpos]
|
|
if parent < newitem:
|
|
heap[pos] = parent
|
|
pos = parentpos
|
|
continue
|
|
break
|
|
heap[pos] = newitem
|
|
|
|
|
|
@register_jitable
|
|
def _siftup_max(heap, pos):
|
|
endpos = len(heap)
|
|
startpos = pos
|
|
newitem = heap[pos]
|
|
|
|
childpos = 2 * pos + 1
|
|
while childpos < endpos:
|
|
|
|
rightpos = childpos + 1
|
|
if rightpos < endpos and not heap[rightpos] < heap[childpos]:
|
|
childpos = rightpos
|
|
|
|
heap[pos] = heap[childpos]
|
|
pos = childpos
|
|
childpos = 2 * pos + 1
|
|
|
|
heap[pos] = newitem
|
|
_siftdown_max(heap, startpos, pos)
|
|
|
|
|
|
@register_jitable
|
|
def reversed_range(x):
|
|
# analogous to reversed(range(x))
|
|
return range(x - 1, -1, -1)
|
|
|
|
|
|
@register_jitable
|
|
def _heapify_max(x):
|
|
n = len(x)
|
|
|
|
for i in reversed_range(n // 2):
|
|
_siftup_max(x, i)
|
|
|
|
|
|
@register_jitable
|
|
def _heapreplace_max(heap, item):
|
|
returnitem = heap[0]
|
|
heap[0] = item
|
|
_siftup_max(heap, 0)
|
|
return returnitem
|
|
|
|
|
|
def assert_heap_type(heap):
|
|
if not isinstance(heap, (types.List, types.ListType)):
|
|
raise TypingError('heap argument must be a list')
|
|
|
|
dt = heap.dtype
|
|
if isinstance(dt, types.Complex):
|
|
msg = ("'<' not supported between instances "
|
|
"of 'complex' and 'complex'")
|
|
raise TypingError(msg)
|
|
|
|
|
|
def assert_item_type_consistent_with_heap_type(heap, item):
|
|
if not heap.dtype == item:
|
|
raise TypingError('heap type must be the same as item type')
|
|
|
|
|
|
@overload(hq.heapify)
|
|
def hq_heapify(x):
|
|
assert_heap_type(x)
|
|
|
|
def hq_heapify_impl(x):
|
|
n = len(x)
|
|
for i in reversed_range(n // 2):
|
|
_siftup(x, i)
|
|
|
|
return hq_heapify_impl
|
|
|
|
|
|
@overload(hq.heappop)
|
|
def hq_heappop(heap):
|
|
assert_heap_type(heap)
|
|
|
|
def hq_heappop_impl(heap):
|
|
lastelt = heap.pop()
|
|
if heap:
|
|
returnitem = heap[0]
|
|
heap[0] = lastelt
|
|
_siftup(heap, 0)
|
|
return returnitem
|
|
return lastelt
|
|
|
|
return hq_heappop_impl
|
|
|
|
|
|
@overload(hq.heappush)
|
|
def heappush(heap, item):
|
|
assert_heap_type(heap)
|
|
assert_item_type_consistent_with_heap_type(heap, item)
|
|
|
|
def hq_heappush_impl(heap, item):
|
|
heap.append(item)
|
|
_siftdown(heap, 0, len(heap) - 1)
|
|
|
|
return hq_heappush_impl
|
|
|
|
|
|
@overload(hq.heapreplace)
|
|
def heapreplace(heap, item):
|
|
assert_heap_type(heap)
|
|
assert_item_type_consistent_with_heap_type(heap, item)
|
|
|
|
def hq_heapreplace(heap, item):
|
|
returnitem = heap[0]
|
|
heap[0] = item
|
|
_siftup(heap, 0)
|
|
return returnitem
|
|
|
|
return hq_heapreplace
|
|
|
|
|
|
@overload(hq.heappushpop)
|
|
def heappushpop(heap, item):
|
|
assert_heap_type(heap)
|
|
assert_item_type_consistent_with_heap_type(heap, item)
|
|
|
|
def hq_heappushpop_impl(heap, item):
|
|
if heap and heap[0] < item:
|
|
item, heap[0] = heap[0], item
|
|
_siftup(heap, 0)
|
|
return item
|
|
|
|
return hq_heappushpop_impl
|
|
|
|
|
|
def check_input_types(n, iterable):
|
|
|
|
if not isinstance(n, (types.Integer, types.Boolean)):
|
|
raise TypingError("First argument 'n' must be an integer")
|
|
# heapq also accepts 1.0 (but not 0.0, 2.0, 3.0...) but
|
|
# this isn't replicated
|
|
|
|
if not isinstance(iterable, (types.Sequence, types.Array, types.ListType)):
|
|
raise TypingError("Second argument 'iterable' must be iterable")
|
|
|
|
|
|
@overload(hq.nsmallest)
|
|
def nsmallest(n, iterable):
|
|
check_input_types(n, iterable)
|
|
|
|
def hq_nsmallest_impl(n, iterable):
|
|
|
|
if n == 0:
|
|
return [iterable[0] for _ in range(0)]
|
|
elif n == 1:
|
|
out = min(iterable)
|
|
return [out]
|
|
|
|
size = len(iterable)
|
|
if n >= size:
|
|
return sorted(iterable)[:n]
|
|
|
|
it = iter(iterable)
|
|
result = [(elem, i) for i, elem in zip(range(n), it)]
|
|
|
|
_heapify_max(result)
|
|
top = result[0][0]
|
|
order = n
|
|
|
|
for elem in it:
|
|
if elem < top:
|
|
_heapreplace_max(result, (elem, order))
|
|
top, _order = result[0]
|
|
order += 1
|
|
result.sort()
|
|
return [elem for (elem, order) in result]
|
|
|
|
return hq_nsmallest_impl
|
|
|
|
|
|
@overload(hq.nlargest)
|
|
def nlargest(n, iterable):
|
|
check_input_types(n, iterable)
|
|
|
|
def hq_nlargest_impl(n, iterable):
|
|
|
|
if n == 0:
|
|
return [iterable[0] for _ in range(0)]
|
|
elif n == 1:
|
|
out = max(iterable)
|
|
return [out]
|
|
|
|
size = len(iterable)
|
|
if n >= size:
|
|
return sorted(iterable)[::-1][:n]
|
|
|
|
it = iter(iterable)
|
|
result = [(elem, i) for i, elem in zip(range(0, -n, -1), it)]
|
|
|
|
hq.heapify(result)
|
|
top = result[0][0]
|
|
order = -n
|
|
|
|
for elem in it:
|
|
if top < elem:
|
|
hq.heapreplace(result, (elem, order))
|
|
top, _order = result[0]
|
|
order -= 1
|
|
result.sort(reverse=True)
|
|
return [elem for (elem, order) in result]
|
|
|
|
return hq_nlargest_impl
|