Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tools/benchmark/triton_kernels/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
# mutated inputs" when using max-autotune. The fallback is correct; suppress noise.
warnings.filterwarnings("ignore", message=".*[Ss]kipping (cuda|CUDA)[Gg]raphs.*")
logging.getLogger("torch._inductor.cudagraph_trees").setLevel(logging.ERROR)
# The per-measurement profiler session emits a one-time note about event cycles; the
# runner reads key_averages right after each session, so it doesn't apply here.
warnings.filterwarnings("ignore", message=".*Profiler clears events.*")

_BENCHMARKS = {
"entropy_loss": bench_entropy_loss,
Expand Down
64 changes: 40 additions & 24 deletions tools/benchmark/triton_kernels/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
Each benchmark file defines a list of `Case` objects (input shape/dtype
sweep) and a list of `Variant` objects (implementations to compare — e.g.
pytorch eager, pytorch compiled, Triton). The runner invokes each variant
on each case, measures timing (median + mean + percentiles via CUDA events),
measures peak/final memory, and compares outputs against an fp32 reference
using RMS error. Results are printed as a table per case.
on each case, measures GPU kernel time (summed per-kernel device time via the
profiler), measures peak/final memory, and compares outputs against an fp32
reference using RMS error. Results are printed as a table per case.
"""

import dataclasses
Expand Down Expand Up @@ -152,24 +152,34 @@ class VariantResult:
# --------------------------------------------------------------------------- timing


# Per-rep samples for the timing distribution. Device kernel time is stable, so a
# small fixed count gives tight stats without profiling thousands of reps.
_MAX_PROFILE_REPS = 50


def bench_fn(
fn: typing.Callable[[], typing.Any],
reset: typing.Callable[[], None] | None = None,
warmup_ms: float = 25.0,
rep_ms: float = 100.0,
min_reps: int = 5,
max_reps: int = 10_000,
) -> TimingStats:
"""Benchmark `fn` — it should be a no-arg callable that invokes the kernel
being timed (close over inputs). Returns timing statistics in ms.

Mirrors `triton.testing.do_bench` logic but returns raw per-rep list so we
can compute {median, mean, min, max, std} from one set of runs.
Reports GPU kernel time: each rep sums the profiler's per-kernel device
self-time, and stats are computed over the per-rep samples. A CUDA-event
window around `fn` instead counts GPU-idle bubbles from per-call allocation
and (for autograd variants) Python/engine launch overhead — benchmarking
artifacts that don't occur in a real run where the CPU runs ahead and the
allocator serves cached buffers, and which vary by variant (the eager C++
engine starves the GPU far less than a Python autograd Function). Summing
device self-time isolates the work the GPU actually does.
"""
if not torch.cuda.is_available():
# CPU / Triton interpret: single timed run with wall clock. min_reps,
# max_reps, warmup_ms, rep_ms are ignored — this path is for smoke
# testing kernel correctness, not measurement.
# warmup_ms, rep_ms are ignored — this path is for smoke testing kernel
# correctness, not measurement.
if reset is not None:
reset()
fn() # warmup
Expand Down Expand Up @@ -215,28 +225,34 @@ def bench_fn(
torch.cuda.synchronize()
one_rep_ms = max(post_start.elapsed_time(post_end), 0.001)

num_reps = max(min_reps, min(max_reps, int(rep_ms / one_rep_ms)))
num_reps = max(min_reps, min(_MAX_PROFILE_REPS, int(rep_ms / one_rep_ms)))

start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_reps)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_reps)]
for i in range(num_reps):
# Profile each rep separately and sum its kernels' device self-time, building a
# per-rep sample distribution. The L2 flush stays outside the profiled region and
# is synced before the profiler opens, so its fill kernel can't land in the capture
# window while each kernel still reads cold. CUDA-only profiling records no CPU op
# nodes (which would otherwise carry their children's device time and double-count);
# the device_type == CUDA filter drops the zero-device-time runtime/launch entries.
samples_ms: list[float] = []
for _ in range(num_reps):
if reset is not None:
reset()
# The L2 flush is enqueued before start_events[i] on the same stream, so
# the timed window starts after the zero completes — only fn() is timed.
flush_buffer.zero_()
start_events[i].record()
fn()
end_events[i].record()
torch.cuda.synchronize()
torch.cuda.synchronize()
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
fn()
torch.cuda.synchronize()
kernel_us = sum(
k.self_device_time_total for k in prof.key_averages() if k.device_type == torch.autograd.DeviceType.CUDA
)
samples_ms.append(kernel_us / 1000)

times = [start_events[i].elapsed_time(end_events[i]) for i in range(num_reps)]
return TimingStats(
median_ms=statistics.median(times),
mean_ms=statistics.fmean(times),
min_ms=min(times),
max_ms=max(times),
std_ms=statistics.pstdev(times) if len(times) > 1 else 0.0,
median_ms=statistics.median(samples_ms),
mean_ms=statistics.fmean(samples_ms),
min_ms=min(samples_ms),
max_ms=max(samples_ms),
std_ms=statistics.pstdev(samples_ms) if len(samples_ms) > 1 else 0.0,
num_reps=num_reps,
)

Expand Down
Loading