ai-content-maker/.venv/Lib/site-packages/torch/utils/benchmark/examples/compare.py

99 lines
2.8 KiB
Python

"""Example of Timer and Compare APIs:
$ python -m examples.compare
"""
import pickle
import sys
import time
import torch
import torch.utils.benchmark as benchmark_utils
class FauxTorch:
"""Emulate different versions of pytorch.
In normal circumstances this would be done with multiple processes
writing serialized measurements, but this simplifies that model to
make the example clearer.
"""
def __init__(self, real_torch, extra_ns_per_element):
self._real_torch = real_torch
self._extra_ns_per_element = extra_ns_per_element
def extra_overhead(self, result):
# time.sleep has a ~65 us overhead, so only fake a
# per-element overhead if numel is large enough.
numel = int(result.numel())
if numel > 5000:
time.sleep(numel * self._extra_ns_per_element * 1e-9)
return result
def add(self, *args, **kwargs):
return self.extra_overhead(self._real_torch.add(*args, **kwargs))
def mul(self, *args, **kwargs):
return self.extra_overhead(self._real_torch.mul(*args, **kwargs))
def cat(self, *args, **kwargs):
return self.extra_overhead(self._real_torch.cat(*args, **kwargs))
def matmul(self, *args, **kwargs):
return self.extra_overhead(self._real_torch.matmul(*args, **kwargs))
def main():
tasks = [
("add", "add", "torch.add(x, y)"),
("add", "add (extra +0)", "torch.add(x, y + zero)"),
]
serialized_results = []
repeats = 2
timers = [
benchmark_utils.Timer(
stmt=stmt,
globals={
"torch": torch if branch == "master" else FauxTorch(torch, overhead_ns),
"x": torch.ones((size, 4)),
"y": torch.ones((1, 4)),
"zero": torch.zeros(()),
},
label=label,
sub_label=sub_label,
description=f"size: {size}",
env=branch,
num_threads=num_threads,
)
for branch, overhead_ns in [("master", None), ("my_branch", 1), ("severe_regression", 5)]
for label, sub_label, stmt in tasks
for size in [1, 10, 100, 1000, 10000, 50000]
for num_threads in [1, 4]
]
for i, timer in enumerate(timers * repeats):
serialized_results.append(pickle.dumps(
timer.blocked_autorange(min_run_time=0.05)
))
print(f"\r{i + 1} / {len(timers) * repeats}", end="")
sys.stdout.flush()
print()
comparison = benchmark_utils.Compare([
pickle.loads(i) for i in serialized_results
])
print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n")
comparison.print()
print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n")
comparison.trim_significant_figures()
comparison.colorize()
comparison.print()
if __name__ == "__main__":
main()