[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA TF] Support building against CUDA 12.0 #58867

Merged
merged 11 commits into from
Jan 18, 2023
Prev Previous commit
Next Next commit
Reduce memory overheads in sparse-sparse matmul.
Memory reduction comes at the cost of an additional device-side copy
to concat the gemm results across the batch.
  • Loading branch information
nluehr committed Jan 4, 2023
commit 51cb95a2ce37988f0d6bb6f100ffb0cfdfaa8291
196 changes: 117 additions & 79 deletions tensorflow/core/kernels/sparse/sparse_mat_mul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,19 @@ limitations under the License.
#include <memory>
#include <numeric>

#include "third_party/eigen3/Eigen/SparseCore"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/kernels/concat_lib.h"
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/sparse/kernels.h"
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#include "tensorflow/core/util/work_sharder.h"
#include "third_party/eigen3/Eigen/SparseCore"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/util/cuda_sparse.h"
Expand All @@ -57,6 +58,28 @@ void SwapDimSizes(const int dim_a, const int dim_b, TensorShape* shape) {
shape->set_dim(dim_b, size_a);
}

#if GOOGLE_CUDA

// Concatenates 'inputs' into a single tensor along the zeroth dimension.
// Requires that all elements of 'inputs' have element type T, all inputs
// have more than zero elements and ouput is preallocated.
template <typename T>
void ConcatHelper(OpKernelContext* context, const std::vector<Tensor>& inputs,
Tensor* output) {
// ConcatGPU expects 2D {1, vec_size} shapes.
std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> inputs_flat;
inputs_flat.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
const Tensor& input = inputs[i];
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
input.shaped<T, 2>({1, input.NumElements()})));
}
auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
ConcatGPU<T>(context, inputs_flat, output, &output_flat);
}

#endif

} // namespace

