[go: nahoru, domu]

Skip to content

Commit

Permalink
Update with Matmulfree LM
Browse files Browse the repository at this point in the history
  • Loading branch information
ruijie-zhu committed May 21, 2024
1 parent 41d10f0 commit 9e6fd30
Show file tree
Hide file tree
Showing 27 changed files with 5,604 additions and 0 deletions.
12 changes: 12 additions & 0 deletions mmfreelm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-

from mmfreelm.models import (
HGRNBitForCausalLM,
HGRNBitModel)

__all__ = [
'HGRNBitModel',
'HGRNBitForCausalLM',
]

__version__ = '0.1'
7 changes: 7 additions & 0 deletions mmfreelm/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-

from .hgrn_bit import HGRNBitAttention

__all__ = [
'HGRNBitAttention'
]
162 changes: 162 additions & 0 deletions mmfreelm/layers/hgrn_bit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# -*- coding: utf-8 -*-

# "HGRN2: Gated Linear RNNs with State Expansion"[https://arxiv.org/abs/2404.07904]

from __future__ import annotations

from typing import Optional, Tuple

import torch
import torch.nn as nn
from einops import rearrange
from transformers.cache_utils import Cache

from mmfreelm.modules import FusedRMSNormSwishGate, ShortConvolution
from mmfreelm.modules.activations import swiglu
from mmfreelm.ops.hgrn.recurrent_fuse import fused_recurrent_hgrn

#from mmfreelm.ops.bitnet import BitLinear_Fuse as BitLinear
from mmfreelm.ops.fusedbitnet import FusedBitLinear as BitLinear


class HGRNBitAttention(nn.Module):

def __init__(
self,
mode: str = 'fused_recurrent',
hidden_size: int = 1024,
num_heads: Optional[int] = None,
expand_ratio: Optional[int] = 1,
use_short_conv: bool = False,
conv_size: int = 4,
conv_bias: bool = False,
share_conv_kernel: bool = True,
layernorm_eps: float = 1e-5,
layer_idx: int = None
) -> HGRNAttention:
super().__init__()

self.mode = mode
self.hidden_size = hidden_size
self.num_heads = num_heads
self.expand_ratio = expand_ratio
self.input_dim = int(hidden_size * expand_ratio)
self.head_dim = self.input_dim // self.num_heads

self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.conv_bias = conv_bias
self.share_conv_kernel = share_conv_kernel

self.layer_idx = layer_idx

assert mode in ['fused_recurrent'], f"Not suppoerted mode `{mode}`."
assert self.hidden_size % num_heads == 0, f"hidden size must be divisible by num_heads of {num_heads}"

self.i_proj = BitLinear(hidden_size, self.input_dim, bias=False)
self.f_proj = BitLinear(hidden_size, self.input_dim, bias=False)
self.g_proj = BitLinear(hidden_size, self.input_dim, bias=False)

if use_short_conv:
self.conv_size = conv_size
if share_conv_kernel:
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
else:
self.q_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu')
self.f_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu')
self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu')

self.g_norm = FusedRMSNormSwishGate(self.input_dim, layernorm_eps)
self.o_proj = BitLinear(self.input_dim, hidden_size, bias=False)

self.apply(self._initialize_weights)

def _initialize_weights(self, module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, (nn.Linear, BitLinear)):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
lower_bound: Optional[torch.Tensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
# launching the triton kernel for just one token will actually be slower
mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode

last_state = past_key_values[self.layer_idx] if use_cache else None
if self.use_short_conv:
conv_state = last_state[0] if use_cache else None
if self.share_conv_kernel:
# conv state is updated inplace
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
i = self.i_proj(hidden_states)
f = self.f_proj(hidden_states)
else:
conv_state_i = last_state[2] if use_cache else None
conv_state_f = last_state[1] if use_cache else None
i = self.i_conv1d(self.i_proj(hidden_states), attention_mask, conv_state_i)
f = self.f_conv1d(self.f_proj(hidden_states), attention_mask, conv_state_f)
else:
i = self.i_proj(hidden_states)
f = self.f_proj(hidden_states)

f = f.sigmoid()
# the lower bound for the first layer is zero
if lower_bound is not None and self.layer_idx > 0:
f = lower_bound + (1 - lower_bound) * f
i = swiglu(i, 1 - f)
# dealing with left-padding
if attention_mask is not None:
i = i.mul_(attention_mask.unsqueeze(-1))
i, f = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (i, f))

recurrent_state = last_state[-1] if use_cache else None
if mode == 'fused_recurrent':
o, recurrent_state = fused_recurrent_hgrn(i, f, initial_state=recurrent_state, output_final_state=use_cache)
else:
raise NotImplementedError(f"Not supported mode `{mode}`.")

