Merge pull request #2 from marcinbogdanski/fix/fa3-non-hopper-fallback
add fallback FA3 kernel for non-Hopper GPUs
This commit is contained in:
commit
bb54287479
5
train.py
5
train.py
@ -18,7 +18,10 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from kernels import get_kernel
|
||||
fa3 = get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
||||
cap = torch.cuda.get_device_capability()
|
||||
# varunneal's FA3 is Hopper only, use kernels-community on non-Hopper GPUs
|
||||
repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3"
|
||||
fa3 = get_kernel(repo).flash_attn_interface
|
||||
|
||||
from constants import MAX_SEQ_LEN, TIME_BUDGET
|
||||
from prepare import Tokenizer, make_dataloader, evaluate_bpb
|
||||
|
||||
Loading…
Reference in New Issue
Block a user