// Op to compute the matrix multiplication of two CSR Sparse Matrices.
Expand Down Expand Up @@ -417,125 +440,140 @@ class CSRSparseMatMulGPUOp : public OpKernel {
GpuSparse cudaSparse(ctx);
OP_REQUIRES_OK(ctx, cudaSparse.Initialize());

// The SpGEMMDescr contains intermediate results so a separate descr must be
// created for each multiply within the batch.
std::vector<GpuSparseSpGEMMDescr> gemmDesc(batch_size);
std::vector<GpuSparseConstSpMatDescr> matA(batch_size);
std::vector<GpuSparseConstSpMatDescr> matB(batch_size);
std::vector<GpuSparseSpMatDescr> matC(batch_size);
size_t maxBufferSize1 = 0;
// Intermediate products, to be concatenated into final result
std::vector<Tensor> colidx_vec, values_vec;
colidx_vec.reserve(batch_size);
values_vec.reserve(batch_size);

// Temporary buffers reused across batch
Tensor buffer1_t, buffer2_t;

// Compute intermediate results
for (int i = 0; i < batch_size; ++i) {
OP_REQUIRES_OK(ctx, gemmDesc[i].Initialize());
GpuSparseSpGEMMDescr gemmDesc;
GpuSparseConstSpMatDescr matA;
GpuSparseConstSpMatDescr matB;
GpuSparseSpMatDescr matC;
OP_REQUIRES_OK(ctx, gemmDesc.Initialize());
OP_REQUIRES_OK(ctx,
matA[i].InitializeCsr(
matA.InitializeCsr(
a_input_dense_shape(a_input_dense_shape.size() - 2),
a_input_dense_shape(a_input_dense_shape.size() - 1),
a_input_matrix->col_indices_vec(i).size(),
a_input_matrix->row_pointers_vec(i).data(),
a_input_matrix->col_indices_vec(i).data(),
a_input_matrix->values_vec<T>(i).data()));
OP_REQUIRES_OK(ctx,
matB[i].InitializeCsr(
matB.InitializeCsr(
b_input_dense_shape(b_input_dense_shape.size() - 2),
b_input_dense_shape(b_input_dense_shape.size() - 1),
b_input_matrix->col_indices_vec(i).size(),
b_input_matrix->row_pointers_vec(i).data(),
b_input_matrix->col_indices_vec(i).data(),
b_input_matrix->values_vec<T>(i).data()));
OP_REQUIRES_OK(ctx,
matC[i].InitializeCsr<int, T>(
matC.InitializeCsr<int, T>(
a_input_dense_shape(a_input_dense_shape.size() - 2),
b_input_dense_shape(b_input_dense_shape.size() - 1), 0,
nullptr, nullptr, nullptr));

// Determine the maximum size for buffer1 across the batch.
// buffer1 will be reused across the batch as each individual gemm will
// overwrite it in the second workEstimation call bellow.
// Check required size for buffer1 and possibly re-allocate
size_t bufferSize1;
OP_REQUIRES_OK(ctx, cudaSparse.SpGEMM_workEstimation<T>(
matA[i], matB[i], matC[i], gemmDesc[i],
&bufferSize1, nullptr));
if (bufferSize1 > maxBufferSize1) {
maxBufferSize1 = bufferSize1;
OP_REQUIRES_OK(
ctx, cudaSparse.SpGEMM_workEstimation<T>(matA, matB, matC, gemmDesc,
&bufferSize1, nullptr));
if (bufferSize1 > buffer1_t.NumElements()) {
OP_REQUIRES_OK(
ctx, ctx->allocate_temp(
DT_INT8, TensorShape({static_cast<int64_t>(bufferSize1)}),
&buffer1_t));
}
} // End for over batch_size
void* buffer1 = buffer1_t.flat<int8>().data();

Tensor buffer1Tensor;
OP_REQUIRES_OK(
ctx, ctx->allocate_temp(
DT_INT8, TensorShape({static_cast<int64_t>(maxBufferSize1)}),
&buffer1Tensor));
void* buffer1 = buffer1Tensor.flat<int8>().data();

// Buffer2 is generated along with the number of non-zero elements (NNZ) and
// (along with the gemmDesc) contains intermediate products needed for
// writing the final output. Because the final outputs cannot be allocated
// until the NNZ is computed for the entire batch, we need to keep
// batch_size buffer2 temoraries.
std::vector<Tensor> buffer2Tensors(batch_size);
for (int i = 0; i < batch_size; ++i) {
// Do workEstimation reusing buffer1 across batch.
OP_REQUIRES_OK(ctx, cudaSparse.SpGEMM_workEstimation<T>(
matA[i], matB[i], matC[i], gemmDesc[i],
&maxBufferSize1, buffer1));
// Do workEstimation using buffer1.
// buffer1 implicitly captured in gemmDesc for use in the compute call.
OP_REQUIRES_OK(
ctx, cudaSparse.SpGEMM_workEstimation<T>(matA, matB, matC, gemmDesc,
&bufferSize1, buffer1));

// Compute size for buffer2 in this batch element.
// Compute size for buffer2 and possibly re-allocate
size_t bufferSize2;
OP_REQUIRES_OK(ctx, cudaSparse.SpGEMM_compute<T>(matA[i], matB[i],
matC[i], gemmDesc[i],
&bufferSize2, nullptr));
OP_REQUIRES_OK(ctx,
cudaSparse.SpGEMM_compute<T>(matA, matB, matC, gemmDesc,
&bufferSize2, nullptr));
if (bufferSize2 > buffer2_t.NumElements()) {
OP_REQUIRES_OK(
ctx, ctx->allocate_temp(
DT_INT8, TensorShape({static_cast<int64_t>(bufferSize2)}),
&buffer2_t));
}
void* buffer2 = buffer2_t.flat<int8>().data();

OP_REQUIRES_OK(
ctx, ctx->allocate_temp(
DT_INT8, TensorShape({static_cast<int64_t>(bufferSize2)}),
&buffer2Tensors[i]));
void* buffer2 = buffer2Tensors[i].flat<int8>().data();
// Populate buffer2 and gemmDesc[i] for this product.
// Note that gemmDesc[i] captures a reference to buffer2 which gets used
// during the SpGEMM_copy() below. So we must preserve buffer2 until the
// copy is complete.
OP_REQUIRES_OK(ctx, cudaSparse.SpGEMM_compute<T>(matA[i], matB[i],
matC[i], gemmDesc[i],
&bufferSize2, buffer2));
// Compute the gemm.
// Note that buffer1 is implicitly consumed here and buffer2 is implicitly
// captured for use by by the copy call.
OP_REQUIRES_OK(ctx,
cudaSparse.SpGEMM_compute<T>(matA, matB, matC, gemmDesc,
&bufferSize2, buffer2));

// Get output dimensions and update batch pointer.
int64_t cRows, cCols, cNnz;
OP_REQUIRES(
ctx,
cusparseSpMatGetSize(matC[i].get(), &cRows, &cCols, &cNnz) ==
cusparseSpMatGetSize(matC.get(), &cRows, &cCols, &cNnz) ==
CUSPARSE_STATUS_SUCCESS,
errors::Internal("Failed to obtain dimensions from SpMatDescr."));
c_batch_ptr(i + 1) = c_batch_ptr(i) + cNnz;
}

Tensor c_col_ind_t;
Tensor c_values_t;
Tensor colidx_tmp, values_tmp;
OP_REQUIRES_OK(
ctx, ctx->allocate_temp(DT_INT32, TensorShape({cNnz}), &colidx_tmp));
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
TensorShape({cNnz}), &values_tmp));

// Copy product to final c_row_ptr and intermediate column and values
// tensors.
void* row_ptr = &c_row_ptr(i * (rows + 1));
void* col_ptr = colidx_tmp.flat<int32>().data();
void* val_ptr = values_tmp.flat<T>().data();
cusparseStatus_t cusp_status =
cusparseCsrSetPointers(matC.get(), row_ptr, col_ptr, val_ptr);
OP_REQUIRES(
ctx, cusp_status == CUSPARSE_STATUS_SUCCESS,
errors::Internal("Failed to update CSR pointers in SpMatDesc."));
OP_REQUIRES_OK(ctx,
cudaSparse.SpGEMM_copy<T>(matA, matB, matC, gemmDesc));

const int total_nnz = c_batch_ptr(batch_size);
// We don't record empty column index or value tensors because Concat
// expects only non-empty inputs.
if (cNnz != 0) {
colidx_vec.emplace_back(std::move(colidx_tmp));
values_vec.emplace_back(std::move(values_tmp));
}
} // End for over batch_size

// Create final buffers
Tensor c_col_ind_t, c_values_t;
int total_nnz = c_batch_ptr(batch_size);
if (colidx_vec.size() == 1) {
c_col_ind_t = std::move(colidx_vec[0]);
c_values_t = std::move(values_vec[0]);
} else if (total_nnz > 0) {
// Multiple intermeidates must be concated together
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_INT32, TensorShape({total_nnz}),
&c_col_ind_t));
OP_REQUIRES_OK(ctx,
ctx->allocate_temp(DataTypeToEnum<T>::value,
TensorShape({total_nnz}), &c_values_t));
ConcatHelper<int>(ctx, colidx_vec, &c_col_ind_t);
ConcatHelper<T>(ctx, values_vec, &c_values_t);
}

OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_INT32, TensorShape({total_nnz}),
&c_col_ind_t));
OP_REQUIRES_OK(ctx,
ctx->allocate_temp(DataTypeToEnum<T>::value,
TensorShape({total_nnz}), &c_values_t));
OP_REQUIRES_OK(ctx,
CSRSparseMatrix::CreateCSRSparseMatrix(
DataTypeToEnum<T>::value, c_dense_shape_t, c_batch_ptr_t,
c_row_ptr_t, c_col_ind_t, c_values_t, &c));

for (int i = 0; i < batch_size; ++i) {
cusparseStatus_t cusp_status = cusparseCsrSetPointers(
matC[i].get(), c.row_pointers_vec(i).data(),
c.col_indices_vec(i).data(), c.values_vec<T>(i).data());
OP_REQUIRES(
ctx, cusp_status == CUSPARSE_STATUS_SUCCESS,
errors::Internal("Failed to update CSR pointers in SpMatDesc."));
// Copy the final result into c from gemmDesc and (implicitly)
// buffer2Tensors[i] which gemmDesc captured during SpGEMM_compute.
OP_REQUIRES_OK(ctx, cudaSparse.SpGEMM_copy<T>(matA[i], matB[i], matC[i],
gemmDesc[i]));
}

#else
functor::CSRSparseSparseMatrixMatMul<Device, T> csr_gemm(
ctx, /*transpose_a=*/false, /*adjoint_a=*/false, /*transpose_b=*/false);
Expand Down