[go: nahoru, domu]

Skip to content

Commit

Permalink
Remove unused VlogOccupancyInfo calls.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 646528556
  • Loading branch information
klucke authored and tensorflower-gardener committed Jun 25, 2024
1 parent 038957a commit a2680c5
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 204 deletions.
8 changes: 6 additions & 2 deletions third_party/xla/build_tools/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,9 @@ def nvidia_gpu_build_with_compute_capability(
JAX_NUM_GENERATED_CASES=25,
JAX_SKIP_SLOW_TESTS=1,
),
options=_DEFAULT_BAZEL_OPTIONS,
options=dict(
**_DEFAULT_BAZEL_OPTIONS, override_repository="xla=/github/xla"
),
)

_JAX_GPU_BUILD = Build(
Expand All @@ -294,7 +296,9 @@ def nvidia_gpu_build_with_compute_capability(
TF_CPP_MIN_LOG_LEVEL=0,
JAX_EXCLUDE_TEST_TARGETS="PmapTest.testSizeOverflow",
),
options=_DEFAULT_BAZEL_OPTIONS,
options=dict(
**_DEFAULT_BAZEL_OPTIONS, override_repository="xla=/github/xla"
),
)

_KOKORO_JOB_NAME_TO_BUILD_MAP = {
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4796,7 +4796,7 @@ cc_library(
] + if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cudnn_header",
]) + if_static([
]) + if_google([
"@com_google_protobuf//:wrappers_cc_proto",
]),
)
Expand Down
90 changes: 0 additions & 90 deletions third_party/xla/xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -465,20 +465,6 @@ absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
const GpuKernel* cuda_kernel = AsGpuKernel(&kernel);
CUfunction cufunc = cuda_kernel->AsGpuFunctionHandle();

// Only perform/print the occupancy check once. Even just checking to see
// whether we've done an occupancy check on this kernel before isn't free
// (because we have to synchronize), so we only do this at -v 2+.
if (VLOG_IS_ON(2)) {
absl::MutexLock lock(&launched_kernels_mu_);
if (!launched_kernels_.count(cufunc)) {
VlogOccupancyInfo(stream->parent()->GetDeviceDescription(), kernel,
thread_dims, block_dims);
// TODO(rspringer): Remove elements from launched_kernels_...if we ever
// expose a kernel/module deallocation method.
launched_kernels_.insert(cufunc);
}
}

if (cuda_kernel->cache_config() != KernelCacheConfig::kNoPreference) {
TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig(
cufunc, cuda_kernel->GetGpuCacheConfig()));
Expand Down Expand Up @@ -547,82 +533,6 @@ absl::Status GpuExecutor::Submit(Stream* stream,
return GpuDriver::GraphLaunch(exec, AsGpuStreamValue(stream));
}

// This is a non-essential operation; if there's a failure, proceed without
// logging an error. It's nearly certain that in case of failures, we'd never
// get here in the first place; these are very low-impact routines.
void GpuExecutor::VlogOccupancyInfo(const DeviceDescription& device_description,
const Kernel& kernel,
const ThreadDim& thread_dims,
const BlockDim& block_dims) {
VLOG(2) << "Computing kernel occupancy for kernel "
<< kernel.demangled_name();
VLOG(2) << "Thread dimensions (" << thread_dims.x << ", " << thread_dims.y
<< ", " << thread_dims.z << ")";

auto regs_per_thread = kernel.metadata().registers_per_thread();
auto smem_per_block = kernel.metadata().shared_memory_bytes();

if (!regs_per_thread && !smem_per_block) {
return;
}

const GpuKernel* cuda_kernel = AsGpuKernel(&kernel);
CUfunction cufunc = cuda_kernel->AsGpuFunctionHandle();

int blocks_per_sm = CalculateOccupancy(device_description, *regs_per_thread,
*smem_per_block, thread_dims, cufunc);
VLOG(2) << "Resident blocks per SM is " << blocks_per_sm;

int suggested_threads =
CompareOccupancy(&blocks_per_sm, device_description, *regs_per_thread,
*smem_per_block, thread_dims, cufunc);
if (suggested_threads != 0) {
VLOG(2) << "The cuda occupancy calculator recommends using "
<< suggested_threads
<< " threads per block to achieve an occupancy of " << blocks_per_sm
<< " blocks per SM.";
}
}

