[go: nahoru, domu]

Skip to content

Commit

Permalink
Update supervised.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jun 6, 2024
1 parent 788e823 commit c09ad8b
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/llamafactory/data/processors/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,16 @@ def preprocess_packed_supervised_dataset(
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]

if len(packed_input_ids) <= data_args.cutoff_len:
if len(packed_input_ids) < data_args.cutoff_len:
pad_length = data_args.cutoff_len - len(packed_input_ids)
packed_input_ids += [tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
else:
raise ValueError("The length of packed example exceeds the cutoff length.")

if len(packed_input_ids) != data_args.cutoff_len:
raise ValueError("The length of packed example should be identical to the cutoff length.")

model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append([1] * len(packed_input_ids))
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
model_inputs["labels"].append(packed_labels)

return model_inputs
Expand Down

0 comments on commit c09ad8b

Please sign in to comment.