ai-content-maker/.venv/Lib/site-packages/numba/cpython/setobj.py

1712 lines
56 KiB
Python

"""
Support for native homogeneous sets.
"""
import collections
import contextlib
import math
import operator
from functools import cached_property
from llvmlite import ir
from numba.core import types, typing, cgutils
from numba.core.imputils import (lower_builtin, lower_cast,
iternext_impl, impl_ret_borrowed,
impl_ret_new_ref, impl_ret_untracked,
for_iter, call_len, RefType)
from numba.misc import quicksort
from numba.cpython import slicing
from numba.core.errors import NumbaValueError, TypingError
from numba.core.extending import overload, overload_method, intrinsic
def get_payload_struct(context, builder, set_type, ptr):
"""
Given a set value and type, get its payload structure (as a
reference, so that mutations are seen by all).
"""
payload_type = types.SetPayload(set_type)
ptrty = context.get_data_type(payload_type).as_pointer()
payload = builder.bitcast(ptr, ptrty)
return context.make_data_helper(builder, payload_type, ref=payload)
def get_entry_size(context, set_type):
"""
Return the entry size for the given set type.
"""
llty = context.get_data_type(types.SetEntry(set_type))
return context.get_abi_sizeof(llty)
# Note these values are special:
# - EMPTY is obtained by issuing memset(..., 0xFF)
# - (unsigned) EMPTY > (unsigned) DELETED > any other hash value
EMPTY = -1
DELETED = -2
FALLBACK = -43
# Minimal size of entries table. Must be a power of 2!
MINSIZE = 16
# Number of cache-friendly linear probes before switching to non-linear probing
LINEAR_PROBES = 3
DEBUG_ALLOCS = False
def get_hash_value(context, builder, typ, value):
"""
Compute the hash of the given value.
"""
typingctx = context.typing_context
fnty = typingctx.resolve_value_type(hash)
sig = fnty.get_call_type(typingctx, (typ,), {})
fn = context.get_function(fnty, sig)
h = fn(builder, (value,))
# Fixup reserved values
is_ok = is_hash_used(context, builder, h)
fallback = ir.Constant(h.type, FALLBACK)
return builder.select(is_ok, h, fallback)
@intrinsic
def _get_hash_value_intrinsic(typingctx, value):
def impl(context, builder, typ, args):
return get_hash_value(context, builder, value, args[0])
fnty = typingctx.resolve_value_type(hash)
sig = fnty.get_call_type(typingctx, (value,), {})
return sig, impl
def is_hash_empty(context, builder, h):
"""
Whether the hash value denotes an empty entry.
"""
empty = ir.Constant(h.type, EMPTY)
return builder.icmp_unsigned('==', h, empty)
def is_hash_deleted(context, builder, h):
"""
Whether the hash value denotes a deleted entry.
"""
deleted = ir.Constant(h.type, DELETED)
return builder.icmp_unsigned('==', h, deleted)
def is_hash_used(context, builder, h):
"""
Whether the hash value denotes an active entry.
"""
# Everything below DELETED is an used entry
deleted = ir.Constant(h.type, DELETED)
return builder.icmp_unsigned('<', h, deleted)
def check_all_set(*args):
if not all([isinstance(typ, types.Set) for typ in args]):
raise TypingError(f"All arguments must be Sets, got {args}")
if not all([args[0].dtype == s.dtype for s in args]):
raise TypingError(f"All Sets must be of the same type, got {args}")
SetLoop = collections.namedtuple('SetLoop', ('index', 'entry', 'do_break'))
class _SetPayload(object):
def __init__(self, context, builder, set_type, ptr):
payload = get_payload_struct(context, builder, set_type, ptr)
self._context = context
self._builder = builder
self._ty = set_type
self._payload = payload
self._entries = payload._get_ptr_by_name('entries')
self._ptr = ptr
@property
def mask(self):
return self._payload.mask
@mask.setter
def mask(self, value):
# CAUTION: mask must be a power of 2 minus 1
self._payload.mask = value
@property
def used(self):
return self._payload.used
@used.setter
def used(self, value):
self._payload.used = value
@property
def fill(self):
return self._payload.fill
@fill.setter
def fill(self, value):
self._payload.fill = value
@property
def finger(self):
return self._payload.finger
@finger.setter
def finger(self, value):
self._payload.finger = value
@property
def dirty(self):
return self._payload.dirty
@dirty.setter
def dirty(self, value):
self._payload.dirty = value
@property
def entries(self):
"""
A pointer to the start of the entries array.
"""
return self._entries
@property
def ptr(self):
"""
A pointer to the start of the NRT-allocated area.
"""
return self._ptr
def get_entry(self, idx):
"""
Get entry number *idx*.
"""
entry_ptr = cgutils.gep(self._builder, self._entries, idx)
entry = self._context.make_data_helper(self._builder,
types.SetEntry(self._ty),
ref=entry_ptr)
return entry
def _lookup(self, item, h, for_insert=False):
"""
Lookup the *item* with the given hash values in the entries.
Return a (found, entry index) tuple:
- If found is true, <entry index> points to the entry containing
the item.
- If found is false, <entry index> points to the empty entry that
the item can be written to (only if *for_insert* is true)
"""
context = self._context
builder = self._builder
intp_t = h.type
mask = self.mask
dtype = self._ty.dtype
tyctx = context.typing_context
fnty = tyctx.resolve_value_type(operator.eq)
sig = fnty.get_call_type(tyctx, (dtype, dtype), {})
eqfn = context.get_function(fnty, sig)
one = ir.Constant(intp_t, 1)
five = ir.Constant(intp_t, 5)
# The perturbation value for probing
perturb = cgutils.alloca_once_value(builder, h)
# The index of the entry being considered: start with (hash & mask)
index = cgutils.alloca_once_value(builder,
builder.and_(h, mask))
if for_insert:
# The index of the first deleted entry in the lookup chain
free_index_sentinel = mask.type(-1) # highest unsigned index
free_index = cgutils.alloca_once_value(builder, free_index_sentinel)
bb_body = builder.append_basic_block("lookup.body")
bb_found = builder.append_basic_block("lookup.found")
bb_not_found = builder.append_basic_block("lookup.not_found")
bb_end = builder.append_basic_block("lookup.end")
def check_entry(i):
"""
Check entry *i* against the value being searched for.
"""
entry = self.get_entry(i)
entry_hash = entry.hash
with builder.if_then(builder.icmp_unsigned('==', h, entry_hash)):
# Hashes are equal, compare values
# (note this also ensures the entry is used)
eq = eqfn(builder, (item, entry.key))
with builder.if_then(eq):
builder.branch(bb_found)
with builder.if_then(is_hash_empty(context, builder, entry_hash)):
builder.branch(bb_not_found)
if for_insert:
# Memorize the index of the first deleted entry
with builder.if_then(is_hash_deleted(context, builder, entry_hash)):
j = builder.load(free_index)
j = builder.select(builder.icmp_unsigned('==', j, free_index_sentinel),
i, j)
builder.store(j, free_index)
# First linear probing. When the number of collisions is small,
# the lineary probing loop achieves better cache locality and
# is also slightly cheaper computationally.
with cgutils.for_range(builder, ir.Constant(intp_t, LINEAR_PROBES)):
i = builder.load(index)
check_entry(i)
i = builder.add(i, one)
i = builder.and_(i, mask)
builder.store(i, index)
# If not found after linear probing, switch to a non-linear
# perturbation keyed on the unmasked hash value.
# XXX how to tell LLVM this branch is unlikely?
builder.branch(bb_body)
with builder.goto_block(bb_body):
i = builder.load(index)
check_entry(i)
# Perturb to go to next entry:
# perturb >>= 5
# i = (i * 5 + 1 + perturb) & mask
p = builder.load(perturb)
p = builder.lshr(p, five)
i = builder.add(one, builder.mul(i, five))
i = builder.and_(mask, builder.add(i, p))
builder.store(i, index)
builder.store(p, perturb)
# Loop
builder.branch(bb_body)
with builder.goto_block(bb_not_found):
if for_insert:
# Not found => for insertion, return the index of the first
# deleted entry (if any), to avoid creating an infinite
# lookup chain (issue #1913).
i = builder.load(index)
j = builder.load(free_index)
i = builder.select(builder.icmp_unsigned('==', j, free_index_sentinel),
i, j)
builder.store(i, index)
builder.branch(bb_end)
with builder.goto_block(bb_found):
builder.branch(bb_end)
builder.position_at_end(bb_end)
found = builder.phi(ir.IntType(1), 'found')
found.add_incoming(cgutils.true_bit, bb_found)
found.add_incoming(cgutils.false_bit, bb_not_found)
return found, builder.load(index)
@contextlib.contextmanager
def _iterate(self, start=None):
"""
Iterate over the payload's entries. Yield a SetLoop.
"""
context = self._context
builder = self._builder
intp_t = context.get_value_type(types.intp)
one = ir.Constant(intp_t, 1)
size = builder.add(self.mask, one)
with cgutils.for_range(builder, size, start=start) as range_loop:
entry = self.get_entry(range_loop.index)
is_used = is_hash_used(context, builder, entry.hash)
with builder.if_then(is_used):
loop = SetLoop(index=range_loop.index, entry=entry,
do_break=range_loop.do_break)
yield loop
@contextlib.contextmanager
def _next_entry(self):
"""
Yield a random entry from the payload. Caller must ensure the
set isn't empty, otherwise the function won't end.
"""
context = self._context
builder = self._builder
intp_t = context.get_value_type(types.intp)
zero = ir.Constant(intp_t, 0)
one = ir.Constant(intp_t, 1)
mask = self.mask
# Start walking the entries from the stored "search finger" and
# break as soon as we find a used entry.
bb_body = builder.append_basic_block('next_entry_body')
bb_end = builder.append_basic_block('next_entry_end')
index = cgutils.alloca_once_value(builder, self.finger)
builder.branch(bb_body)
with builder.goto_block(bb_body):
i = builder.load(index)
# ANDing with mask ensures we stay inside the table boundaries
i = builder.and_(mask, builder.add(i, one))
builder.store(i, index)
entry = self.get_entry(i)
is_used = is_hash_used(context, builder, entry.hash)
builder.cbranch(is_used, bb_end, bb_body)
builder.position_at_end(bb_end)
# Update the search finger with the next position. This avoids
# O(n**2) behaviour when pop() is called in a loop.
i = builder.load(index)
self.finger = i
yield self.get_entry(i)
class SetInstance(object):
def __init__(self, context, builder, set_type, set_val):
self._context = context
self._builder = builder
self._ty = set_type
self._entrysize = get_entry_size(context, set_type)
self._set = context.make_helper(builder, set_type, set_val)
@property
def dtype(self):
return self._ty.dtype
@property
def payload(self):
"""
The _SetPayload for this set.
"""
# This cannot be cached as the pointer can move around!
context = self._context
builder = self._builder
ptr = self._context.nrt.meminfo_data(builder, self.meminfo)
return _SetPayload(context, builder, self._ty, ptr)
@property
def value(self):
return self._set._getvalue()
@property
def meminfo(self):
return self._set.meminfo
@property
def parent(self):
return self._set.parent
@parent.setter
def parent(self, value):
self._set.parent = value
def get_size(self):
"""
Return the number of elements in the size.
"""
return self.payload.used
def set_dirty(self, val):
if self._ty.reflected:
self.payload.dirty = cgutils.true_bit if val else cgutils.false_bit
def _add_entry(self, payload, entry, item, h, do_resize=True):
context = self._context
builder = self._builder
old_hash = entry.hash
entry.hash = h
self.incref_value(item)
entry.key = item
# used++
used = payload.used
one = ir.Constant(used.type, 1)
used = payload.used = builder.add(used, one)
# fill++ if entry wasn't a deleted one
with builder.if_then(is_hash_empty(context, builder, old_hash),
likely=True):
payload.fill = builder.add(payload.fill, one)
# Grow table if necessary
if do_resize:
self.upsize(used)
self.set_dirty(True)
def _add_key(self, payload, item, h, do_resize=True, do_incref=True):
context = self._context
builder = self._builder
found, i = payload._lookup(item, h, for_insert=True)
not_found = builder.not_(found)
with builder.if_then(not_found):
# Not found => add it
entry = payload.get_entry(i)
old_hash = entry.hash
entry.hash = h
if do_incref:
self.incref_value(item)
entry.key = item
# used++
used = payload.used
one = ir.Constant(used.type, 1)
used = payload.used = builder.add(used, one)
# fill++ if entry wasn't a deleted one
with builder.if_then(is_hash_empty(context, builder, old_hash),
likely=True):
payload.fill = builder.add(payload.fill, one)
# Grow table if necessary
if do_resize:
self.upsize(used)
self.set_dirty(True)
def _remove_entry(self, payload, entry, do_resize=True, do_decref=True):
# Mark entry deleted
entry.hash = ir.Constant(entry.hash.type, DELETED)
if do_decref:
self.decref_value(entry.key)
# used--
used = payload.used
one = ir.Constant(used.type, 1)
used = payload.used = self._builder.sub(used, one)
# Shrink table if necessary
if do_resize:
self.downsize(used)
self.set_dirty(True)
def _remove_key(self, payload, item, h, do_resize=True):
context = self._context
builder = self._builder
found, i = payload._lookup(item, h)
with builder.if_then(found):
entry = payload.get_entry(i)
self._remove_entry(payload, entry, do_resize)
return found
def add(self, item, do_resize=True):
context = self._context
builder = self._builder
payload = self.payload
h = get_hash_value(context, builder, self._ty.dtype, item)
self._add_key(payload, item, h, do_resize)
def add_pyapi(self, pyapi, item, do_resize=True):
"""A version of .add for use inside functions following Python calling
convention.
"""
context = self._context
builder = self._builder
payload = self.payload
h = self._pyapi_get_hash_value(pyapi, context, builder, item)
self._add_key(payload, item, h, do_resize)
def _pyapi_get_hash_value(self, pyapi, context, builder, item):
"""Python API compatible version of `get_hash_value()`.
"""
argtypes = [self._ty.dtype]
resty = types.intp
def wrapper(val):
return _get_hash_value_intrinsic(val)
args = [item]
sig = typing.signature(resty, *argtypes)
is_error, retval = pyapi.call_jit_code(wrapper, sig, args)
# Handle return status
with builder.if_then(is_error, likely=False):
# Raise nopython exception as a Python exception
builder.ret(pyapi.get_null_object())
return retval
def contains(self, item):
context = self._context
builder = self._builder
payload = self.payload
h = get_hash_value(context, builder, self._ty.dtype, item)
found, i = payload._lookup(item, h)
return found
def discard(self, item):
context = self._context
builder = self._builder
payload = self.payload
h = get_hash_value(context, builder, self._ty.dtype, item)
found = self._remove_key(payload, item, h)
return found
def pop(self):
context = self._context
builder = self._builder
lty = context.get_value_type(self._ty.dtype)
key = cgutils.alloca_once(builder, lty)
payload = self.payload
with payload._next_entry() as entry:
builder.store(entry.key, key)
# since the value is returned don't decref in _remove_entry()
self._remove_entry(payload, entry, do_decref=False)
return builder.load(key)
def clear(self):
context = self._context
builder = self._builder
intp_t = context.get_value_type(types.intp)
minsize = ir.Constant(intp_t, MINSIZE)
self._replace_payload(minsize)
self.set_dirty(True)
def copy(self):
"""
Return a copy of this set.
"""
context = self._context
builder = self._builder
payload = self.payload
used = payload.used
fill = payload.fill
other = type(self)(context, builder, self._ty, None)
no_deleted_entries = builder.icmp_unsigned('==', used, fill)
with builder.if_else(no_deleted_entries, likely=True) \
as (if_no_deleted, if_deleted):
with if_no_deleted:
# No deleted entries => raw copy the payload
ok = other._copy_payload(payload)
with builder.if_then(builder.not_(ok), likely=False):
context.call_conv.return_user_exc(builder, MemoryError,
("cannot copy set",))
with if_deleted:
# Deleted entries => re-insert entries one by one
nentries = self.choose_alloc_size(context, builder, used)
ok = other._allocate_payload(nentries)
with builder.if_then(builder.not_(ok), likely=False):
context.call_conv.return_user_exc(builder, MemoryError,
("cannot copy set",))
other_payload = other.payload
with payload._iterate() as loop:
entry = loop.entry
other._add_key(other_payload, entry.key, entry.hash,
do_resize=False)
return other
def intersect(self, other):
"""
In-place intersection with *other* set.
"""
context = self._context
builder = self._builder
payload = self.payload
other_payload = other.payload
with payload._iterate() as loop:
entry = loop.entry
found, _ = other_payload._lookup(entry.key, entry.hash)
with builder.if_then(builder.not_(found)):
self._remove_entry(payload, entry, do_resize=False)
# Final downsize
self.downsize(payload.used)
def difference(self, other):
"""
In-place difference with *other* set.
"""
context = self._context
builder = self._builder
payload = self.payload
other_payload = other.payload
with other_payload._iterate() as loop:
entry = loop.entry
self._remove_key(payload, entry.key, entry.hash, do_resize=False)
# Final downsize
self.downsize(payload.used)
def symmetric_difference(self, other):
"""
In-place symmetric difference with *other* set.
"""
context = self._context
builder = self._builder
other_payload = other.payload
with other_payload._iterate() as loop:
key = loop.entry.key
h = loop.entry.hash
# We must reload our payload as it may be resized during the loop
payload = self.payload
found, i = payload._lookup(key, h, for_insert=True)
entry = payload.get_entry(i)
with builder.if_else(found) as (if_common, if_not_common):
with if_common:
self._remove_entry(payload, entry, do_resize=False)
with if_not_common:
self._add_entry(payload, entry, key, h)
# Final downsize
self.downsize(self.payload.used)
def issubset(self, other, strict=False):
context = self._context
builder = self._builder
payload = self.payload
other_payload = other.payload
cmp_op = '<' if strict else '<='
res = cgutils.alloca_once_value(builder, cgutils.true_bit)
with builder.if_else(
builder.icmp_unsigned(cmp_op, payload.used, other_payload.used)
) as (if_smaller, if_larger):
with if_larger:
# self larger than other => self cannot possibly a subset
builder.store(cgutils.false_bit, res)
with if_smaller:
# check whether each key of self is in other
with payload._iterate() as loop:
entry = loop.entry
found, _ = other_payload._lookup(entry.key, entry.hash)
with builder.if_then(builder.not_(found)):
builder.store(cgutils.false_bit, res)
loop.do_break()
return builder.load(res)
def isdisjoint(self, other):
context = self._context
builder = self._builder
payload = self.payload
other_payload = other.payload
res = cgutils.alloca_once_value(builder, cgutils.true_bit)
def check(smaller, larger):
# Loop over the smaller of the two, and search in the larger
with smaller._iterate() as loop:
entry = loop.entry
found, _ = larger._lookup(entry.key, entry.hash)
with builder.if_then(found):
builder.store(cgutils.false_bit, res)
loop.do_break()
with builder.if_else(
builder.icmp_unsigned('>', payload.used, other_payload.used)
) as (if_larger, otherwise):
with if_larger:
# len(self) > len(other)
check(other_payload, payload)
with otherwise:
# len(self) <= len(other)
check(payload, other_payload)
return builder.load(res)
def equals(self, other):
context = self._context
builder = self._builder
payload = self.payload
other_payload = other.payload
res = cgutils.alloca_once_value(builder, cgutils.true_bit)
with builder.if_else(
builder.icmp_unsigned('==', payload.used, other_payload.used)
) as (if_same_size, otherwise):
with if_same_size:
# same sizes => check whether each key of self is in other
with payload._iterate() as loop:
entry = loop.entry
found, _ = other_payload._lookup(entry.key, entry.hash)
with builder.if_then(builder.not_(found)):
builder.store(cgutils.false_bit, res)
loop.do_break()
with otherwise:
# different sizes => cannot possibly be equal
builder.store(cgutils.false_bit, res)
return builder.load(res)
@classmethod
def allocate_ex(cls, context, builder, set_type, nitems=None):
"""
Allocate a SetInstance with its storage.
Return a (ok, instance) tuple where *ok* is a LLVM boolean and
*instance* is a SetInstance object (the object's contents are
only valid when *ok* is true).
"""
intp_t = context.get_value_type(types.intp)
if nitems is None:
nentries = ir.Constant(intp_t, MINSIZE)
else:
if isinstance(nitems, int):
nitems = ir.Constant(intp_t, nitems)
nentries = cls.choose_alloc_size(context, builder, nitems)
self = cls(context, builder, set_type, None)
ok = self._allocate_payload(nentries)
return ok, self
@classmethod
def allocate(cls, context, builder, set_type, nitems=None):
"""
Allocate a SetInstance with its storage. Same as allocate_ex(),
but return an initialized *instance*. If allocation failed,
control is transferred to the caller using the target's current
call convention.
"""
ok, self = cls.allocate_ex(context, builder, set_type, nitems)
with builder.if_then(builder.not_(ok), likely=False):
context.call_conv.return_user_exc(builder, MemoryError,
("cannot allocate set",))
return self
@classmethod
def from_meminfo(cls, context, builder, set_type, meminfo):
"""
Allocate a new set instance pointing to an existing payload
(a meminfo pointer).
Note the parent field has to be filled by the caller.
"""
self = cls(context, builder, set_type, None)
self._set.meminfo = meminfo
self._set.parent = context.get_constant_null(types.pyobject)
context.nrt.incref(builder, set_type, self.value)
# Payload is part of the meminfo, no need to touch it
return self
@classmethod
def choose_alloc_size(cls, context, builder, nitems):
"""
Choose a suitable number of entries for the given number of items.
"""
intp_t = nitems.type
one = ir.Constant(intp_t, 1)
minsize = ir.Constant(intp_t, MINSIZE)
# Ensure number of entries >= 2 * used
min_entries = builder.shl(nitems, one)
# Find out first suitable power of 2, starting from MINSIZE
size_p = cgutils.alloca_once_value(builder, minsize)
bb_body = builder.append_basic_block("calcsize.body")
bb_end = builder.append_basic_block("calcsize.end")
builder.branch(bb_body)
with builder.goto_block(bb_body):
size = builder.load(size_p)
is_large_enough = builder.icmp_unsigned('>=', size, min_entries)
with builder.if_then(is_large_enough, likely=False):
builder.branch(bb_end)
next_size = builder.shl(size, one)
builder.store(next_size, size_p)
builder.branch(bb_body)
builder.position_at_end(bb_end)
return builder.load(size_p)
def upsize(self, nitems):
"""
When adding to the set, ensure it is properly sized for the given
number of used entries.
"""
context = self._context
builder = self._builder
intp_t = nitems.type
one = ir.Constant(intp_t, 1)
two = ir.Constant(intp_t, 2)
payload = self.payload
# Ensure number of entries >= 2 * used
min_entries = builder.shl(nitems, one)
size = builder.add(payload.mask, one)
need_resize = builder.icmp_unsigned('>=', min_entries, size)
with builder.if_then(need_resize, likely=False):
# Find out next suitable size
new_size_p = cgutils.alloca_once_value(builder, size)
bb_body = builder.append_basic_block("calcsize.body")
bb_end = builder.append_basic_block("calcsize.end")
builder.branch(bb_body)
with builder.goto_block(bb_body):
# Multiply by 4 (ensuring size remains a power of two)
new_size = builder.load(new_size_p)
new_size = builder.shl(new_size, two)
builder.store(new_size, new_size_p)
is_too_small = builder.icmp_unsigned('>=', min_entries, new_size)
builder.cbranch(is_too_small, bb_body, bb_end)
builder.position_at_end(bb_end)
new_size = builder.load(new_size_p)
if DEBUG_ALLOCS:
context.printf(builder,
"upsize to %zd items: current size = %zd, "
"min entries = %zd, new size = %zd\n",
nitems, size, min_entries, new_size)
self._resize(payload, new_size, "cannot grow set")
def downsize(self, nitems):
"""
When removing from the set, ensure it is properly sized for the given
number of used entries.
"""
context = self._context
builder = self._builder
intp_t = nitems.type
one = ir.Constant(intp_t, 1)
two = ir.Constant(intp_t, 2)
minsize = ir.Constant(intp_t, MINSIZE)
payload = self.payload
# Ensure entries >= max(2 * used, MINSIZE)
min_entries = builder.shl(nitems, one)
min_entries = builder.select(builder.icmp_unsigned('>=', min_entries, minsize),
min_entries, minsize)
# Shrink only if size >= 4 * min_entries && size > MINSIZE
max_size = builder.shl(min_entries, two)
size = builder.add(payload.mask, one)
need_resize = builder.and_(
builder.icmp_unsigned('<=', max_size, size),
builder.icmp_unsigned('<', minsize, size))
with builder.if_then(need_resize, likely=False):
# Find out next suitable size
new_size_p = cgutils.alloca_once_value(builder, size)
bb_body = builder.append_basic_block("calcsize.body")
bb_end = builder.append_basic_block("calcsize.end")
builder.branch(bb_body)
with builder.goto_block(bb_body):
# Divide by 2 (ensuring size remains a power of two)
new_size = builder.load(new_size_p)
new_size = builder.lshr(new_size, one)
# Keep current size if new size would be < min_entries
is_too_small = builder.icmp_unsigned('>', min_entries, new_size)
with builder.if_then(is_too_small):
builder.branch(bb_end)
builder.store(new_size, new_size_p)
builder.branch(bb_body)
builder.position_at_end(bb_end)
# Ensure new_size >= MINSIZE
new_size = builder.load(new_size_p)
# At this point, new_size should be < size if the factors
# above were chosen carefully!
if DEBUG_ALLOCS:
context.printf(builder,
"downsize to %zd items: current size = %zd, "
"min entries = %zd, new size = %zd\n",
nitems, size, min_entries, new_size)
self._resize(payload, new_size, "cannot shrink set")
def _resize(self, payload, nentries, errmsg):
"""
Resize the payload to the given number of entries.
CAUTION: *nentries* must be a power of 2!
"""
context = self._context
builder = self._builder
# Allocate new entries
old_payload = payload
ok = self._allocate_payload(nentries, realloc=True)
with builder.if_then(builder.not_(ok), likely=False):
context.call_conv.return_user_exc(builder, MemoryError,
(errmsg,))
# Re-insert old entries
# No incref since they already were the first time they were inserted
payload = self.payload
with old_payload._iterate() as loop:
entry = loop.entry
self._add_key(payload, entry.key, entry.hash,
do_resize=False, do_incref=False)
self._free_payload(old_payload.ptr)
def _replace_payload(self, nentries):
"""
Replace the payload with a new empty payload with the given number
of entries.
CAUTION: *nentries* must be a power of 2!
"""
context = self._context
builder = self._builder
# decref all of the previous entries
with self.payload._iterate() as loop:
entry = loop.entry
self.decref_value(entry.key)
# Free old payload
self._free_payload(self.payload.ptr)
ok = self._allocate_payload(nentries, realloc=True)
with builder.if_then(builder.not_(ok), likely=False):
context.call_conv.return_user_exc(builder, MemoryError,
("cannot reallocate set",))
def _allocate_payload(self, nentries, realloc=False):
"""
Allocate and initialize payload for the given number of entries.
If *realloc* is True, the existing meminfo is reused.
CAUTION: *nentries* must be a power of 2!
"""
context = self._context
builder = self._builder
ok = cgutils.alloca_once_value(builder, cgutils.true_bit)
intp_t = context.get_value_type(types.intp)
zero = ir.Constant(intp_t, 0)
one = ir.Constant(intp_t, 1)
payload_type = context.get_data_type(types.SetPayload(self._ty))
payload_size = context.get_abi_sizeof(payload_type)
entry_size = self._entrysize
# Account for the fact that the payload struct already contains an entry
payload_size -= entry_size
# Total allocation size = <payload header size> + nentries * entry_size
allocsize, ovf = cgutils.muladd_with_overflow(builder, nentries,
ir.Constant(intp_t, entry_size),
ir.Constant(intp_t, payload_size))
with builder.if_then(ovf, likely=False):
builder.store(cgutils.false_bit, ok)
with builder.if_then(builder.load(ok), likely=True):
if realloc:
meminfo = self._set.meminfo
ptr = context.nrt.meminfo_varsize_alloc_unchecked(builder,
meminfo,
size=allocsize)
alloc_ok = cgutils.is_null(builder, ptr)
else:
# create destructor to be called upon set destruction
dtor = self._imp_dtor(context, builder.module)
meminfo = context.nrt.meminfo_new_varsize_dtor_unchecked(
builder, allocsize, builder.bitcast(dtor, cgutils.voidptr_t))
alloc_ok = cgutils.is_null(builder, meminfo)
with builder.if_else(alloc_ok,
likely=False) as (if_error, if_ok):
with if_error:
builder.store(cgutils.false_bit, ok)
with if_ok:
if not realloc:
self._set.meminfo = meminfo
self._set.parent = context.get_constant_null(types.pyobject)
payload = self.payload
# Initialize entries to 0xff (EMPTY)
cgutils.memset(builder, payload.ptr, allocsize, 0xFF)
payload.used = zero
payload.fill = zero
payload.finger = zero
new_mask = builder.sub(nentries, one)
payload.mask = new_mask
if DEBUG_ALLOCS:
context.printf(builder,
"allocated %zd bytes for set at %p: mask = %zd\n",
allocsize, payload.ptr, new_mask)
return builder.load(ok)
def _free_payload(self, ptr):
"""
Free an allocated old payload at *ptr*.
"""
self._context.nrt.meminfo_varsize_free(self._builder, self.meminfo, ptr)
def _copy_payload(self, src_payload):
"""
Raw-copy the given payload into self.
"""
context = self._context
builder = self._builder
ok = cgutils.alloca_once_value(builder, cgutils.true_bit)
intp_t = context.get_value_type(types.intp)
zero = ir.Constant(intp_t, 0)
one = ir.Constant(intp_t, 1)
payload_type = context.get_data_type(types.SetPayload(self._ty))
payload_size = context.get_abi_sizeof(payload_type)
entry_size = self._entrysize
# Account for the fact that the payload struct already contains an entry
payload_size -= entry_size
mask = src_payload.mask
nentries = builder.add(one, mask)
# Total allocation size = <payload header size> + nentries * entry_size
# (note there can't be any overflow since we're reusing an existing
# payload's parameters)
allocsize = builder.add(ir.Constant(intp_t, payload_size),
builder.mul(ir.Constant(intp_t, entry_size),
nentries))
with builder.if_then(builder.load(ok), likely=True):
# create destructor for new meminfo
dtor = self._imp_dtor(context, builder.module)
meminfo = context.nrt.meminfo_new_varsize_dtor_unchecked(
builder, allocsize, builder.bitcast(dtor, cgutils.voidptr_t))
alloc_ok = cgutils.is_null(builder, meminfo)
with builder.if_else(alloc_ok, likely=False) as (if_error, if_ok):
with if_error:
builder.store(cgutils.false_bit, ok)
with if_ok:
self._set.meminfo = meminfo
payload = self.payload
payload.used = src_payload.used
payload.fill = src_payload.fill
payload.finger = zero
payload.mask = mask
# instead of using `_add_key` for every entry, since the
# size of the new set is the same, we can just copy the
# data directly without having to re-compute the hash
cgutils.raw_memcpy(builder, payload.entries,
src_payload.entries, nentries,
entry_size)
# increment the refcounts to simulate `_add_key` for each
# element
with src_payload._iterate() as loop:
self.incref_value(loop.entry.key)
if DEBUG_ALLOCS:
context.printf(builder,
"allocated %zd bytes for set at %p: mask = %zd\n",
allocsize, payload.ptr, mask)
return builder.load(ok)
def _imp_dtor(self, context, module):
"""Define the dtor for set
"""
llvoidptr = cgutils.voidptr_t
llsize_t= context.get_value_type(types.size_t)
# create a dtor function that takes (void* set, size_t size, void* dtor_info)
fnty = ir.FunctionType(
ir.VoidType(),
[llvoidptr, llsize_t, llvoidptr],
)
# create type-specific name
fname = f".dtor.set.{self._ty.dtype}"
fn = cgutils.get_or_insert_function(module, fnty, name=fname)
if fn.is_declaration:
# Set linkage
fn.linkage = 'linkonce_odr'
# Define
builder = ir.IRBuilder(fn.append_basic_block())
payload = _SetPayload(context, builder, self._ty, fn.args[0])
with payload._iterate() as loop:
entry = loop.entry
context.nrt.decref(builder, self._ty.dtype, entry.key)
builder.ret_void()
return fn
def incref_value(self, val):
"""Incref an element value
"""
self._context.nrt.incref(self._builder, self._ty.dtype, val)
def decref_value(self, val):
"""Decref an element value
"""
self._context.nrt.decref(self._builder, self._ty.dtype, val)
class SetIterInstance(object):
def __init__(self, context, builder, iter_type, iter_val):
self._context = context
self._builder = builder
self._ty = iter_type
self._iter = context.make_helper(builder, iter_type, iter_val)
ptr = self._context.nrt.meminfo_data(builder, self.meminfo)
self._payload = _SetPayload(context, builder, self._ty.container, ptr)
@classmethod
def from_set(cls, context, builder, iter_type, set_val):
set_inst = SetInstance(context, builder, iter_type.container, set_val)
self = cls(context, builder, iter_type, None)
index = context.get_constant(types.intp, 0)
self._iter.index = cgutils.alloca_once_value(builder, index)
self._iter.meminfo = set_inst.meminfo
return self
@property
def value(self):
return self._iter._getvalue()
@property
def meminfo(self):
return self._iter.meminfo
@property
def index(self):
return self._builder.load(self._iter.index)
@index.setter
def index(self, value):
self._builder.store(value, self._iter.index)
def iternext(self, result):
index = self.index
payload = self._payload
one = ir.Constant(index.type, 1)
result.set_exhausted()
with payload._iterate(start=index) as loop:
# An entry was found
entry = loop.entry
result.set_valid()
result.yield_(entry.key)
self.index = self._builder.add(loop.index, one)
loop.do_break()
#-------------------------------------------------------------------------------
# Constructors
def build_set(context, builder, set_type, items):
"""
Build a set of the given type, containing the given items.
"""
nitems = len(items)
inst = SetInstance.allocate(context, builder, set_type, nitems)
if nitems > 0:
# Populate set. Inlining the insertion code for each item would be very
# costly, instead we create a LLVM array and iterate over it.
array = cgutils.pack_array(builder, items)
array_ptr = cgutils.alloca_once_value(builder, array)
count = context.get_constant(types.intp, nitems)
with cgutils.for_range(builder, count) as loop:
item = builder.load(cgutils.gep(builder, array_ptr, 0, loop.index))
inst.add(item)
return impl_ret_new_ref(context, builder, set_type, inst.value)
@lower_builtin(set)
def set_empty_constructor(context, builder, sig, args):
set_type = sig.return_type
inst = SetInstance.allocate(context, builder, set_type)
return impl_ret_new_ref(context, builder, set_type, inst.value)
@lower_builtin(set, types.IterableType)
def set_constructor(context, builder, sig, args):
set_type = sig.return_type
items_type, = sig.args
items, = args
# If the argument has a len(), preallocate the set so as to
# avoid resizes.
# `for_iter` increfs each item in the set, so a `decref` is required each
# iteration to balance. Because the `incref` from `.add` is dependent on
# the item not already existing in the set, just removing its incref is not
# enough to guarantee all memory is freed
n = call_len(context, builder, items_type, items)
inst = SetInstance.allocate(context, builder, set_type, n)
with for_iter(context, builder, items_type, items) as loop:
inst.add(loop.value)
context.nrt.decref(builder, set_type.dtype, loop.value)
return impl_ret_new_ref(context, builder, set_type, inst.value)
#-------------------------------------------------------------------------------
# Various operations
@lower_builtin(len, types.Set)
def set_len(context, builder, sig, args):
inst = SetInstance(context, builder, sig.args[0], args[0])
return inst.get_size()
@lower_builtin(operator.contains, types.Set, types.Any)
def in_set(context, builder, sig, args):
inst = SetInstance(context, builder, sig.args[0], args[0])
return inst.contains(args[1])
@lower_builtin('getiter', types.Set)
def getiter_set(context, builder, sig, args):
inst = SetIterInstance.from_set(context, builder, sig.return_type, args[0])
return impl_ret_borrowed(context, builder, sig.return_type, inst.value)
@lower_builtin('iternext', types.SetIter)
@iternext_impl(RefType.BORROWED)
def iternext_listiter(context, builder, sig, args, result):
inst = SetIterInstance(context, builder, sig.args[0], args[0])
inst.iternext(result)
#-------------------------------------------------------------------------------
# Methods
# One-item-at-a-time operations
@lower_builtin("set.add", types.Set, types.Any)
def set_add(context, builder, sig, args):
inst = SetInstance(context, builder, sig.args[0], args[0])
item = args[1]
inst.add(item)
return context.get_dummy_value()
@intrinsic
def _set_discard(typingctx, s, item):
sig = types.none(s, item)
def set_discard(context, builder, sig, args):
inst = SetInstance(context, builder, sig.args[0], args[0])
item = args[1]
inst.discard(item)
return context.get_dummy_value()
return sig, set_discard
@overload_method(types.Set, "discard")
def ol_set_discard(s, item):
return lambda s, item: _set_discard(s, item)
@intrinsic
def _set_pop(typingctx, s):
sig = s.dtype(s)
def set_pop(context, builder, sig, args):
inst = SetInstance(context, builder, sig.args[0], args[0])
used = inst.payload.used
with builder.if_then(cgutils.is_null(builder, used), likely=False):
context.call_conv.return_user_exc(builder, KeyError,
("set.pop(): empty set",))
return inst.pop()
return sig, set_pop
@overload_method(types.Set, "pop")
def ol_set_pop(s):
return lambda s: _set_pop(s)
@intrinsic
def _set_remove(typingctx, s, item):
sig = types.none(s, item)
def set_remove(context, builder, sig, args):
inst = SetInstance(context, builder, sig.args[0], args[0])
item = args[1]
found = inst.discard(item)
with builder.if_then(builder.not_(found), likely=False):
context.call_conv.return_user_exc(builder, KeyError,
("set.remove(): key not in set",))
return context.get_dummy_value()
return sig, set_remove
@overload_method(types.Set, "remove")
def ol_set_remove(s, item):
if s.dtype == item:
return lambda s, item: _set_remove(s, item)
# Mutating set operations
@intrinsic
def _set_clear(typingctx, s):
sig = types.none(s)
def set_clear(context, builder, sig, args):
inst = SetInstance(context, builder, sig.args[0], args[0])
inst.clear()
return context.get_dummy_value()
return sig, set_clear
@overload_method(types.Set, "clear")
def ol_set_clear(s):
return lambda s: _set_clear(s)
@intrinsic
def _set_copy(typingctx, s):
sig = s(s)
def set_copy(context, builder, sig, args):
inst = SetInstance(context, builder, sig.args[0], args[0])
other = inst.copy()
return impl_ret_new_ref(context, builder, sig.return_type, other.value)
return sig, set_copy
@overload_method(types.Set, "copy")
def ol_set_copy(s):
return lambda s: _set_copy(s)
def set_difference_update(context, builder, sig, args):
inst = SetInstance(context, builder, sig.args[0], args[0])
other = SetInstance(context, builder, sig.args[1], args[1])
inst.difference(other)
return context.get_dummy_value()
@intrinsic
def _set_difference_update(typingctx, a, b):
sig = types.none(a, b)
return sig, set_difference_update
@overload_method(types.Set, "difference_update")
def set_difference_update_impl(a, b):
check_all_set(a, b)
return lambda a, b: _set_difference_update(a, b)
def set_intersection_update(context, builder, sig, args):
inst = SetInstance(context, builder, sig.args[0], args[0])
other = SetInstance(context, builder, sig.args[1], args[1])
inst.intersect(other)
return context.get_dummy_value()
@intrinsic
def _set_intersection_update(typingctx, a, b):
sig = types.none(a, b)
return sig, set_intersection_update
@overload_method(types.Set, "intersection_update")
def set_intersection_update_impl(a, b):
check_all_set(a, b)
return lambda a, b: _set_intersection_update(a, b)
def set_symmetric_difference_update(context, builder, sig, args):
inst = SetInstance(context, builder, sig.args[0], args[0])
other = SetInstance(context, builder, sig.args[1], args[1])
inst.symmetric_difference(other)
return context.get_dummy_value()
@intrinsic
def _set_symmetric_difference_update(typingctx, a, b):
sig = types.none(a, b)
return sig, set_symmetric_difference_update
@overload_method(types.Set, "symmetric_difference_update")
def set_symmetric_difference_update_impl(a, b):
check_all_set(a, b)
return lambda a, b: _set_symmetric_difference_update(a, b)
@lower_builtin("set.update", types.Set, types.IterableType)
def set_update(context, builder, sig, args):
inst = SetInstance(context, builder, sig.args[0], args[0])
items_type = sig.args[1]
items = args[1]
# If the argument has a len(), assume there are few collisions and
# presize to len(set) + len(items)
n = call_len(context, builder, items_type, items)
if n is not None:
new_size = builder.add(inst.payload.used, n)
inst.upsize(new_size)
with for_iter(context, builder, items_type, items) as loop:
# make sure that the items being added are of the same dtype as the
# set instance
casted = context.cast(builder, loop.value, items_type.dtype, inst.dtype)
inst.add(casted)
# decref each item to counter balance the incref from `for_iter`
# `.add` will conditionally incref when the item does not already exist
# in the set, therefore removing its incref is not enough to guarantee
# all memory is freed
context.nrt.decref(builder, items_type.dtype, loop.value)
if n is not None:
# If we pre-grew the set, downsize in case there were many collisions
inst.downsize(inst.payload.used)
return context.get_dummy_value()
def gen_operator_impl(op, impl):
@intrinsic
def _set_operator_intr(typingctx, a, b):
sig = a(a, b)
def codegen(context, builder, sig, args):
assert sig.return_type == sig.args[0]
impl(context, builder, sig, args)
return impl_ret_borrowed(context, builder, sig.args[0], args[0])
return sig, codegen
@overload(op)
def _ol_set_operator(a, b):
check_all_set(a, b)
return lambda a, b: _set_operator_intr(a, b)
for op_, op_impl in [
(operator.iand, set_intersection_update),
(operator.ior, set_update),
(operator.isub, set_difference_update),
(operator.ixor, set_symmetric_difference_update),
]:
gen_operator_impl(op_, op_impl)
# Set operations creating a new set
@overload(operator.sub)
@overload_method(types.Set, "difference")
def impl_set_difference(a, b):
check_all_set(a, b)
def difference_impl(a, b):
s = a.copy()
s.difference_update(b)
return s
return difference_impl
@overload(operator.and_)
@overload_method(types.Set, "intersection")
def set_intersection(a, b):
check_all_set(a, b)
def intersection_impl(a, b):
if len(a) < len(b):
s = a.copy()
s.intersection_update(b)
return s
else:
s = b.copy()
s.intersection_update(a)
return s
return intersection_impl
@overload(operator.xor)
@overload_method(types.Set, "symmetric_difference")
def set_symmetric_difference(a, b):
check_all_set(a, b)
def symmetric_difference_impl(a, b):
if len(a) > len(b):
s = a.copy()
s.symmetric_difference_update(b)
return s
else:
s = b.copy()
s.symmetric_difference_update(a)
return s
return symmetric_difference_impl
@overload(operator.or_)
@overload_method(types.Set, "union")
def set_union(a, b):
check_all_set(a, b)
def union_impl(a, b):
if len(a) > len(b):
s = a.copy()
s.update(b)
return s
else:
s = b.copy()
s.update(a)
return s
return union_impl
# Predicates
@intrinsic
def _set_isdisjoint(typingctx, a, b):
sig = types.boolean(a, b)
def codegen(context, builder, sig, args):
inst = SetInstance(context, builder, sig.args[0], args[0])
other = SetInstance(context, builder, sig.args[1], args[1])
return inst.isdisjoint(other)
return sig, codegen
@overload_method(types.Set, "isdisjoint")
def set_isdisjoint(a, b):
check_all_set(a, b)
return lambda a, b: _set_isdisjoint(a, b)
@intrinsic
def _set_issubset(typingctx, a, b):
sig = types.boolean(a, b)
def codegen(context, builder, sig, args):
inst = SetInstance(context, builder, sig.args[0], args[0])
other = SetInstance(context, builder, sig.args[1], args[1])
return inst.issubset(other)
return sig, codegen
@overload(operator.le)
@overload_method(types.Set, "issubset")
def set_issubset(a, b):
check_all_set(a, b)
return lambda a, b: _set_issubset(a, b)
@overload(operator.ge)
@overload_method(types.Set, "issuperset")
def set_issuperset(a, b):
check_all_set(a, b)
def superset_impl(a, b):
return b.issubset(a)
return superset_impl
@intrinsic
def _set_eq(typingctx, a, b):
sig = types.boolean(a, b)
def codegen(context, builder, sig, args):
inst = SetInstance(context, builder, sig.args[0], args[0])
other = SetInstance(context, builder, sig.args[1], args[1])
return inst.equals(other)
return sig, codegen
@overload(operator.eq)
def set_eq(a, b):
check_all_set(a, b)
return lambda a, b: _set_eq(a, b)
@overload(operator.ne)
def set_ne(a, b):
check_all_set(a, b)
def ne_impl(a, b):
return not a == b
return ne_impl
@intrinsic
def _set_lt(typingctx, a, b):
sig = types.boolean(a, b)
def codegen(context, builder, sig, args):
inst = SetInstance(context, builder, sig.args[0], args[0])
other = SetInstance(context, builder, sig.args[1], args[1])
return inst.issubset(other, strict=True)
return sig, codegen
@overload(operator.lt)
def set_lt(a, b):
check_all_set(a, b)
return lambda a, b: _set_lt(a, b)
@overload(operator.gt)
def set_gt(a, b):
check_all_set(a, b)
def gt_impl(a, b):
return b < a
return gt_impl
@lower_builtin(operator.is_, types.Set, types.Set)
def set_is(context, builder, sig, args):
a = SetInstance(context, builder, sig.args[0], args[0])
b = SetInstance(context, builder, sig.args[1], args[1])
ma = builder.ptrtoint(a.meminfo, cgutils.intp_t)
mb = builder.ptrtoint(b.meminfo, cgutils.intp_t)
return builder.icmp_signed('==', ma, mb)
# -----------------------------------------------------------------------------
# Implicit casting
@lower_cast(types.Set, types.Set)
def set_to_set(context, builder, fromty, toty, val):
# Casting from non-reflected to reflected
assert fromty.dtype == toty.dtype
return val