// Compute and return maximum blocks per core (occupancy) based on the
// device description, some kernel characteristics and the number of threads per
// block. If unable to compute occupancy, zero is returned.
int GpuExecutor::CalculateOccupancy(const DeviceDescription& device_description,
uint64_t registers_per_thread,
uint64_t shared_memory_per_block,
const ThreadDim& thread_dims,
CUfunction func) {
int suggested_blocks = 0;
int suggested_threads = 0;
CUresult err = cuOccupancyMaxPotentialBlockSize(
&suggested_blocks, &suggested_threads, func, nullptr,
shared_memory_per_block, 0);
CHECK_EQ(err, CUDA_SUCCESS);
return suggested_blocks;
}

// Compute and return the suggested thread count to achieve ideal occupancy.
// If the provided thread dimensions match this number, zero is returned.
int GpuExecutor::CompareOccupancy(int* initial_blocks,
const DeviceDescription& device_description,
uint64_t registers_per_thread,
uint64_t shared_memory_per_block,
const ThreadDim& thread_dims,
CUfunction func) {
int suggested_blocks = 0;
int suggested_threads = 0;
CUresult err = cuOccupancyMaxPotentialBlockSize(
&suggested_blocks, &suggested_threads, func, nullptr,
shared_memory_per_block, 0);
CHECK_EQ(err, CUDA_SUCCESS);
if (suggested_blocks > *initial_blocks) {
*initial_blocks = suggested_blocks;
return suggested_threads;
} else {
return 0;
}
}

