Crafting Efficient Kernels with Epilogue Fusion

Crafting Efficient Kernels with Epilogue Fusion

In many ML workloads, a GEMM is followed by small operations like bias, activation, scaling, or type conversion. These ops are cheap in math, but they often cost extra global memory traffic (store GEMM result, read it back, write again).

Epilogue fusion is a way to avoid this, we can apply these extra ops while the GEMM result is still in registers, right before the final store to global memory. On Hopper and Blackwell, there is also more room to overlap Tensor Core work with other instructions, so doing some extra compute in the epilogue can be even more attractive.


Epilogue fusion eliminates intermediate global memory reads and writes by applying the activation function directly within the GEMM epilogue.

CUTLASS takes advantage of this by allowing epilogue fusion, where operations such as bias addition, activation functions (ReLU, Sigmoid, etc.), and type conversion are applied directly to the accumulator fragments before writing the result to global memory. By fusing these operations, we avoid additional memory reads and writes, which are often the dominant cost in modern workloads. These kinds of operations are relatively easy to fuse because they are elementwise. They do not require reductions or communication across rows or columns, and can be applied independently to each output element.

In this blog post we will go over the basics, show some prebuilt ops, and then show the idea of writing a custom visitor for GEMM gated‑SiLU (aka SwiGLU). The goal is not to be super complete, just to show how the visitor flow works and how you "hook" your logic into epilogue.

We are not building a fully optimized version, nor are we going to show custom TMA store paths or pipelining details (those get complicated fast). To keep things simple, we will use BF16 GEMM for these examples. For a production level Blackwell implementation, you should use NVFP4 block-scaled GEMM, a more robust visitor, and potentially additional fusions.

One important detail for warp-specialized epilogues, the epilogue work is split into producer-load and consumer-store roles (warps/warp-groups). The EVT callbacks run on the consumer side, where the accumulator fragments are transformed and then ultimately stored.

CUTLASS GEMM + EVT basics

Let's start with a minimal GEMM definition in CUTLASS.

In CUTLASS, defining a GEMM kernel typically involves specifying three main components:

  • The mainloop, which defines how input tiles are loaded, the tile and cluster shapes, operand layouts, and how Tensor Core MMA operations are issued.

  • The epilogue, which describes how accumulator fragments are transformed and written back to global memory.

  • The kernel wrapper, which combines the mainloop and epilogue into a complete GEMM kernel.

Below is a simple example that uses an identity epilogue.

// EVT: acc -> cast
using EVT = Sm90EVT<
    Sm90Compute<cutlass::epilogue::thread::Identity, ElementOutput, ElementAcc, RoundStyle>,
    Sm90AccFetch>;

// Collective epilogue
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
    ArchTag, OpClass, TileShape, ClusterShape,
    cutlass::epilogue::collective::EpilogueTileAuto, ElementAcc, ElementAcc,
    ElementOutput, cutlass::layout::RowMajor, AlignOutput,
    ElementOutput, cutlass::layout::RowMajor, AlignOutput,
    cutlass::epilogue::collective::EpilogueScheduleAuto, EVT>::CollectiveOp;

// Collective mainloop
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
    ArchTag, OpClass, ElementInput, cutlass::layout::RowMajor, AlignInput,
    ElementInput, cutlass::layout::ColumnMajor, AlignInput, ElementAcc, TileShape, ClusterShape,
    cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
    cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<Shape<int,int,int,int>, CollectiveMainloop, CollectiveEpilogue, void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

Here is what each piece is doing:

  • Sm90AccFetch: leaf node that provides the raw accumulator fragment (acc) from the mainloop.
  • Sm90Compute<Identity, ElementOutput, ElementAcc, ...>: a trivial "compute" node. With Identity, this effectively just converts the accumulator element type (ElementAcc) into the output element type (ElementOutput) using the chosen rounding mode.
  • Sm90EVT<...>: composes those nodes into a tiny visitor tree: acc -> cast.

Then you plug that EVT into the epilogue "collective":

  • CollectiveEpilogue: the epilogue implementation for the chosen architecture/schedule. On SM90 warp-specialized kernels this typically means, each consumer warp/group iterates over epilogue subtiles, calls the EVT callbacks (visit, reduce, etc.), and then performs the actual store path (often register -> shared memory -> TMA store to global memory).

And finally into the mainloop “collective”:

  • CollectiveMainloop: defines how A/B tiles are loaded (e.g. TMA vs LDG), how many pipeline stages are used, and how MMA is issued for your TileShape/ClusterShape.

The GemmUniversal wrapper ties the mainloop and epilogue together.

