-
Notifications
You must be signed in to change notification settings - Fork 7
/
fine_tune_clsify_head.py
119 lines (101 loc) · 3.54 KB
/
fine_tune_clsify_head.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
Takes a pretrained model with classification head and uses the peft package to do Adapter + LoRA
fine tuning.
"""
from typing import Any
import torch
from lightning import LightningModule
from peft import get_peft_model, LoraConfig, TaskType
from torch.optim import AdamW, Optimizer
from torchmetrics.functional.classification import (
binary_accuracy,
binary_f1_score,
binary_precision,
binary_recall,
)
from transformers import AutoModelForSequenceClassification
class TransformerModule(LightningModule):
def __init__(
self,
pretrained_model: str,
num_classes: int,
lr: float,
):
super().__init__()
model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_name_or_path=pretrained_model,
num_labels=num_classes,
)
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
)
self.model = get_peft_model(model, peft_config)
self.model.print_trainable_parameters()
self.lr = lr
self.save_hyperparameters("pretrained_model")
def forward(
self,
input_ids: list[int],
attention_mask: list[int],
label: list[int],
):
"""Calc the loss by passing inputs to the model and comparing against ground
truth labels. Here, all of the arguments of self.model comes from the
SequenceClassification head from HuggingFace.
"""
return self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=label,
)
def _compute_metrics(self, batch, split) -> tuple:
"""Helper method hosting the evaluation logic common to the <split>_step methods."""
outputs = self(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
label=batch["label"],
)
# For predicting probabilities, do softmax along last dimension (by row).
prob_class1 = torch.argmax(torch.softmax(outputs["logits"], dim=-1), dim=1)
metrics = {
f"{split}_Loss": outputs["loss"],
f"{split}_Acc": binary_accuracy(
preds=prob_class1,
target=batch["label"],
),
f"{split}_F1_Score": binary_f1_score(
preds=prob_class1,
target=batch["label"],
),
f"{split}_Precision": binary_precision(
preds=prob_class1,
target=batch["label"],
),
f"{split}_Recall": binary_recall(
preds=prob_class1,
target=batch["label"],
),
}
return outputs, metrics
def training_step(self, batch, batch_idx):
outputs, metrics = self._compute_metrics(batch, "Train")
self.log_dict(metrics, on_epoch=True, on_step=False)
return outputs["loss"]
def validation_step(self, batch, batch_idx) -> dict[str, Any]:
_, metrics = self._compute_metrics(batch, "Val")
self.log_dict(metrics, on_epoch=True, on_step=False)
return metrics
def test_step(self, batch, batch_idx) -> dict[str, Any]:
_, metrics = self._compute_metrics(batch, "Test")
self.log_dict(metrics)
return metrics
def configure_optimizers(self) -> Optimizer:
return AdamW(
params=self.parameters(),
lr=self.lr,
weight_decay=0.0,
)