[go: nahoru, domu]

Skip to content

Commit

Permalink
Support reusing calibration data if exists
Browse files Browse the repository at this point in the history
Calibration is the most time-consuming step in quantization. This cl will help avoiding it in case of users already quantized it before and the changes in configuration if have doesn't affect the calibration results.

PiperOrigin-RevId: 626929619
  • Loading branch information
thaink authored and tensorflower-gardener committed May 14, 2024
1 parent 820f077 commit c0027b1
Show file tree
Hide file tree
Showing 16 changed files with 453 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ cc_library(
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:types",
"//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc",
"//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc",
"//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_statistics_proto_cc",
"//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes",
"//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib",
"//tensorflow/core/protobuf:for_core_protos_cc",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/base/nullability.h"
#include "absl/container/flat_hash_map.h"
#include "absl/log/die_if_null.h"
Expand All @@ -42,6 +43,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h"
#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h"
#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h"
#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h"
#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h"
Expand All @@ -51,32 +53,52 @@ limitations under the License.
#include "tsl/platform/statusor.h"

namespace mlir::quant::stablehlo {
namespace {

using ::stablehlo::quantization::AddCalibrationStatistics;
using ::stablehlo::quantization::CreateRepresentativeDatasetFileMap;
using ::stablehlo::quantization::DisableDebugging;
using ::stablehlo::quantization::IsCalibrationRequired;
using ::stablehlo::quantization::QuantizationConfig;
using ::stablehlo::quantization::ReadStatistics;
using ::stablehlo::quantization::RepresentativeDatasetConfig;
using ::stablehlo::quantization::io::CreateTmpDir;
using ::stablehlo::quantization::io::GetLocalTmpFileName;
using ::stablehlo::quantization::io::ListDirectory;
using ::tensorflow::AssetFileDef;
using ::tensorflow::SignatureDef;
using ::tensorflow::calibrator::CalibrationStatistics;
using ::tensorflow::quantization::ExportedModel;
using ::tensorflow::quantization::PyFunctionLibrary;
using ::tensorflow::quantization::RunPasses;
using CalibrationStatisticsFlatMap =
absl::flat_hash_map<std::string, CalibrationStatistics>;

absl::Status RunCalibrationPasses(mlir::ModuleOp module_op, MLIRContext& ctx,
absl::string_view calibration_data_dir) {
} // namespace

absl::Status RunCalibrationPasses(
mlir::ModuleOp module_op, MLIRContext& ctx,
absl::string_view calibration_data_dir,
const bool force_regenerate_calibration_data) {
// Disable DumpTensor ops when running calibration.
DisableDebugging(module_op);

std::vector<std::string> skipping_aggregator_ops;
if (!force_regenerate_calibration_data) {
TF_ASSIGN_OR_RETURN(const CalibrationStatisticsFlatMap statistics_map,
ReadStatistics(calibration_data_dir));
absl::c_for_each(statistics_map, [&](const auto& iter) {
return skipping_aggregator_ops.push_back(iter.first);
});
}

return RunPasses(
/*name=*/
CalibrationComponent::kName,
/*add_passes_func=*/
[calibration_data_dir](PassManager& pm) {
pm.addPass(
CreateInsertCalibrationStatisticsSaverPass(calibration_data_dir));
[calibration_data_dir, &skipping_aggregator_ops](PassManager& pm) {
pm.addPass(CreateInsertCalibrationStatisticsSaverPass(
calibration_data_dir, skipping_aggregator_ops));
},
ctx, module_op);
}
Expand All @@ -97,17 +119,23 @@ CalibrationComponent::CalibrationComponent(
signature_def_map_(std::move(signature_def_map)),
signature_keys_(std::move(signature_keys)) {}

absl::StatusOr<ExportedModel> CalibrationComponent::ExportToSavedModel(
absl::Status CalibrationComponent::ExportToSavedModel(
ModuleOp module_op, absl::string_view calibration_data_dir,
const bool force_regenerate_calibration_data,
const absl::string_view dst_saved_model_path) {
TF_ASSIGN_OR_RETURN(const std::string checkpoint_dir, GetLocalTmpFileName());

// Clone ModuleOp and function aliases so changes in this pipeline won't
// be reflected in the original values.
mlir::OwningOpRef<mlir::ModuleOp> cloned_module_ref(module_op.clone());

TF_RETURN_IF_ERROR(
RunCalibrationPasses(*cloned_module_ref, *ctx_, calibration_data_dir));
TF_RETURN_IF_ERROR(RunCalibrationPasses(*cloned_module_ref, *ctx_,
calibration_data_dir,
force_regenerate_calibration_data));

const bool is_calibration_required =
IsCalibrationRequired(*cloned_module_ref);
if (!is_calibration_required) return absl::OkStatus();

// `duplicate_shape_determining_constants = false` because the
// resulting graph of this step is not expected to be loaded on TPU.
Expand All @@ -128,13 +156,13 @@ absl::StatusOr<ExportedModel> CalibrationComponent::ExportToSavedModel(
src_saved_model_path_, tags_,
signature_def_map_);

return exported_model;
return absl::OkStatus();
}

absl::StatusOr<ModuleOp> CalibrationComponent::Run(
ModuleOp module_op, const QuantizationConfig& config) {
// Exports the pre-calibrated model to SavedModel.
TF_ASSIGN_OR_RETURN(const std::string precalibrated_saved_model_dir,
// Export the calibration model to SavedModel.
TF_ASSIGN_OR_RETURN(const std::string calibration_saved_model_dir,
CreateTmpDir());

std::string calibration_data_dir =
Expand All @@ -143,29 +171,32 @@ absl::StatusOr<ModuleOp> CalibrationComponent::Run(
TF_ASSIGN_OR_RETURN(calibration_data_dir, CreateTmpDir());
}

TF_ASSIGN_OR_RETURN(ExportedModel exported_model,
ExportToSavedModel(module_op, calibration_data_dir,
precalibrated_saved_model_dir));

// Translates `RepresentativeDatasetConfig`s to signature key ->
// `RepresentativeDatasetFile` mapping.
const auto dataset_configs =
config.calibration_options().representative_datasets();
const std::vector<RepresentativeDatasetConfig> dataset_config_vector(
dataset_configs.begin(), dataset_configs.end());
TF_ASSIGN_OR_RETURN(
const auto representative_dataset_file_map,
CreateRepresentativeDatasetFileMap(dataset_config_vector));

// Runs calibration on the exported model. The statistics will be stored in a
// separate singleton object `CalibratorSingleton` and are directly added to
// `exported_model` without re-importing it.
if (py_function_lib_->RunCalibration(
precalibrated_saved_model_dir, signature_keys_, tags_,
/*force_graph_mode_calibration=*/true,
representative_dataset_file_map) == std::nullopt) {
return absl::InternalError(
"CalibrationComponent error: Failed to run calibration.");
TF_RETURN_IF_ERROR(ExportToSavedModel(
module_op, calibration_data_dir,
config.calibration_options().force_regenerate_calibration_data(),
calibration_saved_model_dir));

TF_ASSIGN_OR_RETURN(std::vector<std::string> calibration_saved_model_files,
ListDirectory(calibration_saved_model_dir));
if (!calibration_saved_model_files.empty()) {
// Translate `RepresentativeDatasetConfig`s to signature key ->
// `RepresentativeDatasetFile` mapping.
const auto dataset_configs =
config.calibration_options().representative_datasets();
const std::vector<RepresentativeDatasetConfig> dataset_config_vector(
dataset_configs.begin(), dataset_configs.end());
TF_ASSIGN_OR_RETURN(
const auto representative_dataset_file_map,
CreateRepresentativeDatasetFileMap(dataset_config_vector));

// Run calibration on the exported model.
if (py_function_lib_->RunCalibration(
calibration_saved_model_dir, signature_keys_, tags_,
/*force_graph_mode_calibration=*/true,
representative_dataset_file_map) == std::nullopt) {
return absl::InternalError(
"CalibrationComponent error: Failed to run calibration.");
}
}

if (absl::Status status = AddCalibrationStatistics(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,11 @@ class CalibrationComponent : public Component {
// Exports `module_op` to SavedModel at `dst_saved_model_path`. This is used
// to export the pre-calibrated `module_op` to SavedModel so that the
// calibration process can use it to load and run the graph with the
// representative dataset.
absl::StatusOr<tensorflow::quantization::ExportedModel> ExportToSavedModel(
ModuleOp module_op, absl::string_view calibration_data_dir,
absl::string_view dst_saved_model_path);
// representative dataset. Returns a failure status if the export fails.
absl::Status ExportToSavedModel(ModuleOp module_op,
absl::string_view calibration_data_dir,
bool force_regenerate_calibration_data,
absl::string_view dst_saved_model_path);

// Imports the SavedModel at `calibrated_saved_model_path` to `ModuleOp` after
// running calibration.
Expand Down Expand Up @@ -113,7 +114,8 @@ class CalibrationComponent : public Component {

// Runs passes to prepare the calibration model.
absl::Status RunCalibrationPasses(mlir::ModuleOp module_op, MLIRContext& ctx,
absl::string_view calibration_data_dir);
absl::string_view calibration_data_dir,
bool force_regenerate_calibration_data);

} // namespace mlir::quant::stablehlo

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ using ::tensorflow::quantization::PyFunctionLibrary;
using CalibrationStatisticsFlatMap =
absl::flat_hash_map<std::string, CalibrationStatistics>;

} // namespace

// Reads the calibration statistics from the given directory.
absl::StatusOr<CalibrationStatisticsFlatMap> ReadStatistics(
absl::string_view calibration_data_dir) {
Expand All @@ -63,8 +65,6 @@ absl::StatusOr<CalibrationStatisticsFlatMap> ReadStatistics(
return statistics_map;
}

} // namespace

absl::Status AddCalibrationStatistics(
mlir::ModuleOp module_op, absl::string_view calibration_data_dir,
const CalibrationOptions& calibration_options,
Expand Down Expand Up @@ -102,4 +102,14 @@ absl::Status AddCalibrationStatistics(
return status;
}

bool IsCalibrationRequired(mlir::ModuleOp module_op) {
bool calibration_required = false;
module_op.walk(
[&calibration_required](
mlir::TF::CalibrationStatisticsSaverOp statistics_saver_op) {
calibration_required = true;
});
return calibration_required;
}

} // namespace stablehlo::quantization
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,24 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_STATISTICS_H_
#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_STATISTICS_H_

#include <string>

#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h"
#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h"
#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h"

namespace stablehlo::quantization {

// Reads the calibration statistics from the given directory.
absl::StatusOr<absl::flat_hash_map<
std::string, tensorflow::calibrator::CalibrationStatistics>>
ReadStatistics(absl::string_view calibration_data_dir);

// Adds calibrated min / max values to CustomAggregator nodes in `graph_def`.
// The min and max values will be added to the "min" and "max" attributes,
// respectively. `calibration_options` provides the strategy to retrieve min and
Expand All @@ -32,6 +42,9 @@ absl::Status AddCalibrationStatistics(
const stablehlo::quantization::CalibrationOptions& calibration_options,
const tensorflow::quantization::PyFunctionLibrary& py_function_library);

// Checks if the model required calibration.
bool IsCalibrationRequired(mlir::ModuleOp module_op);

} // namespace stablehlo::quantization

#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_STATISTICS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.
#include <cstdint>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>

#include "absl/strings/string_view.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
Expand Down Expand Up @@ -48,11 +50,14 @@ std::string GetOutputFilePath(absl::string_view calibration_data_dir,
}

// Finds `CustomAggregator` ops and collects their outputs and attributes.
void FindCustomAggregatorOps(Region& region,
SmallVector<Value>& statistics_outputs,
SmallVector<StringRef>& ids,
SmallVector<int32_t>& calibration_methods) {
void FindCustomAggregatorOps(
Region& region,
const std::unordered_set<std::string>& aggregator_ops_to_ignore,
SmallVector<Value>& statistics_outputs, SmallVector<StringRef>& ids,
SmallVector<int32_t>& calibration_methods) {
for (auto op : region.getOps<TF::CustomAggregatorOp>()) {
if (aggregator_ops_to_ignore.count(op.getId().str())) continue;

ids.push_back(op.getId());
calibration_methods.push_back(op.getCalibrationMethod());
statistics_outputs.push_back(op.getMin());
Expand All @@ -63,11 +68,13 @@ void FindCustomAggregatorOps(Region& region,

// Inserts a `CalibrationStatisticsSaverOp` to the end of the region.
LogicalResult InsertCalibrationStatisticsSaverOp(
Region& region, MLIRContext& ctx, absl::string_view output_file_path) {
Region& region, MLIRContext& ctx, absl::string_view output_file_path,
const std::unordered_set<std::string>& aggregator_ops_to_ignore) {
SmallVector<Value> statistics_outputs;
SmallVector<StringRef> ids;
SmallVector<int32_t> calibration_methods;
FindCustomAggregatorOps(region, statistics_outputs, ids, calibration_methods);
FindCustomAggregatorOps(region, aggregator_ops_to_ignore, statistics_outputs,
ids, calibration_methods);
if (statistics_outputs.empty()) return failure();

OpBuilder builder(&ctx);
Expand Down Expand Up @@ -115,6 +122,7 @@ bool ContainCalibrationStatisticsSaverOp(Operation* op) {

} // namespace

#define GEN_PASS_DECL_INSERTCALIBRATIONSTATISTICSSAVERPASS
#define GEN_PASS_DEF_INSERTCALIBRATIONSTATISTICSSAVERPASS
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc"

Expand All @@ -126,29 +134,30 @@ class InsertCalibrationStatisticsSaverPass
InsertCalibrationStatisticsSaverPass>::
InsertCalibrationStatisticsSaverPassBase;

explicit InsertCalibrationStatisticsSaverPass(StringRef calibration_data_dir)
: calibration_data_dir_(calibration_data_dir) {}

private:
std::string calibration_data_dir_;
void runOnOperation() override;
};

void InsertCalibrationStatisticsSaverPass::runOnOperation() {
ModuleOp module_op = getOperation();
MLIRContext& ctx = getContext();

std::unordered_set<std::string> aggregator_ops_to_ignore(
aggregator_ops_to_ignore_.begin(), aggregator_ops_to_ignore_.end());

// Insert CalibrationStatisticsSaverOp to the end of each region.
for (auto func_op : module_op.getOps<func::FuncOp>()) {
int32_t output_file_idx = 0;
StringRef func_name = func_op.getSymName();

func_op.walk([&output_file_idx, &ctx, &func_name, this](Operation* op) {
func_op.walk([&output_file_idx, &ctx, &func_name, &aggregator_ops_to_ignore,
this](Operation* op) {
for (Region& region : op->getRegions()) {
if (succeeded(InsertCalibrationStatisticsSaverOp(
region, ctx,
GetOutputFilePath(calibration_data_dir_, func_name,
output_file_idx)))) {
output_file_idx),
aggregator_ops_to_ignore))) {
++output_file_idx;
};
}
Expand All @@ -167,9 +176,14 @@ void InsertCalibrationStatisticsSaverPass::runOnOperation() {
}

std::unique_ptr<OperationPass<ModuleOp>>
CreateInsertCalibrationStatisticsSaverPass(StringRef calibration_data_dir) {
return std::make_unique<InsertCalibrationStatisticsSaverPass>(
calibration_data_dir);
CreateInsertCalibrationStatisticsSaverPass(
StringRef calibration_data_dir,
const std::vector<std::string>& aggregator_ops_to_ignore) {
InsertCalibrationStatisticsSaverPassOptions options = {
.aggregator_ops_to_ignore_ = aggregator_ops_to_ignore,
.calibration_data_dir_ = calibration_data_dir.str(),
};
return std::make_unique<InsertCalibrationStatisticsSaverPass>(options);
}

} // namespace mlir::quant::stablehlo
Loading

0 comments on commit c0027b1

Please sign in to comment.