[go: nahoru, domu]

Skip to content

Commit

Permalink
able to override lora R value when adding a new finetuning scope
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 29, 2022
1 parent 6126cde commit 0e45bac
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
18 changes: 12 additions & 6 deletions palm_rlhf_pytorch/palm_rlhf_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def identity(t, *args, **kwargs):
return t

Expand Down Expand Up @@ -330,9 +333,9 @@ def set_dropout(self, dropout):
module.p = dropout
return self

def add_finetune_params(self, scope):
def add_finetune_params(self, scope, lora_r = None):
assert scope not in self.finetune_modules, f'finetune scope {scope} already found'
dim, dim_head, heads, r, device = self.dim, self.dim_head, self.heads, self.lora_r, self.device
dim, dim_head, heads, r, device = self.dim, self.dim_head, self.heads, default(lora_r, self.lora_r), self.device

q_inner_dim = heads * dim_head
kv_inner_dim = dim_head
Expand Down Expand Up @@ -497,7 +500,8 @@ def __init__(
dropout = 0.1,
num_binned_output = 0.,
use_lora = True,
reward_lora_scope = 'reward'
lora_r = 8,
reward_lora_scope = 'reward',
):
super().__init__()

Expand All @@ -507,7 +511,7 @@ def __init__(
self.reward_lora_scope = reward_lora_scope if use_lora else None

if exists(self.reward_lora_scope):
self.palm.add_finetune_params(reward_lora_scope)
self.palm.add_finetune_params(reward_lora_scope, lora_r = lora_r)

dim = palm.dim

Expand Down Expand Up @@ -601,6 +605,8 @@ def __init__(
pooled_values = False,
actor_lora = True,
critic_lora = True,
actor_lora_r = 8,
critic_lora_r = 8,
actor_lora_scope = 'actor',
critic_lora_scope = 'critic',
actor_dropout = 0.,
Expand All @@ -624,10 +630,10 @@ def __init__(
self.critic_lora_scope = critic_lora_scope if critic_lora else None

if self.actor_lora:
self.actor_palm.add_finetune_params(actor_lora_scope)
self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r)

if self.critic_lora:
self.critic_palm.add_finetune_params(critic_lora_scope)
self.critic_palm.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r)

self.pooled_values = pooled_values
self.value_head = nn.Sequential(
Expand Down
4 changes: 4 additions & 0 deletions palm_rlhf_pytorch/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def __init__(
critic_adam_eps = 1e-7,
actor_lora = True,
critic_lora = True,
actor_lora_r = 8,
critic_lora_r = 8,
critic_pooled_values = True,
actor_dropout = 0.,
critic_dropout = 0.,
Expand Down Expand Up @@ -182,6 +184,8 @@ def __init__(
palm = palm,
actor_lora = actor_lora,
critic_lora = critic_lora,
actor_lora_r = actor_lora_r,
critic_lora_r = critic_lora_r,
pooled_values = critic_pooled_values,
actor_dropout = actor_dropout,
critic_dropout = critic_dropout
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'PaLM-rlhf-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.43',
version = '0.0.44',
license='MIT',
description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 0e45bac

Please sign in to comment.