ai-content-maker/.venv/Lib/site-packages/blis/tests/test_gemm.py

76 lines
2.4 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
# Copyright ExplosionAI GmbH, released under BSD.
from __future__ import division
from hypothesis import given, assume
from math import sqrt, floor
from blis.tests.common import *
from blis.py import gemm
def _stretch_matrix(data, m, n):
orig_len = len(data)
orig_m = m
orig_n = n
ratio = sqrt(len(data) / (m * n))
m = int(floor(m * ratio))
n = int(floor(n * ratio))
data = np.ascontiguousarray(data[: m * n], dtype=data.dtype)
return data.reshape((m, n)), m, n
def _reshape_for_gemm(
A, B, a_rows, a_cols, out_cols, dtype, trans_a=False, trans_b=False
):
A, a_rows, a_cols = _stretch_matrix(A, a_rows, a_cols)
if len(B) < a_cols or a_cols < 1:
return (None, None, None)
b_cols = int(floor(len(B) / a_cols))
B = np.ascontiguousarray(B.flatten()[: a_cols * b_cols], dtype=dtype)
B = B.reshape((a_cols, b_cols))
out_cols = B.shape[1]
C = np.zeros(shape=(A.shape[0], B.shape[1]), dtype=dtype)
if trans_a:
A = np.ascontiguousarray(A.T, dtype=dtype)
return A, B, C
@given(
ndarrays(min_len=10, max_len=100, min_val=-100.0, max_val=100.0, dtype="float64"),
ndarrays(min_len=10, max_len=100, min_val=-100.0, max_val=100.0, dtype="float64"),
integers(min_value=2, max_value=1000),
integers(min_value=2, max_value=1000),
integers(min_value=2, max_value=1000),
)
def test_memoryview_double_notrans(A, B, a_rows, a_cols, out_cols):
A, B, C = _reshape_for_gemm(A, B, a_rows, a_cols, out_cols, "float64")
assume(A is not None)
assume(B is not None)
assume(C is not None)
assume(A.size >= 1)
assume(B.size >= 1)
assume(C.size >= 1)
gemm(A, B, out=C)
numpy_result = A.dot(B)
assert_allclose(numpy_result, C, atol=1e-4, rtol=1e-4)
@given(
ndarrays(min_len=10, max_len=100, min_val=-100.0, max_val=100.0, dtype="float32"),
ndarrays(min_len=10, max_len=100, min_val=-100.0, max_val=100.0, dtype="float32"),
integers(min_value=2, max_value=1000),
integers(min_value=2, max_value=1000),
integers(min_value=2, max_value=1000),
)
def test_memoryview_float_notrans(A, B, a_rows, a_cols, out_cols):
A, B, C = _reshape_for_gemm(A, B, a_rows, a_cols, out_cols, dtype="float32")
assume(A is not None)
assume(B is not None)
assume(C is not None)
assume(A.size >= 1)
assume(B.size >= 1)
assume(C.size >= 1)
gemm(A, B, out=C)
numpy_result = A.dot(B)
assert_allclose(numpy_result, C, atol=1e-3, rtol=1e-3)