From 17b480aa6511e20a71c60b9541d2491263eab921 Mon Sep 17 00:00:00 2001 From: Marcin Bogdanski Date: Sat, 7 Mar 2026 01:31:48 +0000 Subject: [PATCH] add fallback FA3 kernel for non-Hopper GPUs --- train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index f9f0065..99c81a4 100644 --- a/train.py +++ b/train.py @@ -18,7 +18,10 @@ import torch.nn as nn import torch.nn.functional as F from kernels import get_kernel -fa3 = get_kernel('varunneal/flash-attention-3').flash_attn_interface +cap = torch.cuda.get_device_capability() +# varunneal's FA3 is Hopper only, use kernels-community on non-Hopper GPUs +repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" +fa3 = get_kernel(repo).flash_attn_interface from constants import MAX_SEQ_LEN, TIME_BUDGET from prepare import Tokenizer, make_dataloader, evaluate_bpb