[go: nahoru, domu]

Skip to content

Commit

Permalink
[fix] use opmath_t for activation functions in Activation.cu (#77949)
Browse files Browse the repository at this point in the history
  • Loading branch information
khushi-411 authored and pytorchmergebot committed May 22, 2022
1 parent 57fab66 commit 8da4993
Showing 1 changed file with 114 additions and 90 deletions.
204 changes: 114 additions & 90 deletions aten/src/ATen/native/cuda/Activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ namespace native {
// -----------------------------------
void glu_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "glu_cuda", [&]() {
using acc_t = at::acc_type<scalar_t, true>;
using opmath_t = at::opmath_type<scalar_t>;
gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a_, scalar_t b_) -> scalar_t {
const acc_t a = a_;
const acc_t b = b_;
const acc_t class="pl-c1 x x-first x-last">acc_t(1);
const acc_t sigmoid = one / (one + std::exp(-b));
const opmath_t a = a_;
const opmath_t b = b_;
const opmath_t class="pl-c1 x x-first x-last">opmath_t(1);
const opmath_t sigmoid = one / (one + std::exp(-b));
return a * sigmoid;
});
});
Expand All @@ -40,19 +40,19 @@ void glu_kernel(TensorIteratorBase& iter) {
// -----------------------------------
void glu_jvp_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "glu_cuda", [&]() {
using acc_t = at::acc_type<scalar_t, true>;
using opmath_t = at::opmath_type<scalar_t>;
gpu_kernel(iter, [] GPU_LAMBDA (
scalar_t res_,
scalar_t b_,
scalar_t da_,
scalar_t db_) -> scalar_t {
const acc_t res = res_;
const acc_t b = b_;
const acc_t da = da_;
const acc_t db = db_;
const acc_t class="pl-c1 x x-first x-last">acc_t(1);
const opmath_t res = res_;
const opmath_t b = b_;
const opmath_t da = da_;
const opmath_t db = db_;
const opmath_t class="pl-c1 x x-first x-last">opmath_t(1);

const acc_t sig_b = one / (one + std::exp(-b));
const opmath_t sig_b = one / (one + std::exp(-b));
return (
da * sig_b + res * (db - sig_b * db)
);
Expand Down Expand Up @@ -80,7 +80,7 @@ __global__ void glu_backward_kernel(
int numel, scalar_t* gI, const scalar_t* I, const scalar_t* gO,
OffsetCalc offset_calculator,
int64_t gI_byte_offset, int64_t I_byte_offset) {
using acc_t = at::acc_type<scalar_t, true>;
using opmath_t = at::opmath_type<scalar_t>;

const uint32_t linear_index = blockIdx.x * blockDim.x + threadIdx.x;
if (linear_index >= numel) {
Expand All @@ -91,12 +91,12 @@ __global__ void glu_backward_kernel(
// We explicitly iterate over the first half of the input tensor, and
// gI_byte_offset and I_byte_offset are the offsets to access the
// corresponding index in the second half of the tensor.
const acc_t a = I[offsets[1]];
const acc_t b = *byte_offset(I + offsets[1], I_byte_offset);
const acc_t gO_val = gO[offsets[2]];
const opmath_t a = I[offsets[1]];
const opmath_t b = *byte_offset(I + offsets[1], I_byte_offset);
const opmath_t gO_val = gO[offsets[2]];

const auto class="pl-c1 x x-first x-last">acc_t(1);
const acc_t sigmoid = one / (one + std::exp(-b));
const auto class="pl-c1 x x-first x-last">opmath_t(1);
const opmath_t sigmoid = one / (one + std::exp(-b));

auto* gA = gI + offsets[0];
*gA = sigmoid * gO_val;
Expand Down Expand Up @@ -349,32 +349,43 @@ void shrink_backward_kernel(TensorIteratorBase& iter, const Scalar& value) {
}

void hardtanh_backward_kernel(TensorIterator& iter, const Scalar& min, const Scalar& max) {
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, iter.dtype(), "hardtanh_backward_cuda", [&]() {
auto min_val = min.to<scalar_t>();
auto max_val = max.to<scalar_t>();
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half,
iter.dtype(), "hardtanh_backward_cuda", [&]() {
using opmath_t = at::opmath_type<scalar_t>;
auto min_val = min.to<opmath_t>();
auto max_val = max.to<opmath_t>();
gpu_kernel(iter, [min_val, max_val]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return (b <= min_val) || (b >= max_val) ? scalar_t(0) : a;
opmath_t aop = static_cast<opmath_t>(a);
opmath_t bop = static_cast<opmath_t>(b);
return (bop <= min_val) || (bop >= max_val) ? opmath_t(0) : aop;
});
});
}

void softplus_kernel(TensorIteratorBase& iter, const Scalar& beta_, const Scalar& threshold_) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "softplus_cuda", [&]() {
auto beta = beta_.to<scalar_t>();
auto threshold = threshold_.to<scalar_t>();
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
iter.dtype(), "softplus_cuda", [&]() {
using opmath_t = at::opmath_type<scalar_t>;
auto beta = beta_.to<opmath_t>();
auto threshold = threshold_.to<opmath_t>();
gpu_kernel(iter, [beta, threshold]GPU_LAMBDA(scalar_t a) -> scalar_t {
return (a * beta) > threshold ? a : static_cast<scalar_t>(::log1p(std::exp(a * beta))) / beta;
opmath_t aop = static_cast<opmath_t>(a);
return (aop * beta) > threshold ? aop : (::log1p(std::exp(aop * beta))) / beta;
});
});
}

void softplus_backward_kernel(TensorIteratorBase& iter, const Scalar& beta_, const Scalar& threshold_) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "softplus_backward_cuda", [&]() {
auto beta = beta_.to<scalar_t>();
auto threshold = threshold_.to<scalar_t>();
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
iter.dtype(), "softplus_backward_cuda", [&]() {
using opmath_t = at::opmath_type<scalar_t>;
auto beta = beta_.to<opmath_t>();
auto threshold = threshold_.to<opmath_t>();
gpu_kernel(iter, [beta, threshold]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
scalar_t z = std::exp(b * beta);
return (b * beta) > threshold ? a : a * z / (z + scalar_t(1.));
opmath_t aop = static_cast<opmath_t>(a);
opmath_t bop = static_cast<opmath_t>(b);
opmath_t z = std::exp(bop * beta);
return (bop * beta) > threshold ? aop : aop * z / (z + opmath_t(1.));
});
});
}
Expand Down Expand Up @@ -494,49 +505,56 @@ void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate) {
namespace {

void leaky_relu_kernel(TensorIteratorBase& iter, const Scalar& negval_) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "leaky_relu_cuda", [&]() {
auto negval = negval_.to<scalar_t>();
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
iter.dtype(), "leaky_relu_cuda", [&]() {
using opmath_t = at::opmath_type<scalar_t>;
auto negval = negval_.to<opmath_t>();
gpu_kernel(iter, [negval]GPU_LAMBDA(scalar_t a) -> scalar_t {
return a > scalar_t(0) ? a : a * negval;
opmath_t aop = static_cast<opmath_t>(a);
return aop > opmath_t(0) ? aop : aop * negval;
});
});
}

void leaky_relu_backward_kernel(TensorIteratorBase& iter, const Scalar& negval_) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "leaky_relu_backward_cuda", [&]() {
auto negval = negval_.to<scalar_t>();
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
iter.dtype(), "leaky_relu_backward_cuda", [&]() {
using opmath_t = at::opmath_type<scalar_t>;
auto negval = negval_.to<opmath_t>();
gpu_kernel(iter, [negval]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a > scalar_t(0) ? b : b * negval;
opmath_t aop = static_cast<opmath_t>(a);
opmath_t bop = static_cast<opmath_t>(b);
return aop > opmath_t(0) ? bop : bop * negval;
});
});
}

void hardswish_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardswish_cuda", [&]() {
using T_ACC = acc_type<scalar_t, true>;
const T_ACC zero(0.0f);
const T_ACC one_sixth(1.0f / 6.0f);
const T_ACC three(3.0f);
const T_ACC six(6.0f);
using opmath_t = at::opmath_type<scalar_t>;
const opmath_t zero(0.0f);
const opmath_t one_sixth(1.0f / 6.0f);
const opmath_t three(3.0f);
const opmath_t six(6.0f);
gpu_kernel(iter, [zero, one_sixth, three, six]GPU_LAMBDA(scalar_t self_val) -> scalar_t {
T_ACC x = static_cast<T_ACC>(self_val);
opmath_t x = static_cast<opmath_t>(self_val);
return x * std::min(std::max(x + three, zero), six) * one_sixth;
});
});
}