Building an EVT chain (scale + bias + ReLU).

We can build slightly complex epilogue that applies scaling to accumulator of gemm applies addition with bias and applies ReLU at the end as below

// EVT node ops 
using NodeMultiply = Sm90Compute<cutlass::multiplies, ElementAcc, ElementAcc, RoundStyle>;
using NodeAdd = Sm90Compute<cutlass::plus, ElementAcc, ElementAcc, RoundStyle>;

// EVT: (global_scale * acc) + per-row bias -> ReLU -> cast to BF16
using EVT0 = Sm90EVT<NodeMultiply, Sm90ScalarBroadcast<ElementScale>, Sm90AccFetch>;
using EVT1 = Sm90EVT<NodeAdd, Sm90ColBroadcast<0, TileShape, ElementBias, ElementBias, Stride<_1, _0, _0>>, EVT0>;
using EVT2 = Sm90EVT<
    Sm90Compute<cutlass::epilogue::thread::ReLU, ElementOutput, ElementAcc, RoundStyle>,
    EVT1>;

// Collective epilogue
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
    ArchTag, OpClass, TileShape, ClusterShape,
    cutlass::epilogue::collective::EpilogueTileAuto, ElementAcc, ElementAcc,
    ElementOutput, cutlass::layout::RowMajor, AlignOutput,
    ElementOutput, cutlass::layout::RowMajor, AlignOutput,
    cutlass::epilogue::collective::EpilogueScheduleAuto, EVT2>::CollectiveOp;

We can visualize this epilogue like this:

image

If you want more details, this blog is really nice: https://research.colfax-intl.com/epilogue_visitor_tree

A few prebuilt epilogue ops.

