ai-content-maker/.venv/Lib/site-packages/numba/np/random/random_methods.py

366 lines
10 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
import numpy as np
from numba import uint64, uint32, uint16, uint8
from numba.core.extending import register_jitable
from numba.np.random._constants import (UINT32_MAX, UINT64_MAX,
UINT16_MAX, UINT8_MAX)
from numba.np.random.generator_core import next_uint32, next_uint64
# All following implementations are direct translations from:
# https://github.com/numpy/numpy/blob/7cfef93c77599bd387ecc6a15d186c5a46024dac/numpy/random/src/distributions/distributions.c
@register_jitable
def gen_mask(max):
mask = uint64(max)
mask |= mask >> 1
mask |= mask >> 2
mask |= mask >> 4
mask |= mask >> 8
mask |= mask >> 16
mask |= mask >> 32
return mask
@register_jitable
def buffered_bounded_bool(bitgen, off, rng, bcnt, buf):
if (rng == 0):
return off, bcnt, buf
if not bcnt:
buf = next_uint32(bitgen)
bcnt = 31
else:
buf >>= 1
bcnt -= 1
return ((buf & 1) != 0), bcnt, buf
@register_jitable
def buffered_uint8(bitgen, bcnt, buf):
if not bcnt:
buf = next_uint32(bitgen)
bcnt = 3
else:
buf >>= 8
bcnt -= 1
return uint8(buf), bcnt, buf
@register_jitable
def buffered_uint16(bitgen, bcnt, buf):
if not bcnt:
buf = next_uint32(bitgen)
bcnt = 1
else:
buf >>= 16
bcnt -= 1
return uint16(buf), bcnt, buf
# The following implementations use Lemire's algorithm:
# https://arxiv.org/abs/1805.10941
@register_jitable
def buffered_bounded_lemire_uint8(bitgen, rng, bcnt, buf):
"""
Generates a random unsigned 8 bit integer bounded
within a given interval using Lemire's rejection.
The buffer acts as storage for a 32 bit integer
drawn from the associated BitGenerator so that
multiple integers of smaller bitsize can be generated
from a single draw of the BitGenerator.
"""
# Note: `rng` should not be 0xFF. When this happens `rng_excl` becomes
# zero.
rng_excl = uint8(rng) + uint8(1)
assert (rng != 0xFF)
# Generate a scaled random number.
n, bcnt, buf = buffered_uint8(bitgen, bcnt, buf)
m = uint16(n * rng_excl)
# Rejection sampling to remove any bias
leftover = m & 0xFF
if (leftover < rng_excl):
# `rng_excl` is a simple upper bound for `threshold`.
threshold = ((uint8(UINT8_MAX) - rng) % rng_excl)
while (leftover < threshold):
n, bcnt, buf = buffered_uint8(bitgen, bcnt, buf)
m = uint16(n * rng_excl)
leftover = m & 0xFF
return m >> 8, bcnt, buf
@register_jitable
def buffered_bounded_lemire_uint16(bitgen, rng, bcnt, buf):
"""
Generates a random unsigned 16 bit integer bounded
within a given interval using Lemire's rejection.
The buffer acts as storage for a 32 bit integer
drawn from the associated BitGenerator so that
multiple integers of smaller bitsize can be generated
from a single draw of the BitGenerator.
"""
# Note: `rng` should not be 0xFFFF. When this happens `rng_excl` becomes
# zero.
rng_excl = uint16(rng) + uint16(1)
assert (rng != 0xFFFF)
# Generate a scaled random number.
n, bcnt, buf = buffered_uint16(bitgen, bcnt, buf)
m = uint32(n * rng_excl)
# Rejection sampling to remove any bias
leftover = m & 0xFFFF
if (leftover < rng_excl):
# `rng_excl` is a simple upper bound for `threshold`.
threshold = ((uint16(UINT16_MAX) - rng) % rng_excl)
while (leftover < threshold):
n, bcnt, buf = buffered_uint16(bitgen, bcnt, buf)
m = uint32(n * rng_excl)
leftover = m & 0xFFFF
return m >> 16, bcnt, buf
@register_jitable
def buffered_bounded_lemire_uint32(bitgen, rng):
"""
Generates a random unsigned 32 bit integer bounded
within a given interval using Lemire's rejection.
"""
rng_excl = uint32(rng) + uint32(1)
assert (rng != 0xFFFFFFFF)
# Generate a scaled random number.
m = uint64(next_uint32(bitgen)) * uint64(rng_excl)
# Rejection sampling to remove any bias
leftover = m & 0xFFFFFFFF
if (leftover < rng_excl):
# `rng_excl` is a simple upper bound for `threshold`.
threshold = (UINT32_MAX - rng) % rng_excl
while (leftover < threshold):
m = uint64(next_uint32(bitgen)) * uint64(rng_excl)
leftover = m & 0xFFFFFFFF
return (m >> 32)
@register_jitable
def bounded_lemire_uint64(bitgen, rng):
"""
Generates a random unsigned 64 bit integer bounded
within a given interval using Lemire's rejection.
"""
rng_excl = uint64(rng) + uint64(1)
assert (rng != 0xFFFFFFFFFFFFFFFF)
x = next_uint64(bitgen)
leftover = uint64(x) * uint64(rng_excl)
if (leftover < rng_excl):
threshold = (UINT64_MAX - rng) % rng_excl
while (leftover < threshold):
x = next_uint64(bitgen)
leftover = uint64(x) * uint64(rng_excl)
x0 = x & uint64(0xFFFFFFFF)
x1 = x >> 32
rng_excl0 = rng_excl & uint64(0xFFFFFFFF)
rng_excl1 = rng_excl >> 32
w0 = x0 * rng_excl0
t = x1 * rng_excl0 + (w0 >> 32)
w1 = t & uint64(0xFFFFFFFF)
w2 = t >> 32
w1 += x0 * rng_excl1
m1 = x1 * rng_excl1 + w2 + (w1 >> 32)
return m1
@register_jitable
def random_bounded_uint64_fill(bitgen, low, rng, size, dtype):
"""
Returns a new array of given size with 64 bit integers
bounded by given interval.
"""
out = np.empty(size, dtype=dtype)
if rng == 0:
for i in np.ndindex(size):
out[i] = low
elif rng <= 0xFFFFFFFF:
if (rng == 0xFFFFFFFF):
for i in np.ndindex(size):
out[i] = low + next_uint32(bitgen)
else:
for i in np.ndindex(size):
out[i] = low + buffered_bounded_lemire_uint32(bitgen, rng)
elif (rng == 0xFFFFFFFFFFFFFFFF):
for i in np.ndindex(size):
out[i] = low + next_uint64(bitgen)
else:
for i in np.ndindex(size):
out[i] = low + bounded_lemire_uint64(bitgen, rng)
return out
@register_jitable
def random_bounded_uint32_fill(bitgen, low, rng, size, dtype):
"""
Returns a new array of given size with 32 bit integers
bounded by given interval.
"""
out = np.empty(size, dtype=dtype)
if rng == 0:
for i in np.ndindex(size):
out[i] = low
elif rng == 0xFFFFFFFF:
# Lemire32 doesn't support rng = 0xFFFFFFFF.
for i in np.ndindex(size):
out[i] = low + next_uint32(bitgen)
else:
for i in np.ndindex(size):
out[i] = low + buffered_bounded_lemire_uint32(bitgen, rng)
return out
@register_jitable
def random_bounded_uint16_fill(bitgen, low, rng, size, dtype):
"""
Returns a new array of given size with 16 bit integers
bounded by given interval.
"""
buf = 0
bcnt = 0
out = np.empty(size, dtype=dtype)
if rng == 0:
for i in np.ndindex(size):
out[i] = low
elif rng == 0xFFFF:
# Lemire16 doesn't support rng = 0xFFFF.
for i in np.ndindex(size):
val, bcnt, buf = buffered_uint16(bitgen, bcnt, buf)
out[i] = low + val
else:
for i in np.ndindex(size):
val, bcnt, buf = \
buffered_bounded_lemire_uint16(bitgen, rng,
bcnt, buf)
out[i] = low + val
return out
@register_jitable
def random_bounded_uint8_fill(bitgen, low, rng, size, dtype):
"""
Returns a new array of given size with 8 bit integers
bounded by given interval.
"""
buf = 0
bcnt = 0
out = np.empty(size, dtype=dtype)
if rng == 0:
for i in np.ndindex(size):
out[i] = low
elif rng == 0xFF:
# Lemire8 doesn't support rng = 0xFF.
for i in np.ndindex(size):
val, bcnt, buf = buffered_uint8(bitgen, bcnt, buf)
out[i] = low + val
else:
for i in np.ndindex(size):
val, bcnt, buf = \
buffered_bounded_lemire_uint8(bitgen, rng,
bcnt, buf)
out[i] = low + val
return out
@register_jitable
def random_bounded_bool_fill(bitgen, low, rng, size, dtype):
"""
Returns a new array of given size with boolean values.
"""
buf = 0
bcnt = 0
out = np.empty(size, dtype=dtype)
for i in np.ndindex(size):
val, bcnt, buf = buffered_bounded_bool(bitgen, low, rng, bcnt, buf)
out[i] = low + val
return out
@register_jitable
def _randint_arg_check(low, high, endpoint, lower_bound, upper_bound):
"""
Check that low and high are within the bounds
for the given datatype.
"""
if low < lower_bound:
raise ValueError("low is out of bounds")
# This is being done to avoid high being accidentally
# casted to int64/32 while subtracting 1 before
# checking bounds, avoids overflow.
if high > 0:
high = uint64(high)
if not endpoint:
high -= uint64(1)
upper_bound = uint64(upper_bound)
if low > 0:
low = uint64(low)
if high > upper_bound:
raise ValueError("high is out of bounds")
if low > high: # -1 already subtracted, closed interval
raise ValueError("low is greater than high in given interval")
else:
if high > upper_bound:
raise ValueError("high is out of bounds")
if low > high: # -1 already subtracted, closed interval
raise ValueError("low is greater than high in given interval")
@register_jitable
def random_interval(bitgen, max_val):
if (max_val == 0):
return 0
max_val = uint64(max_val)
mask = uint64(gen_mask(max_val))
if (max_val <= 0xffffffff):
value = uint64(next_uint32(bitgen)) & mask
while value > max_val:
value = uint64(next_uint32(bitgen)) & mask
else:
value = next_uint64(bitgen) & mask
while value > max_val:
value = next_uint64(bitgen) & mask
return uint64(value)