void hardswish_backward_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardswish_backward_cuda", [&]() {
using T_ACC = acc_type<scalar_t, true>;
const T_ACC zero(0.0f);
const T_ACC three(3.0f);
const T_ACC neg_three(-3.0f);
const T_ACC one_half(0.5f);
using opmath_t = at::opmath_type<scalar_t>;
const opmath_t zero(0.0f);
const opmath_t three(3.0f);
const opmath_t neg_three(-3.0f);
const opmath_t one_half(0.5f);
gpu_kernel(
iter,
[zero, three, neg_three, one_half]GPU_LAMBDA(scalar_t grad_val_, scalar_t self_val_) -> scalar_t {
T_ACC grad_val = static_cast<T_ACC>(grad_val_);
T_ACC self_val = static_cast<T_ACC>(self_val_);
opmath_t grad_val = static_cast<opmath_t>(grad_val_);
opmath_t self_val = static_cast<opmath_t>(self_val_);
if (self_val < neg_three) {
return zero;
} else if (self_val <= three) {
Expand All @@ -549,36 +567,42 @@ void hardswish_backward_kernel(TensorIterator& iter) {
}

void hardsigmoid_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardsigmoid_cuda", [&]() {
using T_ACC = acc_type<scalar_t, true>;
const T_ACC zero(0.0f);
const T_ACC one_sixth(1.0f / 6.0f);
const T_ACC three(3.0f);
const T_ACC six(6.0f);
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
iter.dtype(), "hardsigmoid_cuda", [&]() {
using opmath_t = at::opmath_type<scalar_t>;
const opmath_t zero(0.0f);
const opmath_t one_sixth(1.0f / 6.0f);
const opmath_t three(3.0f);
const opmath_t six(6.0f);
gpu_kernel(iter, [zero, one_sixth, three, six]GPU_LAMBDA(scalar_t self_val) -> scalar_t {
T_ACC x = static_cast<T_ACC>(self_val);
opmath_t x = static_cast<opmath_t>(self_val);
return std::min(std::max(x + three, zero), six) * one_sixth;
});
});
}

void hardsigmoid_backward_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardsigmoid_backward_cuda", [&]() {
using T_ACC = acc_type<scalar_t, true>;
const T_ACC zero(0.0f);
const T_ACC three(3.0f);
const T_ACC neg_three(-3.0f);
const T_ACC one_sixth(1.0f / 6.0f);
gpu_kernel(
iter,
[zero, three, neg_three, one_sixth]GPU_LAMBDA(scalar_t grad_val_, scalar_t self_val_) -> scalar_t {
T_ACC grad_val = static_cast<T_ACC>(grad_val_);
T_ACC self_val = static_cast<T_ACC>(self_val_);
return (self_val > neg_three && self_val < three)
? grad_val * one_sixth
: zero;
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
iter.dtype(),
"hardsigmoid_backward_cuda",
[&]() {
using opmath_t = at::opmath_type<scalar_t>;
const opmath_t zero(0.0f);
const opmath_t three(3.0f);
const opmath_t neg_three(-3.0f);
const opmath_t one_sixth(1.0f / 6.0f);
gpu_kernel(
iter,
[zero, three, neg_three, one_sixth]GPU_LAMBDA(scalar_t grad_val_, scalar_t self_val_) -> scalar_t {
opmath_t grad_val = static_cast<opmath_t>(grad_val_);
opmath_t self_val = static_cast<opmath_t>(self_val_);
return (self_val > neg_three && self_val < three)
? grad_val * one_sixth
: zero;
});
});
});
}

