[go: nahoru, domu]

Skip to content

Commit

Permalink
optimize uint256 shr
Browse files Browse the repository at this point in the history
The standard library version of this function it uses
     `uint256_unsigned_div_rem`, which is quite slow. The new version
     primarily uses bitwise operations and a much faster felt
     `unsigned_div_rem`.

That shaves ~300 steps from every call.
  • Loading branch information
temyurchenko committed Dec 16, 2021
1 parent ddc9d6a commit 2616119
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 14 deletions.
11 changes: 6 additions & 5 deletions tests/benchmark/IntegralMath.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func mulModMax{range_check_ptr}(x : Uint256, y : Uint256) -> (res : Uint256):
return uint256_mulmod(x, y, Uint256(MAX_VAL, MAX_VAL))
end

func block_0{range_check_ptr}(n : Uint256, r : felt) -> (n : Uint256, r : felt):
func block_0{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(n : Uint256, r : felt) -> (n : Uint256, r : felt):
alloc_locals
let (_cond : Uint256) = is_gt(n, Uint256(1, 0))
if _cond.low + _cond.low == 0:
Expand Down Expand Up @@ -158,7 +158,7 @@ func block_6{range_check_ptr}(
return (x, n)
end

func block_3{range_check_ptr}(n : Uint256, x : Uint256, y : Uint256) -> (n : Uint256, x : Uint256):
func block_3{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(n : Uint256, x : Uint256, y : Uint256) -> (n : Uint256, x : Uint256):
alloc_locals
let (_cond_0 : Uint256) = is_gt(y, Uint256(0, 0))
if _cond_0.low + _cond_0.high == 0:
Expand All @@ -178,7 +178,7 @@ end

# @dev Compute the largest integer smaller than or equal to the cubic root of `n`
@external
func floorCbrt{range_check_ptr}(n : Uint256) -> (res : Uint256):
func floorCbrt{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(n : Uint256) -> (res : Uint256):
alloc_locals
let x : Uint256 = Uint256(0, 0)
let (y : Uint256) = u256_shl(Uint256(255, 0), Uint256(1, 0))
Expand All @@ -188,7 +188,7 @@ end

# @dev Compute the smallest integer larger than or equal to the cubic root of `n`
@external
func ceilCbrt{range_check_ptr}(n : Uint256) -> (res : Uint256):
func ceilCbrt{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(n : Uint256) -> (res : Uint256):
alloc_locals
let (x : Uint256) = floorCbrt(n)
let (_sq : Uint256) = u256_mul(x, x)
Expand All @@ -197,7 +197,8 @@ func ceilCbrt{range_check_ptr}(n : Uint256) -> (res : Uint256):
if _cond.low + _cond.high != 0:
return (res=x)
end
return u256_add(x, Uint256(1, 0))
let (res) = u256_add(x, Uint256(1, 0))
return (res)
end

# @dev Compute the nearest integer to the quotient of `n` and `d` (or `n / d`)
Expand Down
44 changes: 36 additions & 8 deletions warp/cairo-src/evm/uint256.cairo
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from starkware.cairo.common.bitwise import bitwise_and
from starkware.cairo.common.bitwise import bitwise_and, bitwise_not
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin
from starkware.cairo.common.math import assert_not_zero
from starkware.cairo.common.math import assert_not_zero, unsigned_div_rem
from starkware.cairo.common.math_cmp import is_le
from starkware.cairo.common.pow import pow
from starkware.cairo.common.uint256 import (
Uint256, uint256_add, uint256_cond_neg, uint256_eq, uint256_lt, uint256_mul, uint256_pow2,
uint256_shl, uint256_shr, uint256_signed_div_rem, uint256_signed_lt, uint256_sub,
uint256_unsigned_div_rem)
uint256_shl, uint256_signed_div_rem, uint256_signed_lt, uint256_sub, uint256_unsigned_div_rem)

const UINT128_BOUND = 2 ** 128

func u256_add{range_check_ptr}(x : Uint256, y : Uint256) -> (result : Uint256):
let (result : Uint256, _) = uint256_add(x, y)
Expand All @@ -28,9 +30,35 @@ func u256_div{range_check_ptr}(x : Uint256, y : Uint256) -> (result : Uint256):
end

# THE ORDER OF ARGUMENTS IS REVERSED, LIKE IN YUL
func u256_shr{range_check_ptr}(x : Uint256, y : Uint256) -> (result : Uint256):
let (result : Uint256) = uint256_shr(y, x)
return (result=result)
func u256_shr{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(i : Uint256, a : Uint256) -> (
result : Uint256):
if i.high != 0:
return (Uint256(0, 0))
end
let (le_127) = is_le(i.low, 127)
if le_127 == 1:
# (h', l') := (h, l) >> i
# p := 2^i
# l' = ((h & (p-1)) << (128 - i)) + ((l&~(p-1)) >> i)
# = ((h & (p-1)) << 128 >> i) + ((l&~(p-1)) >> i)
# = (h & (p-1)) * 2^128 / p + (l&~(p-1)) / p
# = (h & (p-1) * 2^128 + l&~(p-1)) / p
# h' = h >> i = (h - h&(p-1)) / p
let (p) = pow(2, i.low)
let (low_mask) = bitwise_not(p - 1)
let (low_part) = bitwise_and(a.low, low_mask)
let (high_part) = bitwise_and(a.high, p - 1)
return (
Uint256(low=(low_part + UINT128_BOUND * high_part) / p, high=(a.high - high_part) / p))
end
let (le_255) = is_le(i.low, 255)
if le_255 == 1:
let (p) = pow(2, i.low - 128)
let (mask) = bitwise_not(p - 1)
let (res) = bitwise_and(a.high, mask)
return (Uint256(res / p, 0))
end
return (Uint256(0, 0))
end

# THE ORDER OF ARGUMENTS IS REVERSED, LIKE IN YUL
Expand Down Expand Up @@ -116,7 +144,7 @@ func uint256_byte{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(a : Uint256, i
res : Uint256):
let (i, _) = uint256_mul(i, cast((8, 0), Uint256))
let (i) = uint256_sub(cast((248, 0), Uint256), i)
let (res) = uint256_shr(a, i)
let (res) = u256_shr(i, a)
let (low) = bitwise_and(res.low, 255)
return (res=cast((low, 0), Uint256))
end
Expand Down
6 changes: 5 additions & 1 deletion warp/yul/BuiltinHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,11 @@ def __init__(self):

class Shr(StaticHandler):
def __init__(self):
super().__init__(function_name="u256_shr", module="evm.uint256")
super().__init__(
function_name="u256_shr",
module="evm.uint256",
used_implicits=("range_check_ptr", "bitwise_ptr"),
)


class Sar(StaticHandler):
Expand Down

0 comments on commit 2616119

Please sign in to comment.