76 lines
2.4 KiB
Python
76 lines
2.4 KiB
Python
# Copyright ExplosionAI GmbH, released under BSD.
|
|
from __future__ import division
|
|
|
|
from hypothesis import given, assume
|
|
from math import sqrt, floor
|
|
|
|
from blis.tests.common import *
|
|
from blis.py import gemm
|
|
|
|
|
|
def _stretch_matrix(data, m, n):
|
|
orig_len = len(data)
|
|
orig_m = m
|
|
orig_n = n
|
|
ratio = sqrt(len(data) / (m * n))
|
|
m = int(floor(m * ratio))
|
|
n = int(floor(n * ratio))
|
|
data = np.ascontiguousarray(data[: m * n], dtype=data.dtype)
|
|
return data.reshape((m, n)), m, n
|
|
|
|
|
|
def _reshape_for_gemm(
|
|
A, B, a_rows, a_cols, out_cols, dtype, trans_a=False, trans_b=False
|
|
):
|
|
A, a_rows, a_cols = _stretch_matrix(A, a_rows, a_cols)
|
|
if len(B) < a_cols or a_cols < 1:
|
|
return (None, None, None)
|
|
b_cols = int(floor(len(B) / a_cols))
|
|
B = np.ascontiguousarray(B.flatten()[: a_cols * b_cols], dtype=dtype)
|
|
B = B.reshape((a_cols, b_cols))
|
|
out_cols = B.shape[1]
|
|
C = np.zeros(shape=(A.shape[0], B.shape[1]), dtype=dtype)
|
|
if trans_a:
|
|
A = np.ascontiguousarray(A.T, dtype=dtype)
|
|
return A, B, C
|
|
|
|
|
|
@given(
|
|
ndarrays(min_len=10, max_len=100, min_val=-100.0, max_val=100.0, dtype="float64"),
|
|
ndarrays(min_len=10, max_len=100, min_val=-100.0, max_val=100.0, dtype="float64"),
|
|
integers(min_value=2, max_value=1000),
|
|
integers(min_value=2, max_value=1000),
|
|
integers(min_value=2, max_value=1000),
|
|
)
|
|
def test_memoryview_double_notrans(A, B, a_rows, a_cols, out_cols):
|
|
A, B, C = _reshape_for_gemm(A, B, a_rows, a_cols, out_cols, "float64")
|
|
assume(A is not None)
|
|
assume(B is not None)
|
|
assume(C is not None)
|
|
assume(A.size >= 1)
|
|
assume(B.size >= 1)
|
|
assume(C.size >= 1)
|
|
gemm(A, B, out=C)
|
|
numpy_result = A.dot(B)
|
|
assert_allclose(numpy_result, C, atol=1e-4, rtol=1e-4)
|
|
|
|
|
|
@given(
|
|
ndarrays(min_len=10, max_len=100, min_val=-100.0, max_val=100.0, dtype="float32"),
|
|
ndarrays(min_len=10, max_len=100, min_val=-100.0, max_val=100.0, dtype="float32"),
|
|
integers(min_value=2, max_value=1000),
|
|
integers(min_value=2, max_value=1000),
|
|
integers(min_value=2, max_value=1000),
|
|
)
|
|
def test_memoryview_float_notrans(A, B, a_rows, a_cols, out_cols):
|
|
A, B, C = _reshape_for_gemm(A, B, a_rows, a_cols, out_cols, dtype="float32")
|
|
assume(A is not None)
|
|
assume(B is not None)
|
|
assume(C is not None)
|
|
assume(A.size >= 1)
|
|
assume(B.size >= 1)
|
|
assume(C.size >= 1)
|
|
gemm(A, B, out=C)
|
|
numpy_result = A.dot(B)
|
|
assert_allclose(numpy_result, C, atol=1e-3, rtol=1e-3)
|