void silu_kernel(TensorIteratorBase& iter) {
Expand All @@ -591,9 +615,9 @@ void silu_kernel(TensorIteratorBase& iter) {
gpu_kernel(
iter,
[] GPU_LAMBDA(scalar_t x) -> scalar_t {
using T_ACC = acc_type<scalar_t, true>;
const T_ACC x_acc = static_cast<T_ACC>(x);
return x_acc / (T_ACC(1) + c10::cuda::compat::exp(-x_acc));
using opmath_t = at::opmath_type<scalar_t>;
const opmath_t x_acc = static_cast<opmath_t>(x);
return x_acc / (opmath_t(1) + c10::cuda::compat::exp(-x_acc));
});
});
}
Expand All @@ -608,12 +632,12 @@ void silu_backward_kernel(TensorIteratorBase& iter) {
gpu_kernel(
iter,
[] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
using T_ACC = acc_type<scalar_t, true>;
const T_ACC dy_acc = static_cast<T_ACC>(dy);
const T_ACC x_acc = static_cast<T_ACC>(x);
const T_ACC s_acc =
T_ACC(1) / (T_ACC(1) + c10::cuda::compat::exp(-x_acc));
return dy_acc * s_acc * (T_ACC(1) + x_acc * (T_ACC(1) - s_acc));
using opmath_t = at::opmath_type<scalar_t>;
const opmath_t dy_acc = static_cast<opmath_t>(dy);
const opmath_t x_acc = static_cast<opmath_t>(x);
const opmath_t s_acc =
opmath_t(1) / (opmath_t(1) + c10::cuda::compat::exp(-x_acc));
return dy_acc * s_acc * (opmath_t(1) + x_acc * (opmath_t(1) - s_acc));
});
});
}
Expand All @@ -628,8 +652,8 @@ void mish_kernel(TensorIteratorBase& iter) {
gpu_kernel(
iter,
[] GPU_LAMBDA(scalar_t x) -> scalar_t {
using T_ACC = acc_type<scalar_t, true>;
const T_ACC x_acc = static_cast<T_ACC>(x);
using opmath_t = at::opmath_type<scalar_t>;
const opmath_t x_acc = static_cast<opmath_t>(x);
return x_acc * c10::cuda::compat::tanh(c10::cuda::compat::log1p(c10::cuda::compat::exp(x_acc)));
});
});
Expand All @@ -645,14 +669,14 @@ void mish_backward_kernel(TensorIterator& iter) {
gpu_kernel(
iter,
[] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
using T_ACC = acc_type<scalar_t, true>;
const T_ACC dy_acc = static_cast<T_ACC>(dy);
const T_ACC x_acc = static_cast<T_ACC>(x);
const T_ACC s_acc =
T_ACC(1) / (T_ACC(1) + c10::cuda::compat::exp(-x_acc));
const T_ACC t_acc =
using opmath_t = at::opmath_type<scalar_t>;
const opmath_t dy_acc = static_cast<opmath_t>(dy);
const opmath_t x_acc = static_cast<opmath_t>(x);
const opmath_t s_acc =
opmath_t(1) / (opmath_t(1) + c10::cuda::compat::exp(-x_acc));
const opmath_t t_acc =
c10::cuda::compat::tanh(c10::cuda::compat::log1p(c10::cuda::compat::exp(x_acc)));
return dy_acc * (t_acc + x_acc * s_acc * (T_ACC(1) - t_acc * t_acc));
return dy_acc * (t_acc + x_acc * s_acc * (opmath_t(1) - t_acc * t_acc));
});
});
}
Expand Down

0 comments on commit 8da4993

Please sign in to comment.