-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
bert_handler.py
207 lines (186 loc) · 7.69 KB
/
bert_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# !/usr/bin/env/python3
# Copyright (c) Facebook, Inc. and its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=no-self-use,too-many-arguments,unused-argument,not-callable,no-member,attribute-defined-outside-init
""" Bert Custom Handler."""
from captum.attr import IntegratedGradients
import json
import logging
import os
import numpy as np
import torch
from transformers import BertTokenizer
from ts.torch_handler.base_handler import BaseHandler
from captum.attr import visualization
import torch.nn.functional as F
from bert_train import BertNewsClassifier
from wrapper import AGNewsmodelWrapper
logger = logging.getLogger(__name__)
class NewsClassifierHandler(BaseHandler):
"""
NewsClassifierHandler class. This handler takes a review / sentence
and returns the label as either world / sports / business /sci-tech
"""
def __init__(self):
self.model = None
self.mapping = None
self.device = None
self.initialized = False
self.class_mapping_file = None
self.VOCAB_FILE = None
def initialize(self, ctx):
"""
First try to load torchscript else load eager mode state_dict based model
:param ctx: System properties
"""
properties = ctx.system_properties
self.device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
model_dir = properties.get("model_dir")
# Read model serialize/pt file
model_pt_path = os.path.join(model_dir, "bert.pth")
# Read model definition file
model_def_path = os.path.join(model_dir, "bert_train.py")
if not os.path.isfile(model_def_path):
raise RuntimeError("Missing the model definition file")
self.VOCAB_FILE = os.path.join(model_dir, "bert-base-uncased-vocab.txt")
if not os.path.isfile(self.VOCAB_FILE):
raise RuntimeError("Missing the vocab file")
self.class_mapping_file = os.path.join(model_dir, "index_to_name.json")
state_dict = torch.load(model_pt_path, map_location=self.device)
self.model = BertNewsClassifier()
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()
logger.debug("Model file %s loaded successfully", model_pt_path)
self.initialized = True
def preprocess(self, data):
"""
Receives text in form of json and converts it into an encoding for the inference stage
:param data: Input to be passed through the layers for prediction
:return: output - preprocessed encoding
"""
text = data[0].get("data")
if text is None:
text = data[0].get("body")
self.text = text
self.tokenizer = BertTokenizer(self.VOCAB_FILE)
self.input_ids = torch.tensor(
[self.tokenizer.encode(self.text, add_special_tokens=True)]
).to(self.device)
return self.input_ids
def inference(self, input_ids):
"""
Predict the class for a review / sentence whether
it is belong to world / sports / business /sci-tech
:param encoding: Input encoding to be passed through the layers for prediction
:return: output - predicted output
"""
inputs = self.input_ids.to(self.device)
self.outputs = self.model.forward(inputs)
self.out = np.argmax(self.outputs.cpu().detach())
return [self.out.item()]
def postprocess(self, inference_output):
"""
Does postprocess after inference to be returned to user
:param inference_output: Output of inference
:return: output - Output after post processing
"""
if os.path.exists(self.class_mapping_file):
with open(self.class_mapping_file) as json_file:
data = json.load(json_file)
inference_output = json.dumps(data[str(inference_output[0])])
return [inference_output]
return inference_output
def add_attributions_to_visualizer(
self,
attributions,
tokens,
pred_prob,
pred_class,
true_class,
attr_class,
delta,
vis_data_records,
):
attributions = attributions.sum(dim=2).squeeze(0)
attributions = attributions / torch.norm(attributions)
attributions = attributions.cpu().detach().numpy()
# storing couple samples in an array for visualization purposes
vis_data_records.append(
visualization.VisualizationDataRecord(
attributions,
pred_prob,
pred_class,
true_class,
attr_class,
attributions.sum(),
tokens,
delta,
)
)
def score_func(self, o):
output = F.softmax(o, dim=1)
pre_pro = np.argmax(output.cpu().detach())
return pre_pro
def summarize_attributions(self, attributions):
"""Summarises the attribution across multiple runs
Args:
attributions ([list): attributions from the Integrated Gradients
Returns:
list : Returns the attributions after normalizing them.
"""
attributions = attributions.sum(dim=-1).squeeze(0)
attributions = attributions / torch.norm(attributions)
return attributions
def explain_handle(self, model_wraper, text, target=1):
"""Captum explanations handler
Args:
data_preprocess (Torch Tensor):
Preprocessed data to be used for captum
raw_data (list): The unprocessed data to get target from the request
Returns:
dict : A dictionary response with the explanations response.
"""
vis_data_records_base = []
model_wrapper = AGNewsmodelWrapper(self.model)
tokenizer = BertTokenizer(self.VOCAB_FILE)
model_wrapper.eval()
model_wrapper.zero_grad()
encoding = tokenizer.encode_plus(
self.text, return_attention_mask=True, return_tensors="pt", add_special_tokens=False
)
input_ids = encoding["input_ids"]
attention_mask = encoding["attention_mask"]
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
input_embedding_test = model_wrapper.model.bert_model.embeddings(input_ids)
preds = model_wrapper(input_embedding_test, attention_mask)
out = np.argmax(preds.cpu().detach(), axis=1)
out = out.item()
ig_1 = IntegratedGradients(model_wrapper)
attributions, delta = ig_1.attribute( # pylint: disable=no-member
input_embedding_test,
n_steps=500,
return_convergence_delta=True,
target=1,
)
tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu().numpy().tolist())
feature_imp_dict = {}
feature_imp_dict["words"] = tokens
attributions_sum = self.summarize_attributions(attributions)
feature_imp_dict["importances"] = attributions_sum.tolist()
feature_imp_dict["delta"] = delta[0].tolist()
self.add_attributions_to_visualizer(
attributions, tokens, self.score_func(preds), out, 2, 1, delta, vis_data_records_base
)
return [feature_imp_dict]