diff --git a/tools/benchmark/triton_kernels/__main__.py b/tools/benchmark/triton_kernels/__main__.py index e09b92976..950307f22 100644 --- a/tools/benchmark/triton_kernels/__main__.py +++ b/tools/benchmark/triton_kernels/__main__.py @@ -33,6 +33,9 @@ # mutated inputs" when using max-autotune. The fallback is correct; suppress noise. warnings.filterwarnings("ignore", message=".*[Ss]kipping (cuda|CUDA)[Gg]raphs.*") logging.getLogger("torch._inductor.cudagraph_trees").setLevel(logging.ERROR) +# The per-measurement profiler session emits a one-time note about event cycles; the +# runner reads key_averages right after each session, so it doesn't apply here. +warnings.filterwarnings("ignore", message=".*Profiler clears events.*") _BENCHMARKS = { "entropy_loss": bench_entropy_loss, diff --git a/tools/benchmark/triton_kernels/runner.py b/tools/benchmark/triton_kernels/runner.py index 9a00c0cd0..0d8334fff 100644 --- a/tools/benchmark/triton_kernels/runner.py +++ b/tools/benchmark/triton_kernels/runner.py @@ -4,9 +4,9 @@ Each benchmark file defines a list of `Case` objects (input shape/dtype sweep) and a list of `Variant` objects (implementations to compare — e.g. pytorch eager, pytorch compiled, Triton). The runner invokes each variant -on each case, measures timing (median + mean + percentiles via CUDA events), -measures peak/final memory, and compares outputs against an fp32 reference -using RMS error. Results are printed as a table per case. +on each case, measures GPU kernel time (summed per-kernel device time via the +profiler), measures peak/final memory, and compares outputs against an fp32 +reference using RMS error. Results are printed as a table per case. """ import dataclasses @@ -152,24 +152,34 @@ class VariantResult: # --------------------------------------------------------------------------- timing +# Per-rep samples for the timing distribution. Device kernel time is stable, so a +# small fixed count gives tight stats without profiling thousands of reps. +_MAX_PROFILE_REPS = 50 + + def bench_fn( fn: typing.Callable[[], typing.Any], reset: typing.Callable[[], None] | None = None, warmup_ms: float = 25.0, rep_ms: float = 100.0, min_reps: int = 5, - max_reps: int = 10_000, ) -> TimingStats: """Benchmark `fn` — it should be a no-arg callable that invokes the kernel being timed (close over inputs). Returns timing statistics in ms. - Mirrors `triton.testing.do_bench` logic but returns raw per-rep list so we - can compute {median, mean, min, max, std} from one set of runs. + Reports GPU kernel time: each rep sums the profiler's per-kernel device + self-time, and stats are computed over the per-rep samples. A CUDA-event + window around `fn` instead counts GPU-idle bubbles from per-call allocation + and (for autograd variants) Python/engine launch overhead — benchmarking + artifacts that don't occur in a real run where the CPU runs ahead and the + allocator serves cached buffers, and which vary by variant (the eager C++ + engine starves the GPU far less than a Python autograd Function). Summing + device self-time isolates the work the GPU actually does. """ if not torch.cuda.is_available(): # CPU / Triton interpret: single timed run with wall clock. min_reps, - # max_reps, warmup_ms, rep_ms are ignored — this path is for smoke - # testing kernel correctness, not measurement. + # warmup_ms, rep_ms are ignored — this path is for smoke testing kernel + # correctness, not measurement. if reset is not None: reset() fn() # warmup @@ -215,28 +225,34 @@ def bench_fn( torch.cuda.synchronize() one_rep_ms = max(post_start.elapsed_time(post_end), 0.001) - num_reps = max(min_reps, min(max_reps, int(rep_ms / one_rep_ms))) + num_reps = max(min_reps, min(_MAX_PROFILE_REPS, int(rep_ms / one_rep_ms))) - start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_reps)] - end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_reps)] - for i in range(num_reps): + # Profile each rep separately and sum its kernels' device self-time, building a + # per-rep sample distribution. The L2 flush stays outside the profiled region and + # is synced before the profiler opens, so its fill kernel can't land in the capture + # window while each kernel still reads cold. CUDA-only profiling records no CPU op + # nodes (which would otherwise carry their children's device time and double-count); + # the device_type == CUDA filter drops the zero-device-time runtime/launch entries. + samples_ms: list[float] = [] + for _ in range(num_reps): if reset is not None: reset() - # The L2 flush is enqueued before start_events[i] on the same stream, so - # the timed window starts after the zero completes — only fn() is timed. flush_buffer.zero_() - start_events[i].record() - fn() - end_events[i].record() - torch.cuda.synchronize() + torch.cuda.synchronize() + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + fn() + torch.cuda.synchronize() + kernel_us = sum( + k.self_device_time_total for k in prof.key_averages() if k.device_type == torch.autograd.DeviceType.CUDA + ) + samples_ms.append(kernel_us / 1000) - times = [start_events[i].elapsed_time(end_events[i]) for i in range(num_reps)] return TimingStats( - median_ms=statistics.median(times), - mean_ms=statistics.fmean(times), - min_ms=min(times), - max_ms=max(times), - std_ms=statistics.pstdev(times) if len(times) > 1 else 0.0, + median_ms=statistics.median(samples_ms), + mean_ms=statistics.fmean(samples_ms), + min_ms=min(samples_ms), + max_ms=max(samples_ms), + std_ms=statistics.pstdev(samples_ms) if len(samples_ms) > 1 else 0.0, num_reps=num_reps, )