[go: nahoru, domu]

Simplify Checked Add, Sub, and Mul implementations

Simplify some checked arithmetic implementations to reduce both source and
generated code size.

Review-Url: https://codereview.chromium.org/2607293002
Cr-Commit-Position: refs/heads/master@{#441201}
diff --git a/base/numerics/safe_conversions_impl.h b/base/numerics/safe_conversions_impl.h
index df7130d..6fd063b1 100644
--- a/base/numerics/safe_conversions_impl.h
+++ b/base/numerics/safe_conversions_impl.h
@@ -67,42 +67,13 @@
       (static_cast<UnsignedT>(x) ^ -SignedT(is_negative)) + is_negative);
 }
 
-// Wrapper for the sign mask used in the absolute value function.
+// This performs a safe, absolute value via unsigned overflow.
 template <typename T>
-constexpr T SignMask(T x) {
-  using SignedT = typename std::make_signed<T>::type;
-  // Right shift on a signed number is implementation defined, but it's often
-  // implemented as arithmetic shift. If the compiler uses an arithmetic shift,
-  // then use that to avoid the extra negation.
-  return static_cast<T>(
-      (static_cast<SignedT>(-1) >> PositionOfSignBit<T>::value) ==
-              static_cast<SignedT>(-1)
-          ? (static_cast<SignedT>(x) >> PositionOfSignBit<T>::value)
-          : -static_cast<SignedT>(static_cast<SignedT>(x) < 0));
-}
-static_assert(SignMask(-2) == -1,
-              "Inconsistent handling of signed right shift.");
-static_assert(SignMask(-3L) == -1L,
-              "Inconsistent handling of signed right shift.");
-static_assert(SignMask(-4LL) == -1LL,
-              "Inconsistent handling of signed right shift.");
-
-// This performs a safe, non-branching absolute value via unsigned overflow.
-template <typename T,
-          typename std::enable_if<std::is_integral<T>::value &&
-                                  std::is_signed<T>::value>::type* = nullptr>
 constexpr typename std::make_unsigned<T>::type SafeUnsignedAbs(T value) {
+  static_assert(std::is_integral<T>::value, "Type must be integral");
   using UnsignedT = typename std::make_unsigned<T>::type;
-  return static_cast<T>(static_cast<UnsignedT>(value ^ SignMask(value)) -
-                        static_cast<UnsignedT>(SignMask(value)));
-}
-
-template <typename T,
-          typename std::enable_if<std::is_integral<T>::value &&
-                                  !std::is_signed<T>::value>::type* = nullptr>
-constexpr T SafeUnsignedAbs(T value) {
-  // T is unsigned, so |value| must already be positive.
-  return static_cast<T>(value);
+  return IsValueNegative(value) ? 0 - static_cast<UnsignedT>(value)
+                                : static_cast<UnsignedT>(value);
 }
 
 enum IntegerRepresentation {
@@ -511,18 +482,46 @@
 // can skip the checked operations if they're not needed. So, for an integer we
 // care if the destination type preserves the sign and is twice the width of
 // the source.
-template <typename T, typename Lhs, typename Rhs>
+template <typename T, typename Lhs, typename Rhs = Lhs>
 struct IsIntegerArithmeticSafe {
   static const bool value =
       !std::is_floating_point<T>::value &&
-      StaticDstRangeRelationToSrcRange<T, Lhs>::value ==
-          NUMERIC_RANGE_CONTAINED &&
+      !std::is_floating_point<Lhs>::value &&
+      !std::is_floating_point<Rhs>::value &&
+      std::is_signed<T>::value >= std::is_signed<Lhs>::value &&
       IntegerBitsPlusSign<T>::value >= (2 * IntegerBitsPlusSign<Lhs>::value) &&
-      StaticDstRangeRelationToSrcRange<T, Rhs>::value !=
-          NUMERIC_RANGE_CONTAINED &&
+      std::is_signed<T>::value >= std::is_signed<Rhs>::value &&
       IntegerBitsPlusSign<T>::value >= (2 * IntegerBitsPlusSign<Rhs>::value);
 };
 
+// Promotes to a type that can represent any possible result of a binary
+// arithmetic operation with the source types.
+template <typename Lhs,
+          typename Rhs,
+          bool is_promotion_possible = IsIntegerArithmeticSafe<
+              typename std::conditional<std::is_signed<Lhs>::value ||
+                                            std::is_signed<Rhs>::value,
+                                        intmax_t,
+                                        uintmax_t>::type,
+              typename MaxExponentPromotion<Lhs, Rhs>::type>::value>
+struct FastIntegerArithmeticPromotion;
+
+template <typename Lhs, typename Rhs>
+struct FastIntegerArithmeticPromotion<Lhs, Rhs, true> {
+  using type =
+      typename TwiceWiderInteger<typename MaxExponentPromotion<Lhs, Rhs>::type,
+                                 std::is_signed<Lhs>::value ||
+                                     std::is_signed<Rhs>::value>::type;
+  static_assert(IsIntegerArithmeticSafe<type, Lhs, Rhs>::value, "");
+  static const bool is_contained = true;
+};
+
+template <typename Lhs, typename Rhs>
+struct FastIntegerArithmeticPromotion<Lhs, Rhs, false> {
+  using type = typename BigEnoughPromotion<Lhs, Rhs>::type;
+  static const bool is_contained = false;
+};
+
 // This hacks around libstdc++ 4.6 missing stuff in type_traits.
 #if defined(__GLIBCXX__)
 #define PRIV_GLIBCXX_4_7_0 20120322
diff --git a/base/numerics/safe_math_impl.h b/base/numerics/safe_math_impl.h
index 9a47a27a..9956115 100644
--- a/base/numerics/safe_math_impl.h
+++ b/base/numerics/safe_math_impl.h
@@ -51,9 +51,9 @@
 #define USE_OVERFLOW_BUILTINS (0)
 #endif
 
-template <typename T,
-          typename std::enable_if<std::is_integral<T>::value>::type* = nullptr>
+template <typename T>
 bool CheckedAddImpl(T x, T y, T* result) {
+  static_assert(std::is_integral<T>::value, "Type must be integral");
   // Since the value of x+y is undefined if we have a signed type, we compute
   // it using the unsigned type of the same size.
   using UnsignedDst = typename std::make_unsigned<T>::type;
@@ -102,9 +102,9 @@
   }
 };
 
