From 6e1baf655a458b43bade810728c225d919081875 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 8 Jan 2023 10:13:01 -0800 Subject: [PATCH] fix masking logic when using palm as encoder --- palm_rlhf_pytorch/palm_rlhf_pytorch.py | 6 +++++- setup.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/palm_rlhf_pytorch/palm_rlhf_pytorch.py b/palm_rlhf_pytorch/palm_rlhf_pytorch.py index 4cb97eb..6824c3b 100644 --- a/palm_rlhf_pytorch/palm_rlhf_pytorch.py +++ b/palm_rlhf_pytorch/palm_rlhf_pytorch.py @@ -477,7 +477,11 @@ def forward( # mask if encoder # treat any token ids that are negative as tokens to mask out - only needed if not autoregressive - mask = (x < 0) if not self.causal else None + if not self.causal: + mask = x >= 0 + x = x.masked_fill(~mask, 0) + else: + mask = None # get token embedding diff --git a/setup.py b/setup.py index 1446fe0..d81d4b1 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'PaLM-rlhf-pytorch', packages = find_packages(exclude=[]), - version = '0.0.47', + version = '0.0.48', license='MIT', description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch', author = 'Phil Wang',