1from opcodes import AND, SHL, SHR
2from rule import Rule
3from z3 import BitVec, BitVecVal, If, Int2BV, IntVal, UGT, ULT
4
5"""
6Rule:
7mask = shlWorkaround(u256(-1), unsigned(A.d())) >> unsigned(B.d())
8SHR(B, SHL(A, X)) -> AND(SH[L/R]([B - A / A - B], X), Mask)
9Requirements:
10A < BitWidth
11B < BitWidth
12"""
13
14rule = Rule()
15
16n_bits = 64
17
18# Input vars
19X = BitVec('X', n_bits)
20A = BitVec('A', n_bits)
21B = BitVec('B', n_bits)
22
23# Constants
24BitWidth = BitVecVal(n_bits, n_bits)
25
26# Requirements
27rule.require(ULT(A, BitWidth))
28rule.require(ULT(B, BitWidth))
29
30# Non optimized result
31nonopt = SHR(B, SHL(A, X))
32
33# Optimized result
34Mask = SHR(B, SHL(A, Int2BV(IntVal(-1), n_bits)))
35opt = If(
36	UGT(A, B),
37	AND(SHL(A - B, X), Mask),
38		If(
39			UGT(B, A),
40			AND(SHR(B - A, X), Mask),
41			AND(X, Mask)
42		)
43	)
44
45rule.check(nonopt, opt)
46