-template <typename T,
-          typename std::enable_if<std::is_integral<T>::value>::type* = nullptr>
+template <typename T>
 bool CheckedSubImpl(T x, T y, T* result) {
+  static_assert(std::is_integral<T>::value, "Type must be integral");
   // Since the value of x+y is undefined if we have a signed type, we compute
   // it using the unsigned type of the same size.
   using UnsignedDst = typename std::make_unsigned<T>::type;
@@ -153,54 +153,24 @@
   }
 };
 
-// Integer multiplication is a bit complicated. In the fast case we just
-// we just promote to a twice wider type, and range check the result. In the
-// slow case we need to manually check that the result won't be truncated by
-// checking with division against the appropriate bound.
-template <typename T,
-          typename std::enable_if<
-              std::is_integral<T>::value &&
-              ((IntegerBitsPlusSign<T>::value * 2) <=
-               IntegerBitsPlusSign<intmax_t>::value)>::type* = nullptr>
+template <typename T>
 bool CheckedMulImpl(T x, T y, T* result) {
-  using IntermediateType = typename TwiceWiderInteger<T>::type;
-  IntermediateType tmp =
-      static_cast<IntermediateType>(x) * static_cast<IntermediateType>(y);
-  *result = static_cast<T>(tmp);
-  return DstRangeRelationToSrcRange<T>(tmp) == RANGE_VALID;
-}
-
-template <typename T,
-          typename std::enable_if<
-              std::is_integral<T>::value && std::is_signed<T>::value &&
-              ((IntegerBitsPlusSign<T>::value * 2) >
-               IntegerBitsPlusSign<intmax_t>::value)>::type* = nullptr>
-bool CheckedMulImpl(T x, T y, T* result) {
+  static_assert(std::is_integral<T>::value, "Type must be integral");
   // Since the value of x*y is potentially undefined if we have a signed type,
   // we compute it using the unsigned type of the same size.
   using UnsignedDst = typename std::make_unsigned<T>::type;
+  using SignedDst = typename std::make_signed<T>::type;
   const UnsignedDst ux = SafeUnsignedAbs(x);
   const UnsignedDst uy = SafeUnsignedAbs(y);
   UnsignedDst uresult = static_cast<UnsignedDst>(ux * uy);
-  // This is a non-branching conditional negation.
-  const T is_negative = (x ^ y) < 0;
-  *result = static_cast<T>((uresult ^ -is_negative) + is_negative);
-  // This uses the unsigned overflow check on the absolute value, with a +1
-  // bound for a negative result.
-  return (uy == 0 ||
-          ux <= (static_cast<UnsignedDst>(std::numeric_limits<T>::max()) +
-                 is_negative) /
-                    uy);
-}
-
-template <typename T,
-          typename std::enable_if<
-              std::is_integral<T>::value && !std::is_signed<T>::value &&
-              ((IntegerBitsPlusSign<T>::value * 2) >
-               IntegerBitsPlusSign<uintmax_t>::value)>::type* = nullptr>
-bool CheckedMulImpl(T x, T y, T* result) {
-  *result = x * y;
-  return (y == 0 || x <= std::numeric_limits<T>::max() / y);
+  const bool is_negative =
+      std::is_signed<T>::value && static_cast<SignedDst>(x ^ y) < 0;
+  *result = is_negative ? 0 - uresult : uresult;
+  // We have a fast out for unsigned identity or zero on the second operand.
+  // After that it's an unsigned overflow check on the absolute value, with
+  // a +1 bound for a negative result.
+  return uy <= UnsignedDst(!std::is_signed<T>::value || is_negative) ||
+         ux <= (std::numeric_limits<T>::max() + UnsignedDst(is_negative)) / uy;
 }
 
 template <typename T, typename U, class Enable = void>
