108 lines
2.6 KiB
Python
108 lines
2.6 KiB
Python
|
# Copyright ExplsionAI GmbH, released under BSD.
|
||
|
import numpy
|
||
|
import numpy.random
|
||
|
from .py import gemm, einsum
|
||
|
from timeit import default_timer as timer
|
||
|
|
||
|
numpy.random.seed(0)
|
||
|
|
||
|
|
||
|
def create_data(nO, nI, batch_size):
|
||
|
X = numpy.zeros((batch_size, nI), dtype="f")
|
||
|
X += numpy.random.uniform(-1.0, 1.0, X.shape)
|
||
|
W = numpy.zeros((nO, nI), dtype="f")
|
||
|
W += numpy.random.uniform(-1.0, 1.0, W.shape)
|
||
|
return X, W
|
||
|
|
||
|
|
||
|
def get_numpy_blas():
|
||
|
blas_libs = numpy.__config__.blas_opt_info["libraries"]
|
||
|
return blas_libs[0]
|
||
|
|
||
|
|
||
|
def numpy_gemm(X, W, n=1000):
|
||
|
nO, nI = W.shape
|
||
|
batch_size = X.shape[0]
|
||
|
total = 0.0
|
||
|
y = numpy.zeros((batch_size, nO), dtype="f")
|
||
|
for i in range(n):
|
||
|
numpy.dot(X, W, out=y)
|
||
|
total += y.sum()
|
||
|
y.fill(0)
|
||
|
print("Total:", total)
|
||
|
|
||
|
|
||
|
def blis_gemm(X, W, n=1000):
|
||
|
nO, nI = W.shape
|
||
|
batch_size = X.shape[0]
|
||
|
total = 0.0
|
||
|
y = numpy.zeros((batch_size, nO), dtype="f")
|
||
|
for i in range(n):
|
||
|
gemm(X, W, out=y)
|
||
|
total += y.sum()
|
||
|
y.fill(0.0)
|
||
|
print("Total:", total)
|
||
|
|
||
|
|
||
|
def numpy_einsum(X, W, n=1000):
|
||
|
nO, nI = W.shape
|
||
|
batch_size = X.shape[0]
|
||
|
total = 0.0
|
||
|
y = numpy.zeros((nO, batch_size), dtype="f")
|
||
|
for i in range(n):
|
||
|
numpy.einsum("ab,cb->ca", X, W, out=y)
|
||
|
total += y.sum()
|
||
|
y.fill(0.0)
|
||
|
print("Total:", total)
|
||
|
|
||
|
|
||
|
def blis_einsum(X, W, n=1000):
|
||
|
nO, nI = W.shape
|
||
|
batch_size = X.shape[0]
|
||
|
total = 0.0
|
||
|
y = numpy.zeros((nO, batch_size), dtype="f")
|
||
|
for i in range(n):
|
||
|
einsum("ab,cb->ca", X, W, out=y)
|
||
|
total += y.sum()
|
||
|
y.fill(0.0)
|
||
|
print("Total:", total)
|
||
|
|
||
|
|
||
|
def main(nI=128 * 3, nO=128 * 3, batch_size=2000):
|
||
|
print(
|
||
|
"Setting up data for gemm. 1000 iters, "
|
||
|
"nO={nO} nI={nI} batch_size={batch_size}".format(**locals())
|
||
|
)
|
||
|
numpy_blas = get_numpy_blas()
|
||
|
X1, W1 = create_data(nI, nO, batch_size)
|
||
|
X2 = X1.copy()
|
||
|
W2 = W1.copy()
|
||
|
print("Blis gemm...")
|
||
|
start = timer()
|
||
|
blis_gemm(X2, W2, n=1000)
|
||
|
end = timer()
|
||
|
blis_time = end - start
|
||
|
print("%.2f seconds" % blis_time)
|
||
|
print("Numpy (%s) gemm..." % numpy_blas)
|
||
|
start = timer()
|
||
|
numpy_gemm(X1, W1)
|
||
|
end = timer()
|
||
|
numpy_time = end - start
|
||
|
print("%.2f seconds" % numpy_time)
|
||
|
print("Blis einsum ab,cb->ca")
|
||
|
start = timer()
|
||
|
blis_einsum(X2, W2, n=1000)
|
||
|
end = timer()
|
||
|
blis_time = end - start
|
||
|
print("%.2f seconds" % blis_time)
|
||
|
print("Numpy (%s) einsum ab,cb->ca" % numpy_blas)
|
||
|
start = timer()
|
||
|
numpy_einsum(X2, W2)
|
||
|
end = timer()
|
||
|
numpy_time = end - start
|
||
|
print("%.2f seconds" % numpy_time)
|
||
|
|
||
|
|
||
|
if __name__:
|
||
|
main()
|