diff --git a/RELEASE.md b/RELEASE.md index 3a0f624848..79324b242d 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -7,6 +7,8 @@ ## Bug fixes and other Changes * Add support for parsing the Predict API prediction log output to the experimental TFX-BSL PredictionsExtractor implementation. +* Add support for parsing the Classification API prediction log output to the + experimental TFX-BSL PredictionsExtractor implementation. * Depends on `tensorflow>=1.15.5,<2` or `tensorflow>=2.10,<3` diff --git a/tensorflow_model_analysis/extractors/predictions_extractor.py b/tensorflow_model_analysis/extractors/predictions_extractor.py index a47fbeb3c2..cdf21c97a7 100644 --- a/tensorflow_model_analysis/extractors/predictions_extractor.py +++ b/tensorflow_model_analysis/extractors/predictions_extractor.py @@ -173,7 +173,22 @@ def _parse_prediction_log_to_tensor_value( # pylint: disable=invalid-name """ log_type = prediction_log.WhichOneof('log_type') if log_type == 'classify_log': - raise NotImplementedError('ClassifyLog processing not implemented yet.') + assert len( + prediction_log.classify_log.response.result.classifications) == 1, ( + 'We expecth the number of classifications per PredictionLog to be ' + 'one because TFX-BSL RunInference expects single input/output and ' + 'handles batching entirely internally.') + classes = np.array([ + c.label for c in + prediction_log.classify_log.response.result.classifications[0].classes + ], + dtype=object) + scores = np.array([ + c.score for c in + prediction_log.classify_log.response.result.classifications[0].classes + ], + dtype=np.float32) + return {'classes': classes, 'scores': scores} elif log_type == 'regress_log': return np.array([ regression.value diff --git a/tensorflow_model_analysis/extractors/predictions_extractor_test.py b/tensorflow_model_analysis/extractors/predictions_extractor_test.py index 2713ac6414..2ce135521f 100644 --- a/tensorflow_model_analysis/extractors/predictions_extractor_test.py +++ b/tensorflow_model_analysis/extractors/predictions_extractor_test.py @@ -145,16 +145,25 @@ def check_result(got): util.assert_that(result, check_result, label='result') - @parameterized.named_parameters(('ModelSignaturesDoFnInference', False), - ('TFXBSLBulkInference', True)) + @parameterized.named_parameters( + ('ModelSignaturesDoFnInferenceUnspecifiedSignature', False, ''), + ('ModelSignaturesDoFnInferencePredictSignature', False, 'predict'), + ('ModelSignaturesDoFnInferenceServingDefaultSignature', False, + 'serving_default'), + ('ModelSignaturesDoFnInferenceClassificationSignature', False, + 'classification'), ('TFXBSLBulkInferenceUnspecifiedSignature', True, ''), + ('TFXBSLBulkInferencePredictSignature', True, 'predict'), + ('TFXBSLBulkInferenceServingDefaultSignature', True, 'serving_default'), + ('TFXBSLBulkInferenceClassificationSignature', True, 'classification')) def testPredictionsExtractorWithBinaryClassificationModel( - self, experimental_bulk_inference): + self, experimental_bulk_inference, signature_name): temp_export_dir = self._getExportDir() num_classes = 2 export_dir, _ = dnn_classifier.simple_dnn_classifier( temp_export_dir, None, n_classes=num_classes) - eval_config = config_pb2.EvalConfig(model_specs=[config_pb2.ModelSpec()]) + eval_config = config_pb2.EvalConfig( + model_specs=[config_pb2.ModelSpec(signature_name=signature_name)]) eval_shared_model = self.createTestEvalSharedModel( eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING]) schema = text_format.Parse( @@ -216,27 +225,48 @@ def check_result(got): self.assertLen(got, 1) # We can't verify the actual predictions, but we can verify the keys. self.assertIn(constants.PREDICTIONS_KEY, got[0]) - for pred_key in ('logistic', 'probabilities', 'all_classes'): - self.assertIn(pred_key, got[0][constants.PREDICTIONS_KEY]) - self.assertEqual( - (num_examples, num_classes), - got[0][constants.PREDICTIONS_KEY]['probabilities'].shape) + # Prediction API cases. We default '' to 'predict'. + if signature_name in ('', 'predict'): + for pred_key in ('logistic', 'probabilities', 'all_classes'): + self.assertIn(pred_key, got[0][constants.PREDICTIONS_KEY]) + self.assertEqual( + (num_examples, num_classes), + got[0][constants.PREDICTIONS_KEY]['probabilities'].shape) + # Classification API cases. The classification signature is also the + # 'serving_default' signature for this model. + if signature_name in ('serving_default', 'classification'): + for pred_key in ('classes', 'scores'): + self.assertIn(pred_key, got[0][constants.PREDICTIONS_KEY]) + self.assertEqual((num_examples, num_classes), + got[0][constants.PREDICTIONS_KEY]['classes'].shape) + self.assertEqual((num_examples, num_classes), + got[0][constants.PREDICTIONS_KEY]['scores'].shape) except AssertionError as err: raise util.BeamAssertException(err) util.assert_that(result, check_result, label='result') - @parameterized.named_parameters(('ModelSignaturesDoFnInference', False), - ('TFXBSLBulkInference', True)) + @parameterized.named_parameters( + ('ModelSignaturesDoFnInferenceUnspecifiedSignature', False, ''), + ('ModelSignaturesDoFnInferencePredictSignature', False, 'predict'), + ('ModelSignaturesDoFnInferenceServingDefaultSignature', False, + 'serving_default'), + ('ModelSignaturesDoFnInferenceClassificationSignature', False, + 'classification'), ('TFXBSLBulkInferenceUnspecifiedSignature', True, ''), + ('TFXBSLBulkInferencePredictSignature', True, 'predict'), + ('TFXBSLBulkInferenceServingDefaultSignature', True, 'serving_default'), + ('TFXBSLBulkInferenceClassificationSignature', True, 'classification')) def testPredictionsExtractorWithMultiClassModel(self, - experimental_bulk_inference): + experimental_bulk_inference, + signature_name): temp_export_dir = self._getExportDir() num_classes = 3 export_dir, _ = dnn_classifier.simple_dnn_classifier( temp_export_dir, None, n_classes=num_classes) - eval_config = config_pb2.EvalConfig(model_specs=[config_pb2.ModelSpec()]) + eval_config = config_pb2.EvalConfig( + model_specs=[config_pb2.ModelSpec(signature_name=signature_name)]) eval_shared_model = self.createTestEvalSharedModel( eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING]) schema = text_format.Parse( @@ -299,11 +329,22 @@ def check_result(got): self.assertLen(got, 1) # We can't verify the actual predictions, but we can verify the keys. self.assertIn(constants.PREDICTIONS_KEY, got[0]) - for pred_key in ('probabilities', 'all_classes'): - self.assertIn(pred_key, got[0][constants.PREDICTIONS_KEY]) - self.assertEqual( - (num_examples, num_classes), - got[0][constants.PREDICTIONS_KEY]['probabilities'].shape) + # Prediction API cases. We default '' to 'predict'. + if signature_name in ('', 'predict'): + for pred_key in ('probabilities', 'all_classes'): + self.assertIn(pred_key, got[0][constants.PREDICTIONS_KEY]) + self.assertEqual( + (num_examples, num_classes), + got[0][constants.PREDICTIONS_KEY]['probabilities'].shape) + # Classification API cases. The classification signature is also the + # 'serving_default' signature for this model. + if signature_name in ('serving_default', 'classification'): + for pred_key in ('classes', 'scores'): + self.assertIn(pred_key, got[0][constants.PREDICTIONS_KEY]) + self.assertEqual((num_examples, num_classes), + got[0][constants.PREDICTIONS_KEY]['classes'].shape) + self.assertEqual((num_examples, num_classes), + got[0][constants.PREDICTIONS_KEY]['scores'].shape) except AssertionError as err: raise util.BeamAssertException(err)