@@ -233,7 +203,7 @@
     if (kUseMaxInt)
       return !__builtin_mul_overflow(x, y, result);
 #endif
-    using Promotion = typename BigEnoughPromotion<T, U>::type;
+    using Promotion = typename FastIntegerArithmeticPromotion<T, U>::type;
     Promotion presult;
     // Fail if either operand is out of range for the promoted type.
     // TODO(jschuh): This could be made to work for a broader range of values.
@@ -256,9 +226,9 @@
 
 // Division just requires a check for a zero denominator or an invalid negation
 // on signed min/-1.
-template <typename T,
-          typename std::enable_if<std::is_integral<T>::value>::type* = nullptr>
+template <typename T>
 bool CheckedDivImpl(T x, T y, T* result) {
+  static_assert(std::is_integral<T>::value, "Type must be integral");
   if (y && (!std::is_signed<T>::value ||
             x != std::numeric_limits<T>::lowest() || y != static_cast<T>(-1))) {
     *result = x / y;
@@ -291,9 +261,9 @@
   }
 };
 
-template <typename T,
-          typename std::enable_if<std::is_integral<T>::value>::type* = nullptr>
+template <typename T>
 bool CheckedModImpl(T x, T y, T* result) {
+  static_assert(std::is_integral<T>::value, "Type must be integral");
   if (y > 0) {
     *result = static_cast<T>(x % y);
     return true;
diff --git a/base/numerics/safe_numerics_unittest.cc b/base/numerics/safe_numerics_unittest.cc
index 5304593..58fac24 100644
--- a/base/numerics/safe_numerics_unittest.cc
+++ b/base/numerics/safe_numerics_unittest.cc
@@ -79,6 +79,57 @@
 
 namespace base {
 namespace internal {
+
+// Test corner case promotions used
+static_assert(IsIntegerArithmeticSafe<int32_t, int8_t, int8_t>::value, "");
+static_assert(IsIntegerArithmeticSafe<int32_t, int16_t, int8_t>::value, "");
+static_assert(IsIntegerArithmeticSafe<int32_t, int8_t, int16_t>::value, "");
+static_assert(!IsIntegerArithmeticSafe<int32_t, int32_t, int8_t>::value, "");
+static_assert(BigEnoughPromotion<int16_t, int8_t>::is_contained, "");
+static_assert(BigEnoughPromotion<int32_t, uint32_t>::is_contained, "");
+static_assert(BigEnoughPromotion<intmax_t, int8_t>::is_contained, "");
+static_assert(!BigEnoughPromotion<uintmax_t, int8_t>::is_contained, "");
+static_assert(
+    std::is_same<BigEnoughPromotion<int16_t, int8_t>::type, int16_t>::value,
+    "");
+static_assert(
+    std::is_same<BigEnoughPromotion<int32_t, uint32_t>::type, int64_t>::value,
+    "");
+static_assert(
+    std::is_same<BigEnoughPromotion<intmax_t, int8_t>::type, intmax_t>::value,
+    "");
+static_assert(
+    std::is_same<BigEnoughPromotion<uintmax_t, int8_t>::type, uintmax_t>::value,
+    "");
+static_assert(BigEnoughPromotion<int16_t, int8_t>::is_contained, "");
+static_assert(BigEnoughPromotion<int32_t, uint32_t>::is_contained, "");
+static_assert(BigEnoughPromotion<intmax_t, int8_t>::is_contained, "");
+static_assert(!BigEnoughPromotion<uintmax_t, int8_t>::is_contained, "");
+static_assert(
+    std::is_same<FastIntegerArithmeticPromotion<int16_t, int8_t>::type,
+                 int32_t>::value,
+    "");
+static_assert(
+    std::is_same<FastIntegerArithmeticPromotion<int32_t, uint32_t>::type,
+                 int64_t>::value,
+    "");
+static_assert(
+    std::is_same<FastIntegerArithmeticPromotion<intmax_t, int8_t>::type,
+                 intmax_t>::value,
+    "");
+static_assert(
+    std::is_same<FastIntegerArithmeticPromotion<uintmax_t, int8_t>::type,
+                 uintmax_t>::value,
+    "");
+static_assert(FastIntegerArithmeticPromotion<int16_t, int8_t>::is_contained,
+              "");
+static_assert(FastIntegerArithmeticPromotion<int32_t, uint32_t>::is_contained,
+              "");
+static_assert(!FastIntegerArithmeticPromotion<intmax_t, int8_t>::is_contained,
+              "");
+static_assert(!FastIntegerArithmeticPromotion<uintmax_t, int8_t>::is_contained,
+              "");
+
 template <typename U>
 U GetNumericValueForTest(const CheckedNumeric<U>& src) {
   return src.state_.value();
@@ -166,6 +217,14 @@
   TEST_EXPECTED_FAILURE(CheckedNumeric<Dst>(DstLimits::lowest()) / -1);
   TEST_EXPECTED_VALUE(0, CheckedNumeric<Dst>(-1) / 2);
   TEST_EXPECTED_FAILURE(CheckedNumeric<Dst>(DstLimits::lowest()) * -1);
+  TEST_EXPECTED_VALUE(DstLimits::max(),
+                      CheckedNumeric<Dst>(DstLimits::lowest() + 1) * Dst(-1));
+  TEST_EXPECTED_VALUE(DstLimits::max(),
+                      CheckedNumeric<Dst>(-1) * Dst(DstLimits::lowest() + 1));
+  TEST_EXPECTED_VALUE(DstLimits::lowest(),
+                      CheckedNumeric<Dst>(DstLimits::lowest()) * Dst(1));
+  TEST_EXPECTED_VALUE(DstLimits::lowest(),
+                      CheckedNumeric<Dst>(1) * Dst(DstLimits::lowest()));
   TEST_EXPECTED_VALUE(DstLimits::lowest(),
                       MakeCheckedNum(DstLimits::lowest()).UnsignedAbs());
   TEST_EXPECTED_VALUE(DstLimits::max(),