[go: nahoru, domu]

Skip to content

Commit

Permalink
Add support for parsing the Classification API prediction log output …
Browse files Browse the repository at this point in the history
…to the experimental TFX-BSL PredictionsExtractor implementation.

PiperOrigin-RevId: 473861807
  • Loading branch information
tf-model-analysis-team authored and tfx-copybara committed Sep 12, 2022
1 parent fca4f6a commit 6eb1dfd
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 19 deletions.
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
## Bug fixes and other Changes
* Add support for parsing the Predict API prediction log output to the
experimental TFX-BSL PredictionsExtractor implementation.
* Add support for parsing the Classification API prediction log output to the
experimental TFX-BSL PredictionsExtractor implementation.

* Depends on `tensorflow>=1.15.5,<2` or `tensorflow>=2.10,<3`

Expand Down
17 changes: 16 additions & 1 deletion tensorflow_model_analysis/extractors/predictions_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,22 @@ def _parse_prediction_log_to_tensor_value( # pylint: disable=invalid-name
"""
log_type = prediction_log.WhichOneof('log_type')
if log_type == 'classify_log':
raise NotImplementedError('ClassifyLog processing not implemented yet.')
assert len(
prediction_log.classify_log.response.result.classifications) == 1, (
'We expecth the number of classifications per PredictionLog to be '
'one because TFX-BSL RunInference expects single input/output and '
'handles batching entirely internally.')
classes = np.array([
c.label for c in
prediction_log.classify_log.response.result.classifications[0].classes
],
dtype=object)
scores = np.array([
c.score for c in
prediction_log.classify_log.response.result.classifications[0].classes
],
dtype=np.float32)
return {'classes': classes, 'scores': scores}
elif log_type == 'regress_log':
return np.array([
regression.value
Expand Down
77 changes: 59 additions & 18 deletions tensorflow_model_analysis/extractors/predictions_extractor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,25 @@ def check_result(got):

util.assert_that(result, check_result, label='result')

@parameterized.named_parameters(('ModelSignaturesDoFnInference', False),
('TFXBSLBulkInference', True))
@parameterized.named_parameters(
('ModelSignaturesDoFnInferenceUnspecifiedSignature', False, ''),
('ModelSignaturesDoFnInferencePredictSignature', False, 'predict'),
('ModelSignaturesDoFnInferenceServingDefaultSignature', False,
'serving_default'),
('ModelSignaturesDoFnInferenceClassificationSignature', False,
'classification'), ('TFXBSLBulkInferenceUnspecifiedSignature', True, ''),
('TFXBSLBulkInferencePredictSignature', True, 'predict'),
('TFXBSLBulkInferenceServingDefaultSignature', True, 'serving_default'),
('TFXBSLBulkInferenceClassificationSignature', True, 'classification'))
def testPredictionsExtractorWithBinaryClassificationModel(
self, experimental_bulk_inference):
self, experimental_bulk_inference, signature_name):
temp_export_dir = self._getExportDir()
num_classes = 2
export_dir, _ = dnn_classifier.simple_dnn_classifier(
temp_export_dir, None, n_classes=num_classes)

eval_config = config_pb2.EvalConfig(model_specs=[config_pb2.ModelSpec()])
eval_config = config_pb2.EvalConfig(
model_specs=[config_pb2.ModelSpec(signature_name=signature_name)])
eval_shared_model = self.createTestEvalSharedModel(
eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])
schema = text_format.Parse(
Expand Down Expand Up @@ -216,27 +225,48 @@ def check_result(got):
self.assertLen(got, 1)
# We can't verify the actual predictions, but we can verify the keys.
self.assertIn(constants.PREDICTIONS_KEY, got[0])
for pred_key in ('logistic', 'probabilities', 'all_classes'):
self.assertIn(pred_key, got[0][constants.PREDICTIONS_KEY])
self.assertEqual(
(num_examples, num_classes),
got[0][constants.PREDICTIONS_KEY]['probabilities'].shape)
# Prediction API cases. We default '' to 'predict'.
if signature_name in ('', 'predict'):
for pred_key in ('logistic', 'probabilities', 'all_classes'):
self.assertIn(pred_key, got[0][constants.PREDICTIONS_KEY])
self.assertEqual(
(num_examples, num_classes),
got[0][constants.PREDICTIONS_KEY]['probabilities'].shape)
# Classification API cases. The classification signature is also the
# 'serving_default' signature for this model.
if signature_name in ('serving_default', 'classification'):
for pred_key in ('classes', 'scores'):
self.assertIn(pred_key, got[0][constants.PREDICTIONS_KEY])
self.assertEqual((num_examples, num_classes),
got[0][constants.PREDICTIONS_KEY]['classes'].shape)
self.assertEqual((num_examples, num_classes),
got[0][constants.PREDICTIONS_KEY]['scores'].shape)

except AssertionError as err:
raise util.BeamAssertException(err)

util.assert_that(result, check_result, label='result')

@parameterized.named_parameters(('ModelSignaturesDoFnInference', False),
('TFXBSLBulkInference', True))
@parameterized.named_parameters(
('ModelSignaturesDoFnInferenceUnspecifiedSignature', False, ''),
('ModelSignaturesDoFnInferencePredictSignature', False, 'predict'),
('ModelSignaturesDoFnInferenceServingDefaultSignature', False,
'serving_default'),
('ModelSignaturesDoFnInferenceClassificationSignature', False,
'classification'), ('TFXBSLBulkInferenceUnspecifiedSignature', True, ''),
('TFXBSLBulkInferencePredictSignature', True, 'predict'),
('TFXBSLBulkInferenceServingDefaultSignature', True, 'serving_default'),
('TFXBSLBulkInferenceClassificationSignature', True, 'classification'))
def testPredictionsExtractorWithMultiClassModel(self,
experimental_bulk_inference):
experimental_bulk_inference,
signature_name):
temp_export_dir = self._getExportDir()
num_classes = 3
export_dir, _ = dnn_classifier.simple_dnn_classifier(
temp_export_dir, None, n_classes=num_classes)

eval_config = config_pb2.EvalConfig(model_specs=[config_pb2.ModelSpec()])
eval_config = config_pb2.EvalConfig(
model_specs=[config_pb2.ModelSpec(signature_name=signature_name)])
eval_shared_model = self.createTestEvalSharedModel(
eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])
schema = text_format.Parse(
Expand Down Expand Up @@ -299,11 +329,22 @@ def check_result(got):
self.assertLen(got, 1)
# We can't verify the actual predictions, but we can verify the keys.
self.assertIn(constants.PREDICTIONS_KEY, got[0])
for pred_key in ('probabilities', 'all_classes'):
self.assertIn(pred_key, got[0][constants.PREDICTIONS_KEY])
self.assertEqual(
(num_examples, num_classes),
got[0][constants.PREDICTIONS_KEY]['probabilities'].shape)
# Prediction API cases. We default '' to 'predict'.
if signature_name in ('', 'predict'):
for pred_key in ('probabilities', 'all_classes'):
self.assertIn(pred_key, got[0][constants.PREDICTIONS_KEY])
self.assertEqual(
(num_examples, num_classes),
got[0][constants.PREDICTIONS_KEY]['probabilities'].shape)
# Classification API cases. The classification signature is also the
# 'serving_default' signature for this model.
if signature_name in ('serving_default', 'classification'):
for pred_key in ('classes', 'scores'):
self.assertIn(pred_key, got[0][constants.PREDICTIONS_KEY])
self.assertEqual((num_examples, num_classes),
got[0][constants.PREDICTIONS_KEY]['classes'].shape)
self.assertEqual((num_examples, num_classes),
got[0][constants.PREDICTIONS_KEY]['scores'].shape)

except AssertionError as err:
raise util.BeamAssertException(err)
Expand Down

0 comments on commit 6eb1dfd

Please sign in to comment.