[go: nahoru, domu]

Skip to content

Commit

Permalink
Merge pull request #223 from casassg/casassg/multi-output-model
Browse files Browse the repository at this point in the history
Add support for multiple output models
  • Loading branch information
codesue committed Dec 1, 2022
2 parents 8e45afb + e0e85c9 commit 87c0b4d
Showing 1 changed file with 36 additions and 25 deletions.
61 changes: 36 additions & 25 deletions model_card_toolkit/utils/tfx_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,31 +442,42 @@ def _parse_array_value(array: Dict[str, Any]) -> str:
logging.warning('Received unexpected array %s', str(array))
return ''

for slice_repr, metrics_for_slice in (
eval_result.get_metrics_for_all_slices().items()):
# Parse the slice name
if not isinstance(slice_repr, tuple):
raise ValueError(
f'Expected EvalResult slices to be tuples; found {type(slice_repr)}')
slice_name = '_X_'.join(f'{a}_{b}' for a, b in slice_repr)
for metric_name, metric_value in metrics_for_slice.items():
# Parse the metric value
parsed_value = ''
if 'doubleValue' in metric_value:
parsed_value = metric_value['doubleValue']
elif 'boundedValue' in metric_value:
parsed_value = metric_value['boundedValue']['value']
elif 'arrayValue' in metric_value:
parsed_value = _parse_array_value(metric_value['arrayValue'])
else:
logging.warning(
'Expected doubleValue, boundedValue, or arrayValue; found %s',
metric_value.keys())
if parsed_value:
# Create the PerformanceMetric and append to the ModelCard
metric = model_card_module.PerformanceMetric(
type=metric_name, value=str(parsed_value), slice=slice_name)
model_card.quantitative_analysis.performance_metrics.append(metric)
# NOTE: When multiple outputs are passed, each will be in it's own output_name key
# If that's the case add each output_name + metric to the quantitative_analysis by namespacing by
# output_name.metric to distinguish them
output_names = set()
for slicing_metric in eval_result.slicing_metrics:
for output_name in slicing_metric[1]:
output_names.add(output_name)
for output_name in sorted(output_names):
for slice_repr, metrics_for_slice in (
eval_result.get_metrics_for_all_slices(output_name=output_name).items()):
# Parse the slice name
if not isinstance(slice_repr, tuple):
raise ValueError(
f'Expected EvalResult slices to be tuples; found {type(slice_repr)}')
slice_name = '_X_'.join(f'{a}_{b}' for a, b in slice_repr)
for metric_name, metric_value in metrics_for_slice.items():
# Parse the metric value
parsed_value = ''
if 'doubleValue' in metric_value:
parsed_value = metric_value['doubleValue']
elif 'boundedValue' in metric_value:
parsed_value = metric_value['boundedValue']['value']
elif 'arrayValue' in metric_value:
parsed_value = _parse_array_value(metric_value['arrayValue'])
else:
logging.warning(
'Expected doubleValue, boundedValue, or arrayValue; found %s',
metric_value.keys())
if parsed_value:
metric_type = metric_name
if output_name:
metric_type = f"{output_name}.{metric_name}"
# Create the PerformanceMetric and append to the ModelCard
metric = model_card_module.PerformanceMetric(
type=metric_type, value=str(parsed_value), slice=slice_name)
model_card.quantitative_analysis.performance_metrics.append(metric)


def filter_metrics(
Expand Down

0 comments on commit 87c0b4d

Please sign in to comment.