CUTLASS already has some predefined epilogue styles, for example:

  • linear combinations: LinearCombination (D = alpha*AB + beta*C), ScaledAcc`` (D = alpha*acc`)
  • activations: LinCombEltAct with ReLU / GELU / SiLU / Sigmoid / Tanh / HardSwish / LeakyReLU / Clamp
  • bias: LinCombPerRowBias, LinCombPerColBias (can be with activations)
  • broadcasts: Sm90ScalarBroadcast, Sm90RowBroadcast, Sm90ColBroadcast
  • reductions: Sm90ScalarReduction, Sm90RowReduction, Sm90ColReduction

Fusing gated‑SiLU

The gated‑SiLU pattern (used in Flux, Flux2, and LLMs like LLaMA) looks like this:

C = A @ B
M, N = C.shape
output = SiLU(C[:, :N//2]) * C[:, N//2:]     

This reduces the output dimension by 2x(we take two halves and multiply them). CUTLASS doesn't have a built-in epilogue for this exact "pair + reduce N/2" pattern, so we need a custom epilogue visitor if we want to fuse it in a single pass.

Custom epilogue visitors in CUTLASS let you intercept the per‑thread fragment, transform it, and control how/where its stored. You define a visitor struct with Arguments/Params, then implement get_consumer_store_callbacks() and a callback like visit() or postreduce()/end_loop() to apply the gated‑SiLU on pairs and write the reduced output. The visitor is then inserted into the EVT tree (e.g., Sm90EVT<CustomVisitor, Sm90AccFetch>), so the GEMM produces the fused output in one pass without an extra kernel.

Before implementing the custom epilogue visitor, we face a key challenge. The gated‑SiLU operation requires multiplying silu(gate) with its corresponding up element, but in the naive output layout those values live in different halves of the N dimension, for output column n in [0, N_out), the pair is at gate = C[:, n] and up = C[:, n + N_out]. This means the paired values are separated by N_out columns, and can easily land in different fragments / subtiles (or even different CTAs), which would require cross-tile communication.

image

Trying to synchronize across tiles would significantly hurt performance and break the GEMM epilogue’s parallel execution model. Instead, we permute the columns of B (the model weight matrix) once during weight packing so the GEMM output C = A @ B is laid out differently, instead of producing [gate(0..N_out-1), up(0..N_out-1)] as two separated halves, it produces interleaved pairs [gate0, up0, gate1, up1, ...]. Concretely, the pair for output column n ends up adjacent in C at gate -> 2n and up -> 2n + 1. With adjacency guaranteed, the visitor can fuse the gated‑SiLU multiply and write the reduced N_out result in the epilogue without cross-tile coordination.

image

After this layout tweak, we can easily write a visitor that applies gated‑SiLU using adjacent elements in the same thread. CUTLASS provides several hooks you can override to intercept and transform data at different stages of the epilogue:

  • begin() -> called once before the store loop starts
  • begin_loop(epi_m, epi_n) -> called at the start of each subtile
  • previsit(...) -> called before visit, used for shared memory broadcasts
  • visit(...) -> called per-fragment, where you receive computed values
  • reduce(...) -> reduction step across fragments
  • postreduce(...) -> called after reduction, before memory fence
  • tma_store(...) -> issue TMA stores for auxiliary tensors
  • end_loop(epi_m, epi_n) -> called at the end of each subtile
  • end() -> called once after the store loop completes

For our basic gated‑SiLU visitor, we only need two hooks, visit() and end_loop(). In a "fully integrated" TMA epilogue you will typically flow the final result through reduce()/postreduce() and let the collective handle the register -> smem -> TMA store path. To keep this post simple, we will directly write the reduced output to global memory in end_loop().

Understanding the Data Flow

The epilogue processes data in a hierarchical manner. At the top level, we have tiles defined by TileShape. Each tile is further divided into subtiles (indexed by epi_m, epi_n), and each subtile is processed as multiple fragments (indexed by epi_v).

Tile (M × N)
└── Subtile (epi_m, epi_n)
    └── Fragment 0 (epi_v = 0)
    └── Fragment 1 (epi_v = 1)
    └── ...

The key insight is that visit() is called multiple times per subtile once for each fragment while end_loop() is called once after all fragments in a subtile have been visited.

The visit() hook: accumulating fragments.

Let's look at our visit() implementation:

template <typename ElementAccumulator, typename ElementInput, int FragmentSize>
CUTLASS_DEVICE auto visit(Array<ElementAccumulator, FragmentSize> const&, int epi_v, int, int,
                          Array<ElementInput, FragmentSize> const& frg_input) {
  Tensor tC_rOut_frg = recast<Array<ElementInput, FragmentSize>>(coalesce(tC_rOut));
  tC_rOut_frg(epi_v) = frg_input;
  return frg_input;
}

The function signature tells us what data we receive:

  • Array<ElementAccumulator, FragmentSize> const& -> the raw accumulator values (we ignore this)
  • epi_v -> the fragment index within the current subtile
  • Array<ElementInput, FragmentSize> const& frg_input -> the input from previous EVT nodes (already converted)

Our visitor sits at the end of an EVT chain. By the time data reaches us, it has already been fetched from the accumulator and converted to our working precision. The frg_input contains FragmentSize elements that this thread is responsible for.

The critical operation here is storing fragments into tC_rOut:

tC_rOut_frg(epi_v) = frg_input;

We use recast to view our register tensor as an array of fragments, then index by epi_v to store each fragment in its correct position. This accumulates all fragments across multiple visit() calls, building up the complete subtile data in registers.

The end_loop() hook: computing gated‑SiLU.

Once all fragments for a subtile have been collected, end_loop() is called:

CUTLASS_DEVICE void end_loop(int epi_m, int epi_n) {
  if constexpr (EnableNullptr) {
    if (params_ptr->ptr_out == nullptr) {
      return;
    }
  }

  auto [M, N_full, L] = problem_shape_mnl_full;
  int N_out = N_full / 2;

  // Create output tensor view with halved N dimension
  Tensor gOut = make_tensor(
    make_gmem_ptr(params_ptr->ptr_out),
    make_shape(M, N_out, L),
    params_ptr->dOut
  );

  // Flatten tensors for easy iteration
  Tensor tC_rOut_flat = coalesce(tC_rOut);
  Tensor tC_cOut_full_flat = coalesce(tC_cOut_full(_,_,_,epi_m,epi_n));

  using ConvertOutput = NumericConverter<ElementOut, float, RoundStyle>;
  ConvertOutput convert_output{};

  int total_elements = size(tC_cOut_full_flat);

  // Process pairs: (gate, up) -> silu(gate) * up
  CUTLASS_PRAGMA_UNROLL
  for (int flat_idx = 0; flat_idx < total_elements; flat_idx += 2) {
    float gate = tC_rOut_flat(flat_idx);
    float up = tC_rOut_flat(flat_idx + 1);
    
    // SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x))
    float silu = gate / (1.0f + expf(-gate));
    float out = silu * up;

    // Map back to output coordinates (column index halved)
    auto [m, n, l] = tC_cOut_full_flat(flat_idx);
    n >>= 1;  // Divide column by 2
    gOut(m, n, l) = convert_output(out);
  }
}

Let's break down what happens:

  1. Output tensor setup: We create gOut with shape (M, N_out, L) where N_out = N_full / 2. The output has half the columns because we're fusing pairs.

  2. Coordinate tracking: tC_cOut_full stores the original (m, n, l) coordinates for each element. We slice it by (epi_m, epi_n) to get coordinates for the current subtile.

  3. Pair processing: We iterate through elements two at a time. Due to our column reordering of matrix B, adjacent elements in the flattened tensor are guaranteed to be (gate, up) pairs.

  4. SiLU computation: For each pair, we compute silu(gate) * up. This code shows the simple expf form for clarity, but CUTLASS’s Sigmoid/SiLu can use a fast tanh-based approximation (e.g. sigmoid(x) ≈ 0.5fast_tanh(0.5x) + 0.5**) depending on build/config.

  5. Coordinate mapping: The original column index n corresponds to the wide matrix. We right-shift by 1 (n >>= 1) to get the output column index, since two input columns map to one output column.

Setting up the callbacks.

The get_consumer_store_callbacks() function initializes our callback object with the necessary tensors:

template <bool ReferenceSrc, class... Args>
CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
  auto [M, N, K, L] = args.problem_shape_mnkl;
  auto problem_shape_mnl_full = make_shape(M, N, L);  // "wide" output shape (before N/2 reduction)

  // Allocate a per-thread register tile to accumulate fragments for the current subtile.
  // The exact shape comes from CUTLASS's epilogue partitioning (omitted here).
  Tensor tC_rOut = make_tensor<float>(/* same (CPY,CPY_M,CPY_N) shape as the thread's output tile */);

  // Also build a matching coordinate tensor for the wide output so we can map each register element
  // back to (m, n, l) and then apply n >>= 1 when storing the reduced output.
  Tensor coordOut_full = make_identity_tensor(make_shape(M, N, L));
  Tensor tC_cOut_full = sm90_partition_for_epilogue<ReferenceSrc>(coordOut_full, /* ... */);

  return ConsumerStoreCallbacks</* ... */>(
    cute::move(tC_rOut),
    cute::move(tC_cOut_full),
    problem_shape_mnl_full,
    params_ptr
  );
}

Key points:

  • tC_rOut is a register tensor that accumulates fragments during visit() calls
  • tC_cOut_full maps flat indices to (m, n, l) coordinates in the original wide matrix
  • sm90_partition_for_epilogue handles the complex tiling and thread mapping for us

Putting it all together.

The complete data flow looks like this:

┌─────────────────────────────────────────────────────────────┐
│  GEMM Mainloop: Compute A @ B_reordered                     │
│  Output shape: (M, N_full) where N_full = 2 * N_out         │
└─────────────────────────────────────────────────────────────┘
                              │
                              ▼
┌─────────────────────────────────────────────────────────────┐
│  visit() called multiple times per subtile                  │
│  Each call: store frg_input into tC_rOut[epi_v]             │
│                                                             │
│  tC_rOut: [frag0][frag1][frag2]...                          │
│            ↑      ↑      ↑                                  │
│          epi_v=0  =1     =2                                 │
└─────────────────────────────────────────────────────────────┘
                              │
                              ▼
┌─────────────────────────────────────────────────────────────┐
│  end_loop() called once per subtile                         │
│                                                             │
│  Process pairs from this thread's tC_rOut:                  │
│  [gate₀, up₀, gate₁, up₁, ...]                              │
│     ↓      ↓                                                │
│  silu(gate₀) * up₀  →  output[m, n>>1, l]                   │
│                                                             │
│  Output shape: (M, N_out)                                   │
└─────────────────────────────────────────────────────────────┘

What a fully integrated implementation looks like

For a real production implementation, you will do typically:

  • Integrate with the TMA epilogue pipeline: instead of directly storing to global memory in end_loop(), compute/store through reduce()/postreduce() and issue TMA stores via tma_store() so you keep the standard register→smem→TMA path.
  • Use an FP8/NVFP4 mainloop on Blackwell (rather than BF16) to take advantage of the higher Tensor Core throughput.
  • Fuse “what comes next” (e.g. output quantization and/or auxiliary tensors) so you minimize memory traffic. gated‑SiLU already halves the logical N dimension; if you also quantize the output (e.g. to NVFP4, 4‑bit), the write footprint can drop dramatically compared to writing the unfused wide BF16 intermediate (2N_out* columns at 2 bytes/elem), storing the fused result at N_out columns and 4 bits/elem is up to ~8× fewer output bytes, plus you avoid an extra read+write pair.

Some production numbers

These are numbers from a production model (time per invocation, us):

That’s 166 us saved (about 1.28× faster) by folding the post-GEMM work into the epilogue.

The cool part of epilogue fusion is that you are not approximating anything. You are just doing the same math earlier (before a round-trip to global memory), so you don't need to sacrifice quality to get the speedup.