[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA TF] Support building against CUDA 12.0 #58867

Merged
merged 11 commits into from
Jan 18, 2023
Prev Previous commit
Load cupti dso using correct version.
Fixes issue where minor version was incorrectly included in dso
name with cuda 12.
  • Loading branch information
nluehr committed Jan 10, 2023
commit 4a04c65383f333fc23d70dc72e8a76b605ccc465
3 changes: 2 additions & 1 deletion tensorflow/tsl/platform/default/dso_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace internal {
namespace {
string GetCudaVersion() { return TF_CUDA_VERSION; }
string GetCudaRtVersion() { return TF_CUDART_VERSION; }
string GetCuptiVersion() { return TF_CUPTI_VERSION; }
string GetCudnnVersion() { return TF_CUDNN_VERSION; }
string GetCublasVersion() { return TF_CUBLAS_VERSION; }
string GetCusolverVersion() { return TF_CUSOLVER_VERSION; }
Expand Down Expand Up @@ -113,7 +114,7 @@ StatusOr<void*> GetCurandDsoHandle() {

StatusOr<void*> GetCuptiDsoHandle() {
// Load specific version of CUPTI this is built.
auto status_or_handle = GetDsoHandle("cupti", GetCudaVersion());
auto status_or_handle = GetDsoHandle("cupti", GetCuptiVersion());
if (status_or_handle.ok()) return status_or_handle;
// Load whatever libcupti.so user specified.
return GetDsoHandle("cupti", "");
Expand Down
1 change: 1 addition & 0 deletions third_party/gpus/cuda/cuda_config.h.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.

#define TF_CUDA_VERSION "%{cuda_version}"
#define TF_CUDART_VERSION "%{cudart_version}"
#define TF_CUPTI_VERSION "%{cupti_version}"
#define TF_CUBLAS_VERSION "%{cublas_version}"
#define TF_CUSOLVER_VERSION "%{cusolver_version}"
#define TF_CURAND_VERSION "%{curand_version}"
Expand Down
2 changes: 2 additions & 0 deletions third_party/gpus/cuda_configure.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,7 @@ filegroup(name="cudnn-include")
{
"%{cuda_version}": "",
"%{cudart_version}": "",
"%{cupti_version}": "",
"%{cublas_version}": "",
"%{cusolver_version}": "",
"%{curand_version}": "",
Expand Down Expand Up @@ -1333,6 +1334,7 @@ def _create_local_cuda_repository(repository_ctx):
{
"%{cuda_version}": cuda_config.cuda_version,
"%{cudart_version}": cuda_config.cudart_version,
"%{cupti_version}": cuda_config.cupti_version,
"%{cublas_version}": cuda_config.cublas_version,
"%{cusolver_version}": cuda_config.cusolver_version,
"%{curand_version}": cuda_config.curand_version,
Expand Down