366 lines
10 KiB
Python
366 lines
10 KiB
Python
|
import numpy as np
|
||
|
|
||
|
from numba import uint64, uint32, uint16, uint8
|
||
|
from numba.core.extending import register_jitable
|
||
|
|
||
|
from numba.np.random._constants import (UINT32_MAX, UINT64_MAX,
|
||
|
UINT16_MAX, UINT8_MAX)
|
||
|
from numba.np.random.generator_core import next_uint32, next_uint64
|
||
|
|
||
|
# All following implementations are direct translations from:
|
||
|
# https://github.com/numpy/numpy/blob/7cfef93c77599bd387ecc6a15d186c5a46024dac/numpy/random/src/distributions/distributions.c
|
||
|
|
||
|
|
||
|
@register_jitable
|
||
|
def gen_mask(max):
|
||
|
mask = uint64(max)
|
||
|
mask |= mask >> 1
|
||
|
mask |= mask >> 2
|
||
|
mask |= mask >> 4
|
||
|
mask |= mask >> 8
|
||
|
mask |= mask >> 16
|
||
|
mask |= mask >> 32
|
||
|
return mask
|
||
|
|
||
|
|
||
|
@register_jitable
|
||
|
def buffered_bounded_bool(bitgen, off, rng, bcnt, buf):
|
||
|
if (rng == 0):
|
||
|
return off, bcnt, buf
|
||
|
if not bcnt:
|
||
|
buf = next_uint32(bitgen)
|
||
|
bcnt = 31
|
||
|
else:
|
||
|
buf >>= 1
|
||
|
bcnt -= 1
|
||
|
|
||
|
return ((buf & 1) != 0), bcnt, buf
|
||
|
|
||
|
|
||
|
@register_jitable
|
||
|
def buffered_uint8(bitgen, bcnt, buf):
|
||
|
if not bcnt:
|
||
|
buf = next_uint32(bitgen)
|
||
|
bcnt = 3
|
||
|
else:
|
||
|
buf >>= 8
|
||
|
bcnt -= 1
|
||
|
|
||
|
return uint8(buf), bcnt, buf
|
||
|
|
||
|
|
||
|
@register_jitable
|
||
|
def buffered_uint16(bitgen, bcnt, buf):
|
||
|
if not bcnt:
|
||
|
buf = next_uint32(bitgen)
|
||
|
bcnt = 1
|
||
|
else:
|
||
|
buf >>= 16
|
||
|
bcnt -= 1
|
||
|
|
||
|
return uint16(buf), bcnt, buf
|
||
|
|
||
|
|
||
|
# The following implementations use Lemire's algorithm:
|
||
|
# https://arxiv.org/abs/1805.10941
|
||
|
@register_jitable
|
||
|
def buffered_bounded_lemire_uint8(bitgen, rng, bcnt, buf):
|
||
|
"""
|
||
|
Generates a random unsigned 8 bit integer bounded
|
||
|
within a given interval using Lemire's rejection.
|
||
|
|
||
|
The buffer acts as storage for a 32 bit integer
|
||
|
drawn from the associated BitGenerator so that
|
||
|
multiple integers of smaller bitsize can be generated
|
||
|
from a single draw of the BitGenerator.
|
||
|
"""
|
||
|
# Note: `rng` should not be 0xFF. When this happens `rng_excl` becomes
|
||
|
# zero.
|
||
|
rng_excl = uint8(rng) + uint8(1)
|
||
|
|
||
|
assert (rng != 0xFF)
|
||
|
|
||
|
# Generate a scaled random number.
|
||
|
n, bcnt, buf = buffered_uint8(bitgen, bcnt, buf)
|
||
|
m = uint16(n * rng_excl)
|
||
|
|
||
|
# Rejection sampling to remove any bias
|
||
|
leftover = m & 0xFF
|
||
|
|
||
|
if (leftover < rng_excl):
|
||
|
# `rng_excl` is a simple upper bound for `threshold`.
|
||
|
threshold = ((uint8(UINT8_MAX) - rng) % rng_excl)
|
||
|
|
||
|
while (leftover < threshold):
|
||
|
n, bcnt, buf = buffered_uint8(bitgen, bcnt, buf)
|
||
|
m = uint16(n * rng_excl)
|
||
|
leftover = m & 0xFF
|
||
|
|
||
|
return m >> 8, bcnt, buf
|
||
|
|
||
|
|
||
|
@register_jitable
|
||
|
def buffered_bounded_lemire_uint16(bitgen, rng, bcnt, buf):
|
||
|
"""
|
||
|
Generates a random unsigned 16 bit integer bounded
|
||
|
within a given interval using Lemire's rejection.
|
||
|
|
||
|
The buffer acts as storage for a 32 bit integer
|
||
|
drawn from the associated BitGenerator so that
|
||
|
multiple integers of smaller bitsize can be generated
|
||
|
from a single draw of the BitGenerator.
|
||
|
"""
|
||
|
# Note: `rng` should not be 0xFFFF. When this happens `rng_excl` becomes
|
||
|
# zero.
|
||
|
rng_excl = uint16(rng) + uint16(1)
|
||
|
|
||
|
assert (rng != 0xFFFF)
|
||
|
|
||
|
# Generate a scaled random number.
|
||
|
n, bcnt, buf = buffered_uint16(bitgen, bcnt, buf)
|
||
|
m = uint32(n * rng_excl)
|
||
|
|
||
|
# Rejection sampling to remove any bias
|
||
|
leftover = m & 0xFFFF
|
||
|
|
||
|
if (leftover < rng_excl):
|
||
|
# `rng_excl` is a simple upper bound for `threshold`.
|
||
|
threshold = ((uint16(UINT16_MAX) - rng) % rng_excl)
|
||
|
|
||
|
while (leftover < threshold):
|
||
|
n, bcnt, buf = buffered_uint16(bitgen, bcnt, buf)
|
||
|
m = uint32(n * rng_excl)
|
||
|
leftover = m & 0xFFFF
|
||
|
|
||
|
return m >> 16, bcnt, buf
|
||
|
|
||
|
|
||
|
@register_jitable
|
||
|
def buffered_bounded_lemire_uint32(bitgen, rng):
|
||
|
"""
|
||
|
Generates a random unsigned 32 bit integer bounded
|
||
|
within a given interval using Lemire's rejection.
|
||
|
"""
|
||
|
rng_excl = uint32(rng) + uint32(1)
|
||
|
|
||
|
assert (rng != 0xFFFFFFFF)
|
||
|
|
||
|
# Generate a scaled random number.
|
||
|
m = uint64(next_uint32(bitgen)) * uint64(rng_excl)
|
||
|
|
||
|
# Rejection sampling to remove any bias
|
||
|
leftover = m & 0xFFFFFFFF
|
||
|
|
||
|
if (leftover < rng_excl):
|
||
|
# `rng_excl` is a simple upper bound for `threshold`.
|
||
|
threshold = (UINT32_MAX - rng) % rng_excl
|
||
|
|
||
|
while (leftover < threshold):
|
||
|
m = uint64(next_uint32(bitgen)) * uint64(rng_excl)
|
||
|
leftover = m & 0xFFFFFFFF
|
||
|
|
||
|
return (m >> 32)
|
||
|
|
||
|
|
||
|
@register_jitable
|
||
|
def bounded_lemire_uint64(bitgen, rng):
|
||
|
"""
|
||
|
Generates a random unsigned 64 bit integer bounded
|
||
|
within a given interval using Lemire's rejection.
|
||
|
"""
|
||
|
rng_excl = uint64(rng) + uint64(1)
|
||
|
|
||
|
assert (rng != 0xFFFFFFFFFFFFFFFF)
|
||
|
|
||
|
x = next_uint64(bitgen)
|
||
|
|
||
|
leftover = uint64(x) * uint64(rng_excl)
|
||
|
|
||
|
if (leftover < rng_excl):
|
||
|
threshold = (UINT64_MAX - rng) % rng_excl
|
||
|
|
||
|
while (leftover < threshold):
|
||
|
x = next_uint64(bitgen)
|
||
|
leftover = uint64(x) * uint64(rng_excl)
|
||
|
|
||
|
x0 = x & uint64(0xFFFFFFFF)
|
||
|
x1 = x >> 32
|
||
|
rng_excl0 = rng_excl & uint64(0xFFFFFFFF)
|
||
|
rng_excl1 = rng_excl >> 32
|
||
|
w0 = x0 * rng_excl0
|
||
|
t = x1 * rng_excl0 + (w0 >> 32)
|
||
|
w1 = t & uint64(0xFFFFFFFF)
|
||
|
w2 = t >> 32
|
||
|
w1 += x0 * rng_excl1
|
||
|
m1 = x1 * rng_excl1 + w2 + (w1 >> 32)
|
||
|
|
||
|
return m1
|
||
|
|
||
|
|
||
|
@register_jitable
|
||
|
def random_bounded_uint64_fill(bitgen, low, rng, size, dtype):
|
||
|
"""
|
||
|
Returns a new array of given size with 64 bit integers
|
||
|
bounded by given interval.
|
||
|
"""
|
||
|
out = np.empty(size, dtype=dtype)
|
||
|
if rng == 0:
|
||
|
for i in np.ndindex(size):
|
||
|
out[i] = low
|
||
|
elif rng <= 0xFFFFFFFF:
|
||
|
if (rng == 0xFFFFFFFF):
|
||
|
for i in np.ndindex(size):
|
||
|
out[i] = low + next_uint32(bitgen)
|
||
|
else:
|
||
|
for i in np.ndindex(size):
|
||
|
out[i] = low + buffered_bounded_lemire_uint32(bitgen, rng)
|
||
|
|
||
|
elif (rng == 0xFFFFFFFFFFFFFFFF):
|
||
|
for i in np.ndindex(size):
|
||
|
out[i] = low + next_uint64(bitgen)
|
||
|
else:
|
||
|
for i in np.ndindex(size):
|
||
|
out[i] = low + bounded_lemire_uint64(bitgen, rng)
|
||
|
|
||
|
return out
|
||
|
|
||
|
|
||
|
@register_jitable
|
||
|
def random_bounded_uint32_fill(bitgen, low, rng, size, dtype):
|
||
|
"""
|
||
|
Returns a new array of given size with 32 bit integers
|
||
|
bounded by given interval.
|
||
|
"""
|
||
|
out = np.empty(size, dtype=dtype)
|
||
|
if rng == 0:
|
||
|
for i in np.ndindex(size):
|
||
|
out[i] = low
|
||
|
elif rng == 0xFFFFFFFF:
|
||
|
# Lemire32 doesn't support rng = 0xFFFFFFFF.
|
||
|
for i in np.ndindex(size):
|
||
|
out[i] = low + next_uint32(bitgen)
|
||
|
else:
|
||
|
for i in np.ndindex(size):
|
||
|
out[i] = low + buffered_bounded_lemire_uint32(bitgen, rng)
|
||
|
return out
|
||
|
|
||
|
|
||
|
@register_jitable
|
||
|
def random_bounded_uint16_fill(bitgen, low, rng, size, dtype):
|
||
|
"""
|
||
|
Returns a new array of given size with 16 bit integers
|
||
|
bounded by given interval.
|
||
|
"""
|
||
|
buf = 0
|
||
|
bcnt = 0
|
||
|
|
||
|
out = np.empty(size, dtype=dtype)
|
||
|
if rng == 0:
|
||
|
for i in np.ndindex(size):
|
||
|
out[i] = low
|
||
|
elif rng == 0xFFFF:
|
||
|
# Lemire16 doesn't support rng = 0xFFFF.
|
||
|
for i in np.ndindex(size):
|
||
|
val, bcnt, buf = buffered_uint16(bitgen, bcnt, buf)
|
||
|
out[i] = low + val
|
||
|
|
||
|
else:
|
||
|
for i in np.ndindex(size):
|
||
|
val, bcnt, buf = \
|
||
|
buffered_bounded_lemire_uint16(bitgen, rng,
|
||
|
bcnt, buf)
|
||
|
out[i] = low + val
|
||
|
return out
|
||
|
|
||
|
|
||
|
@register_jitable
|
||
|
def random_bounded_uint8_fill(bitgen, low, rng, size, dtype):
|
||
|
"""
|
||
|
Returns a new array of given size with 8 bit integers
|
||
|
bounded by given interval.
|
||
|
"""
|
||
|
buf = 0
|
||
|
bcnt = 0
|
||
|
|
||
|
out = np.empty(size, dtype=dtype)
|
||
|
if rng == 0:
|
||
|
for i in np.ndindex(size):
|
||
|
out[i] = low
|
||
|
elif rng == 0xFF:
|
||
|
# Lemire8 doesn't support rng = 0xFF.
|
||
|
for i in np.ndindex(size):
|
||
|
val, bcnt, buf = buffered_uint8(bitgen, bcnt, buf)
|
||
|
out[i] = low + val
|
||
|
else:
|
||
|
for i in np.ndindex(size):
|
||
|
val, bcnt, buf = \
|
||
|
buffered_bounded_lemire_uint8(bitgen, rng,
|
||
|
bcnt, buf)
|
||
|
out[i] = low + val
|
||
|
return out
|
||
|
|
||
|
|
||
|
@register_jitable
|
||
|
def random_bounded_bool_fill(bitgen, low, rng, size, dtype):
|
||
|
"""
|
||
|
Returns a new array of given size with boolean values.
|
||
|
"""
|
||
|
buf = 0
|
||
|
bcnt = 0
|
||
|
out = np.empty(size, dtype=dtype)
|
||
|
for i in np.ndindex(size):
|
||
|
val, bcnt, buf = buffered_bounded_bool(bitgen, low, rng, bcnt, buf)
|
||
|
out[i] = low + val
|
||
|
return out
|
||
|
|
||
|
|
||
|
@register_jitable
|
||
|
def _randint_arg_check(low, high, endpoint, lower_bound, upper_bound):
|
||
|
"""
|
||
|
Check that low and high are within the bounds
|
||
|
for the given datatype.
|
||
|
"""
|
||
|
|
||
|
if low < lower_bound:
|
||
|
raise ValueError("low is out of bounds")
|
||
|
|
||
|
# This is being done to avoid high being accidentally
|
||
|
# casted to int64/32 while subtracting 1 before
|
||
|
# checking bounds, avoids overflow.
|
||
|
if high > 0:
|
||
|
high = uint64(high)
|
||
|
if not endpoint:
|
||
|
high -= uint64(1)
|
||
|
upper_bound = uint64(upper_bound)
|
||
|
if low > 0:
|
||
|
low = uint64(low)
|
||
|
if high > upper_bound:
|
||
|
raise ValueError("high is out of bounds")
|
||
|
if low > high: # -1 already subtracted, closed interval
|
||
|
raise ValueError("low is greater than high in given interval")
|
||
|
else:
|
||
|
if high > upper_bound:
|
||
|
raise ValueError("high is out of bounds")
|
||
|
if low > high: # -1 already subtracted, closed interval
|
||
|
raise ValueError("low is greater than high in given interval")
|
||
|
|
||
|
|
||
|
@register_jitable
|
||
|
def random_interval(bitgen, max_val):
|
||
|
if (max_val == 0):
|
||
|
return 0
|
||
|
|
||
|
max_val = uint64(max_val)
|
||
|
mask = uint64(gen_mask(max_val))
|
||
|
|
||
|
if (max_val <= 0xffffffff):
|
||
|
value = uint64(next_uint32(bitgen)) & mask
|
||
|
while value > max_val:
|
||
|
value = uint64(next_uint32(bitgen)) & mask
|
||
|
else:
|
||
|
value = next_uint64(bitgen) & mask
|
||
|
while value > max_val:
|
||
|
value = next_uint64(bitgen) & mask
|
||
|
|
||
|
return uint64(value)
|