[go: nahoru, domu]

Skip to content

Commit

Permalink
Added a fingerprint field to PjRtStreamExecutorLoadedExecutable to av…
Browse files Browse the repository at this point in the history
…oid recalculating fingerprints when FingerprintExecutable() is called. This change significantly reduces idle time before execution when the GPU load tracker enqueues an executable.

PiperOrigin-RevId: 642397394
  • Loading branch information
tensorflower-gardener committed Jun 17, 2024
1 parent 8b5f1d5 commit ccac76a
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 22 deletions.
2 changes: 1 addition & 1 deletion tensorflow/compiler/jit/xla_launch_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ Status RunPjRtExecutable(
pjrt_client->LookupAddressableDevice(pjrt_device_id));

gpu::GpuServingDeviceSelectorResource* device_selector_resource = nullptr;
if (device_type == DEVICE_GPU && gpu::kUseGpuServingDeviceSelector) {
if (device_type == DEVICE_GPU) {
auto rm = ctx->resource_manager();
TF_RETURN_IF_ERROR(rm->LookupOrCreate<
gpu::GpuServingDeviceSelectorResource>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ namespace gpu {
class GpuServingDeviceSelector;
const char kGpuServingDeviceSelectorResourceName[] =
"gpu_serving_device_selector";
// TODO(b/335729939): Disable GPU load tracker for performance regression
// investigation. Remove when fixed.
const bool kUseGpuServingDeviceSelector = false;

class GpuServingDeviceSelectorResource : public ResourceBase {
public:
Expand Down
1 change: 0 additions & 1 deletion tensorflow/core/kernels/mkl/mkl_quantize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,6 @@ class MklQuantizeV2Op : public OpKernel {

void Compute(OpKernelContext* ctx) override {
const unsigned int src_idx = 0;
const Tensor& input = ctx->input(src_idx);
const float input_min_range = ctx->input(1).scalar<float>()();
const float input_max_range = ctx->input(2).scalar<float>()();
float min_range = std::min(0.0f, input_min_range);
Expand Down
18 changes: 2 additions & 16 deletions third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2290,6 +2290,8 @@ PjRtStreamExecutorLoadedExecutable::PjRtStreamExecutorLoadedExecutable(
parameter_shapes.push_back(transfer_manager->HostShapeToDeviceShape(
computation_layout.parameter_shape(i)));
}
fingerprint_ = absl::StrCat(
fingerprint_, executable->executable()->module().GetFingerprint128());
executables_.emplace_back(std::move(executable));
on_device_executable_parameter_shapes_.push_back(
std::move(parameter_shapes));
Expand Down Expand Up @@ -3245,22 +3247,6 @@ PjRtStreamExecutorLoadedExecutable::GetOutputMemoryKinds() const {
return Unimplemented("GetOutputMemoryKinds is not supported.");
}

absl::StatusOr<std::string>
PjRtStreamExecutorLoadedExecutable::FingerprintExecutable() const {
if (executables_.size() != 1) {
return absl::InternalError(
"Fingerprinting multiple executables within one "
"PjRtStreamExecutorLoadedExecutable is not supported.");
}

Executable* executable = executables_[0]->executable();
if (executable->has_module()) {
return executable->module().GetFingerprint128();
} else {
return absl::InternalError("Executable does not have HLO modules.");
}
}

absl::StatusOr<PjRtStreamExecutorClient::ExecutableExtras>
PjRtStreamExecutorClient::GetExecutableExtras(CompileOptions* options) {
ExecutableExtras extras;
Expand Down
5 changes: 4 additions & 1 deletion third_party/xla/xla/pjrt/pjrt_stream_executor_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,9 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
return compile_options_;
}

absl::StatusOr<std::string> FingerprintExecutable() const override;
absl::StatusOr<std::string> FingerprintExecutable() const override {
return fingerprint_;
};

protected:
bool parameter_is_tupled_arguments() const {
Expand Down Expand Up @@ -1078,6 +1080,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
// addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of
// unique_ptrs to play well with the Python bindings (see xla.cc).
std::vector<PjRtDevice*> addressable_devices_;
std::string fingerprint_;
};

} // namespace xla
Expand Down

0 comments on commit ccac76a

Please sign in to comment.