86 lines
2.6 KiB
Python
86 lines
2.6 KiB
Python
|
"""Example of the Timer and Fuzzer APIs:
|
||
|
|
||
|
$ python -m examples.fuzzer
|
||
|
"""
|
||
|
|
||
|
import sys
|
||
|
|
||
|
import torch.utils.benchmark as benchmark_utils
|
||
|
|
||
|
|
||
|
def main():
|
||
|
add_fuzzer = benchmark_utils.Fuzzer(
|
||
|
parameters=[
|
||
|
[
|
||
|
benchmark_utils.FuzzedParameter(
|
||
|
name=f"k{i}",
|
||
|
minval=16,
|
||
|
maxval=16 * 1024,
|
||
|
distribution="loguniform",
|
||
|
) for i in range(3)
|
||
|
],
|
||
|
benchmark_utils.FuzzedParameter(
|
||
|
name="d",
|
||
|
distribution={2: 0.6, 3: 0.4},
|
||
|
),
|
||
|
],
|
||
|
tensors=[
|
||
|
[
|
||
|
benchmark_utils.FuzzedTensor(
|
||
|
name=name,
|
||
|
size=("k0", "k1", "k2"),
|
||
|
dim_parameter="d",
|
||
|
probability_contiguous=0.75,
|
||
|
min_elements=64 * 1024,
|
||
|
max_elements=128 * 1024,
|
||
|
) for name in ("x", "y")
|
||
|
],
|
||
|
],
|
||
|
seed=0,
|
||
|
)
|
||
|
|
||
|
n = 250
|
||
|
measurements = []
|
||
|
for i, (tensors, tensor_properties, _) in enumerate(add_fuzzer.take(n=n)):
|
||
|
x, x_order = tensors["x"], str(tensor_properties["x"]["order"])
|
||
|
y, y_order = tensors["y"], str(tensor_properties["y"]["order"])
|
||
|
shape = ", ".join(tuple(f'{i:>4}' for i in x.shape))
|
||
|
|
||
|
description = "".join([
|
||
|
f"{x.numel():>7} | {shape:<16} | ",
|
||
|
f"{'contiguous' if x.is_contiguous() else x_order:<12} | ",
|
||
|
f"{'contiguous' if y.is_contiguous() else y_order:<12} | ",
|
||
|
])
|
||
|
|
||
|
timer = benchmark_utils.Timer(
|
||
|
stmt="x + y",
|
||
|
globals=tensors,
|
||
|
description=description,
|
||
|
)
|
||
|
|
||
|
measurements.append(timer.blocked_autorange(min_run_time=0.1))
|
||
|
measurements[-1].metadata = {"numel": x.numel()}
|
||
|
print(f"\r{i + 1} / {n}", end="")
|
||
|
sys.stdout.flush()
|
||
|
print()
|
||
|
|
||
|
# More string munging to make pretty output.
|
||
|
print(f"Average attempts per valid config: {1. / (1. - add_fuzzer.rejection_rate):.1f}")
|
||
|
|
||
|
def time_fn(m):
|
||
|
return m.median / m.metadata["numel"]
|
||
|
measurements.sort(key=time_fn)
|
||
|
|
||
|
template = f"{{:>6}}{' ' * 19}Size Shape{' ' * 13}X order Y order\n{'-' * 80}"
|
||
|
print(template.format("Best:"))
|
||
|
for m in measurements[:15]:
|
||
|
print(f"{time_fn(m) * 1e9:>4.1f} ns / element {m.description}")
|
||
|
|
||
|
print("\n" + template.format("Worst:"))
|
||
|
for m in measurements[-15:]:
|
||
|
print(f"{time_fn(m) * 1e9:>4.1f} ns / element {m.description}")
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|