DeviceMemoryBase GpuExecutor::Allocate(uint64_t size, int64_t memory_space) {
if (memory_space == 1) {
auto result = GpuCollectives::CollectiveMemoryAllocate(context_, size);
Expand Down
25 changes: 0 additions & 25 deletions third_party/xla/xla/stream_executor/gpu/gpu_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,6 @@ class GpuExecutor : public StreamExecutorCommon {
absl::Status Submit(Stream* stream,
const CommandBuffer& command_buffer) override;

int CalculateOccupancy(const DeviceDescription& device_description,
uint64_t registers_per_thread,
uint64_t shared_memory_per_block,
const ThreadDim& thread_dims, GpuFunctionHandle func);

int CompareOccupancy(int* initial_blocks,
const DeviceDescription& device_description,
uint64_t registers_per_thread,
uint64_t shared_memory_per_block,
const ThreadDim& thread_dims, GpuFunctionHandle func);

DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override;

void Deallocate(DeviceMemoryBase* mem) override;
Expand Down Expand Up @@ -320,12 +309,6 @@ class GpuExecutor : public StreamExecutorCommon {
absl::Status GetKernelMetadata(GpuKernel* cuda_kernel,
KernelMetadata* kernel_metadata);

// Prints to VLOG(2) information about the kernel's occupancy and how it might
// be improved.
void VlogOccupancyInfo(const DeviceDescription& device_description,
const Kernel& kernel, const ThreadDim& thread_dims,
const BlockDim& block_dims);

// (supported on CUDA only)
absl::Status LoadModuleFromCuBin(const char* cubin, GpuModuleHandle* module)
TF_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);
Expand Down Expand Up @@ -377,14 +360,6 @@ class GpuExecutor : public StreamExecutorCommon {
std::unordered_map<const void*, std::pair<GpuModuleHandle, uint64_t>>
gpu_binary_to_module_ ABSL_GUARDED_BY(in_memory_modules_mu_);

// Guards the launched kernel set.
absl::Mutex launched_kernels_mu_;

// Keeps track of the set of launched kernels. Currently used to suppress the
// occupancy check on subsequent launches.
std::set<GpuFunctionHandle> launched_kernels_
ABSL_GUARDED_BY(launched_kernels_mu_);

// Handle for the CUDA device being operated on. Immutable
// post-initialization.
GpuDeviceHandle device_;
Expand Down
86 changes: 0 additions & 86 deletions third_party/xla/xla/stream_executor/rocm/rocm_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -337,20 +337,6 @@ absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
const GpuKernel* rocm_kernel = AsGpuKernel(&kernel);
hipFunction_t hipfunc = rocm_kernel->AsGpuFunctionHandle();

// Only perform/print the occupancy check once. Even just checking to see
// whether we've done an occupancy check on this kernel before isn't free
// (because we have to synchronize), so we only do this at -v 2+.
if (VLOG_IS_ON(2)) {
absl::MutexLock lock(&launched_kernels_mu_);
if (!launched_kernels_.count(hipfunc)) {
VlogOccupancyInfo(stream->parent()->GetDeviceDescription(), kernel,
thread_dims, block_dims);
// TODO(rspringer): Remove elements from launched_kernels_...if we ever
// expose a kernel/module deallocation method.
launched_kernels_.insert(hipfunc);
}
}

if (rocm_kernel->cache_config() != KernelCacheConfig::kNoPreference) {
TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig(
hipfunc, rocm_kernel->GetGpuCacheConfig()));
Expand Down Expand Up @@ -458,78 +444,6 @@ absl::Status GpuExecutor::LoadModuleFromHsaco(const char* hsaco,
return absl::OkStatus();
}

// This is a non-essential operation; if there's a failure, proceed without
// logging an error. It's nearly certain that in case of failures, we'd never
// get here in the first place; these are very low-impact routines.
void GpuExecutor::VlogOccupancyInfo(const DeviceDescription& device_description,
const Kernel& kernel,
const ThreadDim& thread_dims,
const BlockDim& block_dims) {
VLOG(2) << "Computing kernel occupancy for kernel "
<< kernel.demangled_name();
VLOG(2) << "Thread dimensions (" << thread_dims.x << ", " << thread_dims.y
<< ", " << thread_dims.z << ")";

auto regs_per_thread = kernel.metadata().registers_per_thread();
auto smem_per_block = kernel.metadata().shared_memory_bytes();

if (!regs_per_thread && !smem_per_block) {
return;
}

const GpuKernel* rocm_kernel = AsGpuKernel(&kernel);
auto hipfunc = rocm_kernel->AsGpuFunctionHandle();

int blocks_per_sm = CalculateOccupancy(device_description, *regs_per_thread,
*smem_per_block, thread_dims, hipfunc);
VLOG(2) << "Resident blocks per SM is " << blocks_per_sm;

int suggested_threads =
CompareOccupancy(&blocks_per_sm, device_description, *regs_per_thread,
*smem_per_block, thread_dims, hipfunc);
if (suggested_threads != 0) {
VLOG(2) << "The rocm occupancy calculator recommends using "
<< suggested_threads
<< " threads per block to achieve an occupancy of " << blocks_per_sm
<< " blocks per SM.";
}
}

// Compute and return maximum blocks per core (occupancy) based on the
// device description, some kernel characteristics and the number of threads per
// block. If unable to compute occupancy, zero is returned.
int GpuExecutor::CalculateOccupancy(const DeviceDescription& device_description,
uint64_t registers_per_thread,
uint64_t shared_memory_per_block,
const ThreadDim& thread_dims,
GpuFunctionHandle func) {
int suggested_blocks = 0;
int suggested_threads = 0;
(void)rocm::OccupancyGetMaxPotentialBlockSize(
&suggested_blocks, &suggested_threads, func, shared_memory_per_block, 0);
return suggested_blocks;
}

// Compute and return the suggested thread count to achieve ideal occupancy.
// If the provided thread dimensions match this number, zero is returned.
int GpuExecutor::CompareOccupancy(int* initial_blocks,
const DeviceDescription& device_description,
uint64_t registers_per_thread,
uint64_t shared_memory_per_block,
const ThreadDim& thread_dims,
GpuFunctionHandle func) {
int suggested_blocks = 0;
int suggested_threads = 0;
(void)rocm::OccupancyGetMaxPotentialBlockSize(
&suggested_blocks, &suggested_threads, func, shared_memory_per_block, 0);
if (suggested_blocks > *initial_blocks) {
*initial_blocks = suggested_blocks;
return suggested_threads;
} else {
return 0;
}
}

DeviceMemoryBase GpuExecutor::Allocate(uint64_t size, int64_t memory_space) {
if (memory_space ==
static_cast<int64_t>(stream_executor::MemoryType::kHost)) {
Expand Down

0 comments on commit a2680c5

Please sign in to comment.