Shortcuts

Source code for rl4co.models.nn.flash_attention

# Attention modules with Flash Attention compatibility
# Slight extension to cover our cases such as pre-computed linears in Attention
# Reference: https://github.com/HazyResearch/flash-attention/blob/57ee618170e1adecbf787365cdf330c63768abd2/flash_attn/modules/mha.py

import math

from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange

try:
    from flash_attn.flash_attn_interface import (
        flash_attn_unpadded_kvpacked_func,
        flash_attn_unpadded_qkvpacked_func,
    )
except ImportError:
    flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func = None, None

try:
    from flash_attn.ops.flash_attn_triton import (
        flash_attn_kvpacked_func,
        flash_attn_qkvpacked_func,
    )
except ImportError:
    flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None

try:
    from flash_attn.ops.fused_dense import (
        ColumnParallelLinear,
        FusedDense,
        RowParallelLinear,
    )
except ImportError:
    FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None

try:
    from flash_attn.layers.rotary import RotaryEmbedding
except ImportError:
    RotaryEmbedding = None

try:
    import ft_attention
except ImportError:
    ft_attention = None


[docs]def flash_attn_wrapper(self, func, *args, **kwargs): """Wrapper for flash attention to automatically cast to fp16 if needed""" if self.force_flash_attn and args[0].is_cuda: original_dtype = args[0].dtype args = [arg.half() for arg in args if isinstance(arg, torch.Tensor)] out = func(*args, **kwargs) return out.to(original_dtype) else: return func(*args, **kwargs)
[docs]class FlashSelfAttention(nn.Module): """Implement the scaled dot product attention with softmax. Arguments --------- softmax_scale: The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime) attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ def __init__( self, causal=False, softmax_scale=None, attention_dropout=0.0, triton=False ): super().__init__() if attention_dropout != 0.0 or not triton: assert ( flash_attn_unpadded_qkvpacked_func is not None ), "FlashAttention is not installed" if attention_dropout == 0.0 and triton: assert ( flash_attn_qkvpacked_func is not None ), "FlashAttention Triton is not installed" self.causal = causal self.softmax_scale = softmax_scale self.dropout_p = attention_dropout self.triton = triton
[docs] def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): """Implements the multihead softmax attention. Arguments --------- qkv: The tensor containing the query, key, and value. If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D). If cu_seqlens is not None and max_seqlen is not None, then qkv has shape (total, 3, H, D), where total is the sum of the sequence lengths in the batch. causal: if passed, will override self.causal cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into qkv. max_seqlen: int. Maximum sequence length in the batch. Returns: -------- out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None, else (B, S, H, D). """ assert qkv.dtype in [torch.float16, torch.bfloat16] assert qkv.is_cuda causal = self.causal if causal is None else causal unpadded = cu_seqlens is not None if unpadded: assert cu_seqlens.dtype == torch.int32 assert max_seqlen is not None assert isinstance(max_seqlen, int) return flash_attn_unpadded_qkvpacked_func( qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal, ) else: batch_size, seqlen = qkv.shape[0], qkv.shape[1] # Triton version doesn't support dropout if self.triton and (self.dropout_p == 0 or not self.training): output = flash_attn_qkvpacked_func(qkv, None, causal, self.softmax_scale) else: qkv = rearrange(qkv, "b s ... -> (b s) ...") max_seqlen = seqlen cu_seqlens = torch.arange( 0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device, ) output = flash_attn_unpadded_qkvpacked_func( qkv, cu_seqlens, max_seqlen, self.dropout_p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal, ) output = rearrange(output, "(b s) ... -> b s ...", b=batch_size) return output
[docs]class FlashCrossAttention(nn.Module): """Implement the scaled dot product attention with softmax. Arguments --------- softmax_scale: The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime) attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ def __init__( self, causal=False, softmax_scale=None, attention_dropout=0.0, triton=False ): super().__init__() if attention_dropout != 0.0 or not triton: assert ( flash_attn_unpadded_kvpacked_func is not None ), "FlashAttention is not installed" if attention_dropout == 0.0 and triton: assert ( flash_attn_kvpacked_func is not None ), "FlashAttention Triton is not installed" self.causal = causal self.softmax_scale = softmax_scale self.dropout_p = attention_dropout self.triton = triton
[docs] def forward( self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None, cu_seqlens_k=None, max_seqlen_k=None, ): """Implements the multihead softmax attention. Arguments --------- q: The tensor containing the query. (B, Sq, H, D) kv: The tensor containing the key and value. (B, Sk, 2, H, D) causal: if passed, will override self.causal cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into q. max_seqlen: int. Maximum sequence length in the batch of q. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into kv. max_seqlen_k: int. Maximum sequence length in the batch of k and v. """ assert q.dtype in [torch.float16, torch.bfloat16] assert q.is_cuda and kv.is_cuda causal = self.causal if causal is None else causal unpadded = cu_seqlens is not None if unpadded: assert cu_seqlens.dtype == torch.int32 assert max_seqlen is not None assert isinstance(max_seqlen, int) assert cu_seqlens_k is not None assert cu_seqlens_k.dtype == torch.int32 assert max_seqlen_k is not None assert isinstance(max_seqlen, int) return flash_attn_unpadded_kvpacked_func( q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k, self.dropout_p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal, ) else: batch_size, seqlen_q = q.shape[0], q.shape[1] seqlen_k = kv.shape[1] assert ( kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3] ) if self.triton and ( self.dropout_p == 0.0 or not self.training ): # Triton version doesn't support dropout output = flash_attn_kvpacked_func(q, kv, None, causal, self.softmax_scale) else: q = rearrange(q, "b s ... -> (b s) ...") kv = rearrange(kv, "b s ... -> (b s) ...") cu_seqlens_q = torch.arange( 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q.device, ) cu_seqlens_k = torch.arange( 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=kv.device, ) output = flash_attn_unpadded_kvpacked_func( q, kv, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, self.dropout_p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal, ) output = rearrange(output, "(b s) ... -> b s ...", b=batch_size) return output
[docs]class SelfAttention(nn.Module): """Implement the scaled dot product attention with softmax. Arguments --------- softmax_scale: The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime) attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ def __init__( self, causal=False, softmax_scale=None, attention_dropout=0.0, _inf=-10000.0 ): super().__init__() self.causal = causal self.softmax_scale = softmax_scale self.dropout_p = attention_dropout self._inf = _inf
[docs] def forward(self, qkv, causal=None, key_padding_mask=None): """Implements the multihead softmax attention. Arguments --------- qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) causal: if passed, will override self.causal key_padding_mask: boolean mask to apply to the attention weights. True means to keep, False means to mask out. (B, S) """ batch_size, seqlen = qkv.shape[0], qkv.shape[1] causal = self.causal if causal is None else causal q, k, v = qkv.unbind(dim=2) softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) if key_padding_mask is not None: padding_mask = torch.full( (batch_size, seqlen), self._inf, dtype=scores.dtype, device=scores.device, ) padding_mask.masked_fill_(key_padding_mask, 0.0) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") if causal: # "triu_tril_cuda_template" not implemented for 'BFloat16' # So we have to construct the mask in float causal_mask = torch.triu( torch.full((seqlen, seqlen), self._inf, device=scores.device), 1 ) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + causal_mask.to(dtype=scores.dtype) attention = torch.softmax(scores, dim=-1, dtype=v.dtype) attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0) output = torch.einsum("bhts,bshd->bthd", attention_drop, v) return output
[docs]class CrossAttention(nn.Module): """Implement the scaled dot product attention with softmax. Arguments --------- softmax_scale: The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime) attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ def __init__( self, causal=False, softmax_scale=None, attention_dropout=0.0, neg_inf=-10000.0 ): super().__init__() self.causal = causal self.softmax_scale = softmax_scale self.dropout_p = attention_dropout self.neg_inf = neg_inf
[docs] def forward(self, q, kv, causal=None, key_padding_mask=None): """Implements the multihead softmax attention. Arguments --------- q: The tensor containing the query. (B, Sq, H, D) kv: The tensor containing the key and value. (B, Sk, 2, H, D) causal: if passed, will override self.causal key_padding_mask: boolean mask to apply to the attention weights. True means to keep, False means to mask out. (B, Sk) """ batch_size, seqlen_q = q.shape[0], q.shape[1] causal = self.causal if causal is None else causal seqlen_k = kv.shape[1] assert ( kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3] ) k, v = kv.unbind(dim=2) softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) if key_padding_mask is not None: padding_mask = torch.full( (batch_size, seqlen_k), self.neg_inf, dtype=scores.dtype, device=scores.device, ) padding_mask.masked_fill_(key_padding_mask, 0.0) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") if causal: # "triu_tril_cuda_template" not implemented for 'BFloat16' # So we have to construct the mask in float causal_mask = torch.triu( torch.full((seqlen_q, seqlen_k), self.neg_inf, device=scores.device), 1 ) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + causal_mask.to(dtype=scores.dtype) attention = torch.softmax(scores, dim=-1, dtype=v.dtype) attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0) output = torch.einsum("bhts,bshd->bthd", attention_drop, v) return output
[docs]class LinearResidual(nn.Linear): """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
[docs] def forward(self, input: torch.Tensor) -> torch.Tensor: return super().forward(input), input
def _update_kv_cache(kv, inference_params, layer_idx): """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" # Pre-allocate memory for key-values for inference. num_heads, head_dim = kv.shape[-2:] if layer_idx not in inference_params.key_value_memory_dict: kv_cache = torch.empty( inference_params.max_batch_size, inference_params.max_sequence_len, 2, num_heads, head_dim, dtype=kv.dtype, device=kv.device, ) inference_params.key_value_memory_dict[layer_idx] = kv_cache else: if not inference_params.fused_ft_kernel: kv_cache = inference_params.key_value_memory_dict[layer_idx] else: # For FT, k_cache has shape (b, h, headdim / packsize, s, packsize) # where packsize = 4 if fp32, 8 if fp16 or bf16. # v_cache has shape (b, h, s, headdim) k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx] kv_cache = None # Adjust key and value for inference batch_start = inference_params.batch_size_offset batch_end = batch_start + kv.shape[0] sequence_start = inference_params.sequence_len_offset sequence_end = sequence_start + kv.shape[1] assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0]) assert sequence_end <= ( kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2] ) # Copy key and values. if not inference_params.fused_ft_kernel: assert kv_cache is not None kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv kv = kv_cache[batch_start:batch_end, :sequence_end, ...] return kv else: assert inference_params.sequence_len_offset == 0 # FT kernel requires different layouts for the k_cache and v_cache. assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32] packsize = 4 if kv.dtype == torch.float32 else 8 if kv_cache is not None: kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv k_cache = rearrange( kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize, ).contiguous() v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous() inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache) else: k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange( kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize ) v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange( kv[:, :, 1], "b s h d -> b h s d" ) return kv
[docs]class MHA(nn.Module): """Multi-head self-attention and cross-attention""" def __init__( self, embed_dim, num_heads, cross_attn=False, bias=True, dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dwconv=False, rotary_emb_dim=0, rotary_emb_scale_base=0, fused_bias_fc=False, use_flash_attn=False, force_dtype_flash_attn=True, return_residual=False, device=None, dtype=None, ) -> None: """ return_residual: whether to return the input x along with the output. This is for performance reason: for post-norm architecture, returning the input allows us to fuse the backward of nn.Linear with the residual connection. """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.embed_dim = embed_dim self.cross_attn = cross_attn self.causal = causal self.layer_idx = layer_idx self.dwconv = dwconv self.rotary_emb_dim = rotary_emb_dim self.use_flash_attn = use_flash_attn self.force_dtype_flash_attn = force_dtype_flash_attn and use_flash_attn self.return_residual = return_residual # self.checkpointing = checkpointing # we don't use it for now self.num_heads = num_heads assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" self.head_dim = self.embed_dim // num_heads if self.rotary_emb_dim > 0: assert ( not cross_attn ), "MHA with rotary embedding does not support cross-attention yet" assert RotaryEmbedding is not None, "rotary_emb is not installed" self.rotary_emb = RotaryEmbedding( self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device ) if fused_bias_fc and FusedDense is None: raise ImportError("fused_dense is not installed") linear_cls = nn.Linear if not fused_bias_fc else FusedDense linear_resid_cls = ( LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True) ) inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention if not self.cross_attn: if not self.return_residual: self.Wqkv = linear_cls( embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs ) else: self.Wqkv = linear_resid_cls( embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs ) if self.dwconv: self.dwconv_qkv = nn.Conv1d( 3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2, groups=3 * embed_dim, ) else: self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs) if not self.return_residual: self.Wkv = linear_cls( embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs ) else: self.Wkv = linear_resid_cls( embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs ) if self.dwconv: self.dwconv_q = nn.Conv1d( embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim ) self.dwconv_kv = nn.Conv1d( 2 * embed_dim, 2 * embed_dim, kernel_size=3, padding=2, groups=2 * embed_dim, ) self.inner_attn = inner_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) self.inner_cross_attn = inner_cross_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) # output projection always have the bias (for now) self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs) def _update_kv_cache(self, kv, inference_params): """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" assert not self.dwconv, "Generation does not support dwconv yet" assert ( self.layer_idx is not None ), "Generation requires layer_idx in the constructor" return _update_kv_cache(kv, inference_params, self.layer_idx)
[docs] def forward( self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None, mixer_subset=None, inference_params=None, **kwargs, ): """ Arguments: x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total is the is the sum of the sequence lengths in the batch. x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x. cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into x. Only applicable when using FlashAttention. max_seqlen: int. Maximum sequence length in the batch. key_padding_mask: boolean mask, True means to keep, False means to mask out. (batch, seqlen). Only applicable when not using FlashAttention. mixer_subset: for cross-attention only. If not None, will take a subset of x before applying the query projection. Useful for e.g., ViT where we only care about the CLS token in the last layer. inference_params: for generation. Adapted from Megatron-LM (and Apex) https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 """ if cu_seqlens is not None: assert max_seqlen is not None assert key_padding_mask is None assert self.use_flash_attn assert not self.dwconv assert self.rotary_emb_dim == 0 if key_padding_mask is not None: assert cu_seqlens is None assert max_seqlen is None assert not self.use_flash_attn if inference_params is not None: assert key_padding_mask is None assert cu_seqlens is None and max_seqlen is None assert not self.dwconv kwargs = ( {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs} if self.use_flash_attn else {"key_padding_mask": key_padding_mask, **kwargs} ) if not self.cross_attn: assert x_kv is None and mixer_subset is None if not self.return_residual: qkv = self.Wqkv(x) else: qkv, x = self.Wqkv(x) if self.dwconv: qkv = rearrange( self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d", ).contiguous() qkv = rearrange( qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim ) if inference_params is None: if self.rotary_emb_dim > 0: qkv = self.rotary_emb(qkv) context = self.inner_attn(qkv, **kwargs) else: if ( not inference_params.fused_ft_kernel ) or inference_params.sequence_len_offset == 0: if self.rotary_emb_dim > 0: qkv = self.rotary_emb( qkv, seqlen_offset=inference_params.sequence_len_offset ) q = qkv[:, :, 0] kv = self._update_kv_cache(qkv[:, :, 1:], inference_params) # If we're processing the prompt, causal=None (use self.causal). # If we're decoding, then causal=False. causal = None if inference_params.sequence_len_offset == 0 else False context = self.flash_attn_wrapper( self.inner_cross_attn, q, kv, causal=causal ) # NOTE: modified else: assert inference_params.fused_ft_kernel assert ft_attention is not None context = ft_attention.single_query_attention( *rearrange(qkv, "b 1 three h d -> b three h d").unbind(dim=1), *inference_params.key_value_memory_dict[self.layer_idx], inference_params.lengths_per_sample, inference_params.sequence_len_offset, self.rotary_emb_dim, ) context = rearrange(context, "b h d -> b 1 h d") else: if not self.return_residual: q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) kv = self.Wkv(x_kv if x_kv is not None else x) else: if x_kv is not None: kv, x_kv = self.Wkv(x_kv) else: kv, x = self.Wkv(x) q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) kv = rearrange(kv, "... (two h d) -> ... two h d", two=2, d=self.head_dim) if self.dwconv: q = rearrange( self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d", ).contiguous() kv = rearrange( self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d", ).contiguous() if inference_params is None: context = self.flash_attn_wrapper(self.inner_cross_attn, q, kv, **kwargs) else: kv = self._update_kv_cache(kv) context = self.flash_attn_wrapper( self.inner_cross_attn, q, kv, causal=False ) out = self.out_proj(rearrange(context, "... h d -> ... (h d)")) return out if not self.return_residual else (out, x)
flash_attn_wrapper = flash_attn_wrapper

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision f4bc96ca.

Built with Sphinx using a theme provided by Read the Docs.