if past_key_values is not None:
if self.use_short_conv:
if self.share_conv_kernel:
last_state = (conv_state, recurrent_state)
else:
last_state = (conv_state_i, conv_state_f, recurrent_state)
else:
last_state = (recurrent_state,)
past_key_values.update(last_state, self.layer_idx, i.shape[2])

o = self.g_norm(self.g_proj(hidden_states), rearrange(o, 'b h l d -> b l (h d)'))
o = self.o_proj(o)

return o, None, past_key_values

def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
param = next(self.parameters())
state = tuple()
if self.use_short_conv:
if self.share_conv_kernel:
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
else:
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),
param.new_zeros(batch_size, self.hidden_size, self.conv_size),
param.new_zeros(batch_size, self.hidden_size, self.conv_size))
state += (param.new_zeros(batch_size, self.num_heads, self.head_dim),)
return state

def state_size(self, **kwargs) -> int:
state_size = self.hidden_size
for module in self.children():
if isinstance(module, ShortConvolution):
state_size += module.state_size
return state_size
47 changes: 47 additions & 0 deletions mmfreelm/layers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-

from typing import Optional

import torch
from einops import rearrange

from mmfreelm.modules.utils import checkpoint

try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
causal_conv1d_fn = None
causal_conv1d_update = None


@checkpoint
def proj_then_conv1d(
x: torch.Tensor,
proj_weight: torch.Tensor,
conv1d_weight: torch.Tensor,
conv1d_bias: Optional[torch.Tensor] = None,
cache: Optional[torch.Tensor] = None
) -> torch.Tensor:
# We do matmul and transpose BLH -> HBL at the same time
x = rearrange(proj_weight @ rearrange(x, "b l d -> d (b l)"), "d (b l) -> b d l", l=x.shape[-2])

if causal_conv1d_fn is None:
raise ImportError("`causal_conv1d_fn` is not available. Please install `causal-conv1d` first.")
if cache is None:
x = causal_conv1d_fn(
x=x,
weight=rearrange(conv1d_weight, "d 1 w -> d w"),
bias=conv1d_bias,
activation="silu",
).transpose(1, 2)
else:
assert x.shape[-1] == 1, "Only support decoding with 1 token at a time for now"
x = x.squeeze(-1)
x = causal_conv1d_update(
x=x,
weight=rearrange(conv1d_weight, "d 1 w -> d w"),
bias=conv1d_bias,
cache=cache,
activation="silu",
)
return x
7 changes: 7 additions & 0 deletions mmfreelm/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-

from mmfreelm.models.hgrn_bit import HGRNBitConfig, HGRNBitForCausalLM, HGRNBitModel
__all__ = [
'HGRNBitConfig', 'HGRNBitForCausalLM', 'HGRNBitModel',

]
13 changes: 13 additions & 0 deletions mmfreelm/models/hgrn_bit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-

from transformers import AutoConfig, AutoModel, AutoModelForCausalLM

from mmfreelm.models.hgrn_bit.configuration_hgrn_bit import HGRNBitConfig
from mmfreelm.models.hgrn_bit.modeling_hgrn_bit import HGRNBitForCausalLM, HGRNBitModel

AutoConfig.register(HGRNBitConfig.model_type, HGRNBitConfig)
AutoModel.register(HGRNBitConfig, HGRNBitModel)
AutoModelForCausalLM.register(HGRNBitConfig, HGRNBitForCausalLM)


__all__ = ['HGRNBitConfig', 'HGRNBitForCausalLM', 'HGRNBitModel']
64 changes: 64 additions & 0 deletions mmfreelm/models/hgrn_bit/configuration_hgrn_bit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-

from typing import Optional

from transformers.configuration_utils import PretrainedConfig


class HGRNBitConfig(PretrainedConfig):

model_type = 'hgrn_bit'
keys_to_ignore_at_inference = ['past_key_values']

def __init__(
self,
vocab_size: int = 32000,
hidden_size: int = 2048,
num_hidden_layers: int = 24,
attn_mode: str = "fused_recurrent",
num_heads: Optional[int] = 1,
expand_ratio: Optional[int] = 1,
use_short_conv: bool = False,
conv_size: int = 4,
share_conv_kernel: bool = True,
use_lower_bound: bool = True,
hidden_ratio: Optional[int] = 4,
intermediate_size: Optional[int] = None,
hidden_act: str = "swish",
max_position_embeddings: int = 2048,
rms_norm_eps: float = 1e-6,
use_cache: bool = True,
pad_token_id: int = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
tie_word_embeddings: bool = False,
initializer_range: float = 0.02,
fuse_cross_entropy: bool = True,
**kwargs
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.attn_mode = attn_mode
self.num_heads = num_heads
self.expand_ratio = expand_ratio
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.share_conv_kernel = share_conv_kernel
self.use_lower_bound = use_lower_bound
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.initializer_range = initializer_range
self.fuse_cross_entropy = fuse_cross_entropy

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
Loading

0 comments on commit 9e6fd30

Please sign in to comment.