diff --git a/palm_rlhf_pytorch/palm_rlhf_pytorch.py b/palm_rlhf_pytorch/palm_rlhf_pytorch.py index d35011d..4673be5 100644 --- a/palm_rlhf_pytorch/palm_rlhf_pytorch.py +++ b/palm_rlhf_pytorch/palm_rlhf_pytorch.py @@ -22,6 +22,9 @@ def exists(val): return val is not None +def default(val, d): + return val if exists(val) else d + def identity(t, *args, **kwargs): return t @@ -330,9 +333,9 @@ def set_dropout(self, dropout): module.p = dropout return self - def add_finetune_params(self, scope): + def add_finetune_params(self, scope, lora_r = None): assert scope not in self.finetune_modules, f'finetune scope {scope} already found' - dim, dim_head, heads, r, device = self.dim, self.dim_head, self.heads, self.lora_r, self.device + dim, dim_head, heads, r, device = self.dim, self.dim_head, self.heads, default(lora_r, self.lora_r), self.device q_inner_dim = heads * dim_head kv_inner_dim = dim_head @@ -497,7 +500,8 @@ def __init__( dropout = 0.1, num_binned_output = 0., use_lora = True, - reward_lora_scope = 'reward' + lora_r = 8, + reward_lora_scope = 'reward', ): super().__init__() @@ -507,7 +511,7 @@ def __init__( self.reward_lora_scope = reward_lora_scope if use_lora else None if exists(self.reward_lora_scope): - self.palm.add_finetune_params(reward_lora_scope) + self.palm.add_finetune_params(reward_lora_scope, lora_r = lora_r) dim = palm.dim @@ -601,6 +605,8 @@ def __init__( pooled_values = False, actor_lora = True, critic_lora = True, + actor_lora_r = 8, + critic_lora_r = 8, actor_lora_scope = 'actor', critic_lora_scope = 'critic', actor_dropout = 0., @@ -624,10 +630,10 @@ def __init__( self.critic_lora_scope = critic_lora_scope if critic_lora else None if self.actor_lora: - self.actor_palm.add_finetune_params(actor_lora_scope) + self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r) if self.critic_lora: - self.critic_palm.add_finetune_params(critic_lora_scope) + self.critic_palm.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r) self.pooled_values = pooled_values self.value_head = nn.Sequential( diff --git a/palm_rlhf_pytorch/ppo.py b/palm_rlhf_pytorch/ppo.py index 519ca91..712fd71 100644 --- a/palm_rlhf_pytorch/ppo.py +++ b/palm_rlhf_pytorch/ppo.py @@ -138,6 +138,8 @@ def __init__( critic_adam_eps = 1e-7, actor_lora = True, critic_lora = True, + actor_lora_r = 8, + critic_lora_r = 8, critic_pooled_values = True, actor_dropout = 0., critic_dropout = 0., @@ -182,6 +184,8 @@ def __init__( palm = palm, actor_lora = actor_lora, critic_lora = critic_lora, + actor_lora_r = actor_lora_r, + critic_lora_r = critic_lora_r, pooled_values = critic_pooled_values, actor_dropout = actor_dropout, critic_dropout = critic_dropout diff --git a/setup.py b/setup.py index 97d96d8..ce1b5fa 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'PaLM-rlhf-pytorch', packages = find_packages(exclude=[]), - version = '0.0.43', + version = '0.0.44', license='MIT', description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch', author = 'Phil Wang',