forked from ridgerchu/matmulfreellm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
41d10f0
commit 9e6fd30
Showing
27 changed files
with
5,604 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from mmfreelm.models import ( | ||
HGRNBitForCausalLM, | ||
HGRNBitModel) | ||
|
||
__all__ = [ | ||
'HGRNBitModel', | ||
'HGRNBitForCausalLM', | ||
] | ||
|
||
__version__ = '0.1' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from .hgrn_bit import HGRNBitAttention | ||
|
||
__all__ = [ | ||
'HGRNBitAttention' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# "HGRN2: Gated Linear RNNs with State Expansion"[https://arxiv.org/abs/2404.07904] | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Optional, Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
from einops import rearrange | ||
from transformers.cache_utils import Cache | ||
|
||
from mmfreelm.modules import FusedRMSNormSwishGate, ShortConvolution | ||
from mmfreelm.modules.activations import swiglu | ||
from mmfreelm.ops.hgrn.recurrent_fuse import fused_recurrent_hgrn | ||
|
||
#from mmfreelm.ops.bitnet import BitLinear_Fuse as BitLinear | ||
from mmfreelm.ops.fusedbitnet import FusedBitLinear as BitLinear | ||
|
||
|
||
class HGRNBitAttention(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
mode: str = 'fused_recurrent', | ||
hidden_size: int = 1024, | ||
num_heads: Optional[int] = None, | ||
expand_ratio: Optional[int] = 1, | ||
use_short_conv: bool = False, | ||
conv_size: int = 4, | ||
conv_bias: bool = False, | ||
share_conv_kernel: bool = True, | ||
layernorm_eps: float = 1e-5, | ||
layer_idx: int = None | ||
) -> HGRNAttention: | ||
super().__init__() | ||
|
||
self.mode = mode | ||
self.hidden_size = hidden_size | ||
self.num_heads = num_heads | ||
self.expand_ratio = expand_ratio | ||
self.input_dim = int(hidden_size * expand_ratio) | ||
self.head_dim = self.input_dim // self.num_heads | ||
|
||
self.use_short_conv = use_short_conv | ||
self.conv_size = conv_size | ||
self.conv_bias = conv_bias | ||
self.share_conv_kernel = share_conv_kernel | ||
|
||
self.layer_idx = layer_idx | ||
|
||
assert mode in ['fused_recurrent'], f"Not suppoerted mode `{mode}`." | ||
assert self.hidden_size % num_heads == 0, f"hidden size must be divisible by num_heads of {num_heads}" | ||
|
||
self.i_proj = BitLinear(hidden_size, self.input_dim, bias=False) | ||
self.f_proj = BitLinear(hidden_size, self.input_dim, bias=False) | ||
self.g_proj = BitLinear(hidden_size, self.input_dim, bias=False) | ||
|
||
if use_short_conv: | ||
self.conv_size = conv_size | ||
if share_conv_kernel: | ||
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu') | ||
else: | ||
self.q_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu') | ||
self.f_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu') | ||
self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu') | ||
|
||
self.g_norm = FusedRMSNormSwishGate(self.input_dim, layernorm_eps) | ||
self.o_proj = BitLinear(self.input_dim, hidden_size, bias=False) | ||
|
||
self.apply(self._initialize_weights) | ||
|
||
def _initialize_weights(self, module): | ||
if getattr(module, "_is_hf_initialized", False): | ||
return | ||
if isinstance(module, (nn.Linear, BitLinear)): | ||
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) | ||
if module.bias is not None: | ||
nn.init.zeros_(module.bias) | ||
module._is_hf_initialized = True | ||
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
past_key_values: Optional[Cache] = None, | ||
use_cache: Optional[bool] = False, | ||
output_attentions: Optional[bool] = False, | ||
lower_bound: Optional[torch.Tensor] = None, | ||
**kwargs | ||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: | ||
# launching the triton kernel for just one token will actually be slower | ||
mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode | ||
|
||
last_state = past_key_values[self.layer_idx] if use_cache else None | ||
if self.use_short_conv: | ||
conv_state = last_state[0] if use_cache else None | ||
if self.share_conv_kernel: | ||
# conv state is updated inplace | ||
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state) | ||
i = self.i_proj(hidden_states) | ||
f = self.f_proj(hidden_states) | ||
else: | ||
conv_state_i = last_state[2] if use_cache else None | ||
conv_state_f = last_state[1] if use_cache else None | ||
i = self.i_conv1d(self.i_proj(hidden_states), attention_mask, conv_state_i) | ||
f = self.f_conv1d(self.f_proj(hidden_states), attention_mask, conv_state_f) | ||
else: | ||
i = self.i_proj(hidden_states) | ||
f = self.f_proj(hidden_states) | ||
|
||
f = f.sigmoid() | ||
# the lower bound for the first layer is zero | ||
if lower_bound is not None and self.layer_idx > 0: | ||
f = lower_bound + (1 - lower_bound) * f | ||
i = swiglu(i, 1 - f) | ||
# dealing with left-padding | ||
if attention_mask is not None: | ||
i = i.mul_(attention_mask.unsqueeze(-1)) | ||
i, f = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (i, f)) | ||
|
||
recurrent_state = last_state[-1] if use_cache else None | ||
if mode == 'fused_recurrent': | ||
o, recurrent_state = fused_recurrent_hgrn(i, f, initial_state=recurrent_state, output_final_state=use_cache) | ||
else: | ||
raise NotImplementedError(f"Not supported mode `{mode}`.") | ||
|
||
if past_key_values is not None: | ||
if self.use_short_conv: | ||
if self.share_conv_kernel: | ||
last_state = (conv_state, recurrent_state) | ||
else: | ||
last_state = (conv_state_i, conv_state_f, recurrent_state) | ||
else: | ||
last_state = (recurrent_state,) | ||
past_key_values.update(last_state, self.layer_idx, i.shape[2]) | ||
|
||
o = self.g_norm(self.g_proj(hidden_states), rearrange(o, 'b h l d -> b l (h d)')) | ||
o = self.o_proj(o) | ||
|
||
return o, None, past_key_values | ||
|
||
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: | ||
param = next(self.parameters()) | ||
state = tuple() | ||
if self.use_short_conv: | ||
if self.share_conv_kernel: | ||
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),) | ||
else: | ||
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size), | ||
param.new_zeros(batch_size, self.hidden_size, self.conv_size), | ||
param.new_zeros(batch_size, self.hidden_size, self.conv_size)) | ||
state += (param.new_zeros(batch_size, self.num_heads, self.head_dim),) | ||
return state | ||
|
||
def state_size(self, **kwargs) -> int: | ||
state_size = self.hidden_size | ||
for module in self.children(): | ||
if isinstance(module, ShortConvolution): | ||
state_size += module.state_size | ||
return state_size |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
from einops import rearrange | ||
|
||
from mmfreelm.modules.utils import checkpoint | ||
|
||
try: | ||
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update | ||
except ImportError: | ||
causal_conv1d_fn = None | ||
causal_conv1d_update = None | ||
|
||
|
||
@checkpoint | ||
def proj_then_conv1d( | ||
x: torch.Tensor, | ||
proj_weight: torch.Tensor, | ||
conv1d_weight: torch.Tensor, | ||
conv1d_bias: Optional[torch.Tensor] = None, | ||
cache: Optional[torch.Tensor] = None | ||
) -> torch.Tensor: | ||
# We do matmul and transpose BLH -> HBL at the same time | ||
x = rearrange(proj_weight @ rearrange(x, "b l d -> d (b l)"), "d (b l) -> b d l", l=x.shape[-2]) | ||
|
||
if causal_conv1d_fn is None: | ||
raise ImportError("`causal_conv1d_fn` is not available. Please install `causal-conv1d` first.") | ||
if cache is None: | ||
x = causal_conv1d_fn( | ||
x=x, | ||
weight=rearrange(conv1d_weight, "d 1 w -> d w"), | ||
bias=conv1d_bias, | ||
activation="silu", | ||
).transpose(1, 2) | ||
else: | ||
assert x.shape[-1] == 1, "Only support decoding with 1 token at a time for now" | ||
x = x.squeeze(-1) | ||
x = causal_conv1d_update( | ||
x=x, | ||
weight=rearrange(conv1d_weight, "d 1 w -> d w"), | ||
bias=conv1d_bias, | ||
cache=cache, | ||
activation="silu", | ||
) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from mmfreelm.models.hgrn_bit import HGRNBitConfig, HGRNBitForCausalLM, HGRNBitModel | ||
__all__ = [ | ||
'HGRNBitConfig', 'HGRNBitForCausalLM', 'HGRNBitModel', | ||
|
||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM | ||
|
||
from mmfreelm.models.hgrn_bit.configuration_hgrn_bit import HGRNBitConfig | ||
from mmfreelm.models.hgrn_bit.modeling_hgrn_bit import HGRNBitForCausalLM, HGRNBitModel | ||
|
||
AutoConfig.register(HGRNBitConfig.model_type, HGRNBitConfig) | ||
AutoModel.register(HGRNBitConfig, HGRNBitModel) | ||
AutoModelForCausalLM.register(HGRNBitConfig, HGRNBitForCausalLM) | ||
|
||
|
||
__all__ = ['HGRNBitConfig', 'HGRNBitForCausalLM', 'HGRNBitModel'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from typing import Optional | ||
|
||
from transformers.configuration_utils import PretrainedConfig | ||
|
||
|
||
class HGRNBitConfig(PretrainedConfig): | ||
|
||
model_type = 'hgrn_bit' | ||
keys_to_ignore_at_inference = ['past_key_values'] | ||
|
||
def __init__( | ||
self, | ||
vocab_size: int = 32000, | ||
hidden_size: int = 2048, | ||
num_hidden_layers: int = 24, | ||
attn_mode: str = "fused_recurrent", | ||
num_heads: Optional[int] = 1, | ||
expand_ratio: Optional[int] = 1, | ||
use_short_conv: bool = False, | ||
conv_size: int = 4, | ||
share_conv_kernel: bool = True, | ||
use_lower_bound: bool = True, | ||
hidden_ratio: Optional[int] = 4, | ||
intermediate_size: Optional[int] = None, | ||
hidden_act: str = "swish", | ||
max_position_embeddings: int = 2048, | ||
rms_norm_eps: float = 1e-6, | ||
use_cache: bool = True, | ||
pad_token_id: int = None, | ||
bos_token_id: int = 1, | ||
eos_token_id: int = 2, | ||
tie_word_embeddings: bool = False, | ||
initializer_range: float = 0.02, | ||
fuse_cross_entropy: bool = True, | ||
**kwargs | ||
): | ||
self.vocab_size = vocab_size | ||
self.max_position_embeddings = max_position_embeddings | ||
self.hidden_size = hidden_size | ||
self.num_hidden_layers = num_hidden_layers | ||
self.attn_mode = attn_mode | ||
self.num_heads = num_heads | ||
self.expand_ratio = expand_ratio | ||
self.use_short_conv = use_short_conv | ||
self.conv_size = conv_size | ||
self.share_conv_kernel = share_conv_kernel | ||
self.use_lower_bound = use_lower_bound | ||
self.hidden_ratio = hidden_ratio | ||
self.intermediate_size = intermediate_size | ||
self.hidden_act = hidden_act | ||
self.rms_norm_eps = rms_norm_eps | ||
self.use_cache = use_cache | ||
self.initializer_range = initializer_range | ||
self.fuse_cross_entropy = fuse_cross_entropy | ||
|
||
super().__init__( | ||
pad_token_id=pad_token_id, | ||
bos_token_id=bos_token_id, | ||
eos_token_id=eos_token_id, | ||
tie_word_embeddings=tie_word_embeddings, | ||
**kwargs, | ||
) |
Oops, something went wrong.