diff --git a/service/constant-propagation/SignedConstantDomain.cpp b/service/constant-propagation/SignedConstantDomain.cpp index 8ae48de573..eb2e43ba56 100644 --- a/service/constant-propagation/SignedConstantDomain.cpp +++ b/service/constant-propagation/SignedConstantDomain.cpp @@ -21,12 +21,13 @@ enum class SignedConstantDomain::BitShiftMask : int32_t { Long = 0x3f, }; -namespace signed_constant_domain { +namespace signed_constant_domain_internal { // TODO(T222824773): Remove this. bool enable_bitset = true; // TODO(T236830337): Remove this. bool enable_low6bits = false; -} // namespace signed_constant_domain + +} // namespace signed_constant_domain_internal SignedConstantDomain& SignedConstantDomain::left_shift_bits_int(int32_t shift) { return left_shift_bits(shift, BitShiftMask::Int); diff --git a/service/constant-propagation/SignedConstantDomain.h b/service/constant-propagation/SignedConstantDomain.h index 7c94575b90..dec253108e 100644 --- a/service/constant-propagation/SignedConstantDomain.h +++ b/service/constant-propagation/SignedConstantDomain.h @@ -45,215 +45,368 @@ inline NumericIntervalDomain numeric_interval_domain_from_int(int64_t min, } } -class SignedConstantDomain; -std::ostream& operator<<(std::ostream& o, const SignedConstantDomain& scd); - -namespace signed_constant_domain { +namespace signed_constant_domain_internal { // TODO(T222824773): Remove this. extern bool enable_bitset; // TODO(T236830337): Remove this. extern bool enable_low6bits; -} // namespace signed_constant_domain -// This class effectively implements a -// ReducedProductAbstractDomain -class SignedConstantDomain final - : public sparta::AbstractDomain { +struct Bounds final { + Bounds() = delete; // Use bottom() or top() instead. + private: static constexpr int64_t MIN = std::numeric_limits::min(); static constexpr int64_t MAX = std::numeric_limits::max(); - struct Bounds final { - bool is_nez; - int64_t l; - int64_t u; - bool operator==(const Bounds& other) const { - return is_nez == other.is_nez && l == other.l && u == other.u; - } - bool operator<=(const Bounds& other) const { - return this->is_bottom() || - (other.l <= l && u <= other.u && - static_cast(other.is_nez) <= static_cast(is_nez)); - } - // Partial order only. Use <= instead. - bool operator>(const Bounds&) = delete; - bool operator<(const Bounds&) = delete; - bool operator>=(const Bounds&) = delete; - - bool is_constant() const { return l == u; } - bool is_top() const { return *this == top(); } - bool is_bottom() const { return *this == bottom(); } - void normalize() { - if (is_nez) { - if (l == 0) { - l++; - } - if (u == 0) { - u--; - } + constexpr Bounds(bool is_nez, int64_t l, int64_t u) + : is_nez(is_nez), l(l), u(u) {} + + public: + bool is_nez; + int64_t l; + int64_t u; + bool operator==(const Bounds& other) const { + return is_nez == other.is_nez && l == other.l && u == other.u; + } + bool operator<=(const Bounds& other) const { + return this->is_bottom() || + (other.l <= l && u <= other.u && + static_cast(other.is_nez) <= static_cast(is_nez)); + } + // Partial order only. Use <= instead. + bool operator>(const Bounds&) = delete; + bool operator<(const Bounds&) = delete; + bool operator>=(const Bounds&) = delete; + + bool is_constant() const { return l == u; } + bool is_top() const { return *this == top(); } + bool is_bottom() const { return *this == bottom(); } + void normalize() { + if (is_nez) { + if (l == 0) { + l++; } - if (u < l) { - this->set_to_bottom(); + if (u == 0) { + u--; } - always_assert(is_normalized()); } - // Is the constant not within the bounds? - bool unequals_constant(int64_t integer) const { - return (integer == 0 && is_nez) || integer < l || u < integer; + if (u < l) { + this->set_to_bottom(); } - bool is_normalized() const { - // bottom has a particular shape - if (u < l) { - return this->is_bottom(); - } - // nez cannot be set if 0 is a lower or upper bound - if (l == 0 || u == 0) { - return !is_nez; - } - // nez must be set if 0 is not in range - return (l <= 0 && u >= 0) || is_nez; + always_assert(is_normalized()); + } + // Is the constant not within the bounds? + bool unequals_constant(int64_t integer) const { + return (integer == 0 && is_nez) || integer < l || u < integer; + } + bool is_normalized() const { + // bottom has a particular shape + if (u < l) { + return this->is_bottom(); + } + // nez cannot be set if 0 is a lower or upper bound + if (l == 0 || u == 0) { + return !is_nez; } + // nez must be set if 0 is not in range + return (l <= 0 && u >= 0) || is_nez; + } - // Is the bounds known to be NEZ and nothing more? - bool is_nez_only() const { return is_nez && l == MIN && u == MAX; } + // Is the bounds known to be NEZ and nothing more? + bool is_nez_only() const { return is_nez && l == MIN && u == MAX; } - Bounds& set_to_top() { - *this = top(); - return *this; - } + Bounds& set_to_top() { + *this = top(); + return *this; + } - Bounds& set_to_bottom() { - *this = bottom(); - return *this; - } + Bounds& set_to_bottom() { + *this = bottom(); + return *this; + } - Bounds& join_with(const Bounds& that) { - l = std::min(l, that.l); - u = std::max(u, that.u); - is_nez &= that.is_nez; - always_assert(is_normalized()); - return *this; - } + Bounds& join_with(const Bounds& that) { + l = std::min(l, that.l); + u = std::max(u, that.u); + is_nez &= that.is_nez; + always_assert(is_normalized()); + return *this; + } - Bounds& meet_with(const Bounds& that) { - l = std::max(l, that.l); - u = std::min(u, that.u); - is_nez |= that.is_nez; - normalize(); - return *this; - } + Bounds& meet_with(const Bounds& that) { + l = std::max(l, that.l); + u = std::min(u, that.u); + is_nez |= that.is_nez; + normalize(); + return *this; + } - static Bounds from_interval(sign_domain::Interval interval) { - switch (interval) { - case sign_domain::Interval::EMPTY: - return bottom(); - case sign_domain::Interval::EQZ: - return {false, 0, 0}; - case sign_domain::Interval::LEZ: - return {false, MIN, 0}; - case sign_domain::Interval::LTZ: - return {true, MIN, -1}; - case sign_domain::Interval::GEZ: - return {false, 0, MAX}; - case sign_domain::Interval::GTZ: - return {true, 1, MAX}; - case sign_domain::Interval::ALL: - return top(); - case sign_domain::Interval::NEZ: - return nez(); - case sign_domain::Interval::SIZE: - break; - } - not_reached(); - } + static Bounds from_interval(sign_domain::Interval interval) { + switch (interval) { + case sign_domain::Interval::EMPTY: + return bottom(); + case sign_domain::Interval::EQZ: + return {false, 0, 0}; + case sign_domain::Interval::LEZ: + return {false, MIN, 0}; + case sign_domain::Interval::LTZ: + return {true, MIN, -1}; + case sign_domain::Interval::GEZ: + return {false, 0, MAX}; + case sign_domain::Interval::GTZ: + return {true, 1, MAX}; + case sign_domain::Interval::ALL: + return top(); + case sign_domain::Interval::NEZ: + return nez(); + case sign_domain::Interval::SIZE: + break; + } + not_reached(); + } - static Bounds from_integer(int64_t integer) { - return Bounds{integer != 0, integer, integer}; - } + static Bounds from_range(int64_t min, int64_t max) { + always_assert(min <= max); + return Bounds{min > 0 || max < 0, min, max}; + } - static constexpr Bounds top() { - constexpr Bounds res{false, MIN, MAX}; - return res; - } - static constexpr Bounds bottom() { - constexpr Bounds res{true, MAX, MIN}; - return res; - } - static constexpr Bounds nez() { - constexpr Bounds res{true, MIN, MAX}; - return res; - } - }; - Bounds m_bounds; + static Bounds from_integer(int64_t integer) { + return from_range(integer, integer); + } - /** Encodes the combination the lowest 6 bits of an integer. - * - * This class contains a 64-bit integer, low6bits_state, that tracks possible - * states of the lowest 6 bits. There are 64 bits in low6bits_state, each - * represents a value of the lowest 6 bits of the integer represented by the - * domain. The n'th bit in low6bits_state being 1 implies that the lowest 6 - * bits of the integer (or %64 in arithmetic terms) may still be possibly n - * conservatively, i.e., redex can't prove that the lowest 64 bits unequals n. - * On the other hand, the n'th bit being 0 implies the lowest 6 bits of the - * integer can be proven to unequal n. - * - * For example, if low6bits_state is 0b101, it means that the lowest 6 bits of - * the represented integer can be either decimal 0 or 2. - * - * This class costs 8 bytes of RAM. - */ - class Low6Bits final { - private: - uint64_t low6bits_state{std::numeric_limits::max()}; + static constexpr Bounds top() { + constexpr Bounds res{false, MIN, MAX}; + return res; + } + static constexpr Bounds bottom() { + constexpr Bounds res{true, MAX, MIN}; + return res; + } + static constexpr Bounds nez() { + constexpr Bounds res{true, MIN, MAX}; + return res; + } +}; - public: - Low6Bits() = default; - Low6Bits(const Low6Bits&) = default; - Low6Bits& operator=(const Low6Bits&) = default; - Low6Bits(Low6Bits&&) = default; - Low6Bits& operator=(Low6Bits&&) = default; - explicit Low6Bits(int64_t value) - : low6bits_state(static_cast(1u) << (value & 63)) {} - - bool operator==(const Low6Bits& that) const { - return low6bits_state == that.low6bits_state; - } +/** Encodes the combination the lowest 6 bits of an integer. + * + * This class contains a 64-bit integer, low6bits_state, that tracks possible + * states of the lowest 6 bits. There are 64 bits in low6bits_state, each + * represents a value of the lowest 6 bits of the integer represented by the + * domain. The n'th bit in low6bits_state being 1 implies that the lowest 6 + * bits of the integer (or %64 in arithmetic terms) may still be possibly n + * conservatively, i.e., redex can't prove that the lowest 64 bits unequals n. + * On the other hand, the n'th bit being 0 implies the lowest 6 bits of the + * integer can be proven to unequal n. + * + * For example, if low6bits_state is 0b101, it means that the lowest 6 bits of + * the represented integer can be either decimal 0 or 2. + * + * This class costs 8 bytes of RAM. + */ +class Low6Bits final { + private: + uint64_t low6bits_state{std::numeric_limits::max()}; - bool operator<=(const Low6Bits& that) const { - return (low6bits_state & ~that.low6bits_state) == 0; - } + public: + Low6Bits() = default; + Low6Bits(const Low6Bits&) = default; + Low6Bits& operator=(const Low6Bits&) = default; + Low6Bits(Low6Bits&&) = default; + Low6Bits& operator=(Low6Bits&&) = default; + explicit Low6Bits(int64_t value) + : low6bits_state(static_cast(1u) << (value & 63)) {} + + bool operator==(const Low6Bits& that) const { + return low6bits_state == that.low6bits_state; + } - bool is_bottom() const { return low6bits_state == 0u; } - bool is_top() const { return *this == top(); } + bool operator<=(const Low6Bits& that) const { + return (low6bits_state & ~that.low6bits_state) == 0; + } - constexpr void set_to_bottom() { low6bits_state = 0u; } - void set_to_top() { *this = top(); } + bool is_bottom() const { return low6bits_state == 0u; } + bool is_top() const { return *this == top(); } - Low6Bits& join_with(const Low6Bits& that) { - low6bits_state |= that.low6bits_state; - return *this; - } + constexpr void set_to_bottom() { low6bits_state = 0u; } + void set_to_top() { *this = top(); } - Low6Bits& meet_with(const Low6Bits& that) { - low6bits_state &= that.low6bits_state; - return *this; + Low6Bits& join_with(const Low6Bits& that) { + low6bits_state |= that.low6bits_state; + return *this; + } + + Low6Bits& meet_with(const Low6Bits& that) { + low6bits_state &= that.low6bits_state; + return *this; + } + + bool unequals_constant(int64_t integer) const { + return (low6bits_state & (static_cast(1u) << (integer & 63))) == + 0u; + } + + uint64_t get_low6bits_state() const { return low6bits_state; } + + static constexpr Low6Bits top() { return Low6Bits(); } + + static constexpr Low6Bits bottom() { + Low6Bits res; + res.set_to_bottom(); + return res; + } +}; + +class Bitset final { + // We use two integers to represent the state of each bit. A bit of + // one_bit_states/zero_bit_states being one means that the corresponding bit + // of the integer can possibly be one/zero. Hence, if the same bits of both + // are one, it means that the bit can be either one or zero, i.e., top for + // that bit. If any bit is zero in both integers, then the bitset is bottom. + // Use uint64_t instead of int64_t to avoid undefined behavior with >>. + // + // For 32-bit integers, the high 32 bits should be the same as the highest + // bit of the lower 32 bits, i.e., the sign bit of the integer. In this way, + // we achieve consistency with SignedConstantDomain initialized from a + // constant. + uint64_t one_bit_states{std::numeric_limits::max()}; + uint64_t zero_bit_states{std::numeric_limits::max()}; + + constexpr void set_all_to(bool zero, bool one) { + zero_bit_states = zero ? std::numeric_limits::max() : 0u; + one_bit_states = one ? std::numeric_limits::max() : 0u; + } + + // Construct with all bits set to a given bit state. + constexpr Bitset(bool zero, bool one) { set_all_to(zero, one); } + + public: + Bitset() = default; + Bitset(const Bitset& that) = default; + Bitset(Bitset&& that) = default; + Bitset& operator=(const Bitset& that) = default; + Bitset& operator=(Bitset&& that) = default; + + bool operator==(const Bitset& that) const { + return (one_bit_states == that.one_bit_states && + zero_bit_states == that.zero_bit_states) || + (is_bottom() && that.is_bottom()); + } + + bool operator<=(const Bitset& that) const { + if (is_bottom()) { + return true; } + return ((one_bit_states | that.one_bit_states) == that.one_bit_states && + (zero_bit_states | that.zero_bit_states) == that.zero_bit_states); + } - bool unequals_constant(int64_t integer) const { - return (low6bits_state & (static_cast(1u) << (integer & 63))) == - 0u; + // Partial order only. Use <= instead. + bool operator>(const Bitset&) = delete; + bool operator<(const Bitset&) = delete; + bool operator>=(const Bitset&) = delete; + + uint64_t get_one_bit_states() const { return one_bit_states; } + uint64_t get_zero_bit_states() const { return zero_bit_states; } + + // Construct from a constant. + explicit Bitset(int64_t value) { + one_bit_states = static_cast(value); + zero_bit_states = ~one_bit_states; + } + + bool is_constant() const { return get_constant().has_value(); } + + std::optional get_constant() const { + if (~one_bit_states == zero_bit_states) { + return static_cast(one_bit_states); + } else { + return std::nullopt; } + } - uint64_t get_low6bits_state() const { return low6bits_state; } + void set_to_bottom() { set_all_to(false, false); } + void set_to_top() { set_all_to(true, true); } + Bitset& join_with(const Bitset& that) { + one_bit_states |= that.one_bit_states; + zero_bit_states |= that.zero_bit_states; + return *this; + } - static constexpr Low6Bits top() { return Low6Bits(); } + Bitset& meet_with(const Bitset& that) { + one_bit_states &= that.one_bit_states; + zero_bit_states &= that.zero_bit_states; + return *this; + } - static constexpr Low6Bits bottom() { - Low6Bits res; - res.set_to_bottom(); - return res; + bool is_bottom() const { + // We don't use a single representation for bottom. Always use this + // function to check if it's bottom. It's bottom if any bit is zero in + // both integers. + return (one_bit_states | zero_bit_states) != + std::numeric_limits::max(); + } + + bool is_top() const { return *this == top(); } + + // Get the bit that can be determined to be one or zero. The returned + // integer hosts all bits that are determined to be zero/one. + uint64_t get_determined_zero_bits() const { + return zero_bit_states & ~one_bit_states; + } + uint64_t get_determined_one_bits() const { + return one_bit_states & ~zero_bit_states; + } + + // Set particular bits to be known to be zero/one. + Bitset& set_determined_zero_bits(uint64_t bits) { + one_bit_states &= ~bits; + zero_bit_states |= bits; + return *this; + } + Bitset& set_determined_one_bits(uint64_t bits) { + one_bit_states |= bits; + zero_bit_states &= ~bits; + return *this; + } + + // Is the constant unrepresentable by the bitset? + bool unequals_constant(int64_t integer) const { + const auto determinable_one_bits = get_determined_one_bits(); + if ((determinable_one_bits & integer) != determinable_one_bits) { + return true; } - }; + + const auto determinable_zero_bits = get_determined_zero_bits(); + return (determinable_zero_bits & ~integer) != determinable_zero_bits; + } + + static constexpr Bitset bottom() { + constexpr Bitset res(false, false); + return res; + } + static constexpr Bitset top() { + constexpr Bitset res(true, true); + return res; + } +}; + +} // namespace signed_constant_domain_internal + +class SignedConstantDomain; +std::ostream& operator<<(std::ostream& o, const SignedConstantDomain& scd); + +// This class effectively implements a +// ReducedProductAbstractDomain and bitset +// and low6bit domains. +class SignedConstantDomain final + : public sparta::AbstractDomain { + private: + using Bounds = signed_constant_domain_internal::Bounds; + using Low6Bits = signed_constant_domain_internal::Low6Bits; + using Bitset = signed_constant_domain_internal::Bitset; + + Bounds m_bounds; // TODO(T236830337): Remove OptionalLow6Bits. // Not using the abstract class/inheritence pattern to avoid heap allocation. @@ -271,15 +424,17 @@ class SignedConstantDomain final OptionalLow6Bits& operator=(OptionalLow6Bits&&) = default; explicit OptionalLow6Bits(const Low6Bits& l6b = Low6Bits::top()) - : low6bits(signed_constant_domain::enable_low6bits ? Low6BitsType(l6b) - : std::nullopt) {} + : low6bits(signed_constant_domain_internal::enable_low6bits + ? Low6BitsType(l6b) + : std::nullopt) {} explicit OptionalLow6Bits(int64_t value) - : low6bits(signed_constant_domain::enable_low6bits ? Low6BitsType(value) - : std::nullopt) {} + : low6bits(signed_constant_domain_internal::enable_low6bits + ? Low6BitsType(value) + : std::nullopt) {} OptionalLow6Bits& operator=(const Low6Bits& bs) { - if (signed_constant_domain::enable_low6bits) { + if (signed_constant_domain_internal::enable_low6bits) { low6bits = bs; } return *this; @@ -364,149 +519,6 @@ class SignedConstantDomain final OptionalLow6Bits m_low6bits; - class Bitset final { - // Albeit unusual, make sure the compiler uses 2's complement to represent - // int64, upon which much of this class relies. - static_assert( - static_cast(static_cast(-1)) == - std::numeric_limits::max() && - static_cast(std::numeric_limits::max()) == - static_cast(-1), - "Unsupported compiler: int64_t is not represented as 2's complement"); - - // We use two integers to represent the state of each bit. A bit of - // one_bit_states/zero_bit_states being one means that the corresponding bit - // of the integer can possibly be one/zero. Hence, if the same bits of both - // are one, it means that the bit can be either one or zero, i.e., top for - // that bit. If any bit is zero in both integers, then the bitset is bottom. - // Use uint64_t instead of int64_t to avoid undefined behavior with >>. - // - // For 32-bit integers, the high 32 bits should be the same as the highest - // bit of the lower 32 bits, i.e., the sign bit of the integer. In this way, - // we achieve consistency with SignedConstantDomain initialized from a - // constant. - uint64_t one_bit_states{std::numeric_limits::max()}; - uint64_t zero_bit_states{std::numeric_limits::max()}; - - constexpr void set_all_to(bool zero, bool one) { - zero_bit_states = zero ? std::numeric_limits::max() : 0u; - one_bit_states = one ? std::numeric_limits::max() : 0u; - } - - // Construct with all bits set to a given bit state. - constexpr Bitset(bool zero, bool one) { set_all_to(zero, one); } - - public: - Bitset() = default; - Bitset(const Bitset& that) = default; - Bitset(Bitset&& that) = default; - Bitset& operator=(const Bitset& that) = default; - Bitset& operator=(Bitset&& that) = default; - - bool operator==(const Bitset& that) const { - return (one_bit_states == that.one_bit_states && - zero_bit_states == that.zero_bit_states) || - (is_bottom() && that.is_bottom()); - } - - bool operator<=(const Bitset& that) const { - if (is_bottom()) { - return true; - } - return ((one_bit_states | that.one_bit_states) == that.one_bit_states && - (zero_bit_states | that.zero_bit_states) == that.zero_bit_states); - } - - // Partial order only. Use <= instead. - bool operator>(const Bitset&) = delete; - bool operator<(const Bitset&) = delete; - bool operator>=(const Bitset&) = delete; - - uint64_t get_one_bit_states() const { return one_bit_states; } - uint64_t get_zero_bit_states() const { return zero_bit_states; } - - // Construct from a constant. - explicit Bitset(int64_t value) { - one_bit_states = static_cast(value); - zero_bit_states = ~one_bit_states; - } - - bool is_constant() const { return get_constant().has_value(); } - - std::optional get_constant() const { - if (~one_bit_states == zero_bit_states) { - return static_cast(one_bit_states); - } else { - return std::nullopt; - } - } - - void set_to_bottom() { set_all_to(false, false); } - void set_to_top() { set_all_to(true, true); } - Bitset& join_with(const Bitset& that) { - one_bit_states |= that.one_bit_states; - zero_bit_states |= that.zero_bit_states; - return *this; - } - - Bitset& meet_with(const Bitset& that) { - one_bit_states &= that.one_bit_states; - zero_bit_states &= that.zero_bit_states; - return *this; - } - - bool is_bottom() const { - // We don't use a single representation for bottom. Always use this - // function to check if it's bottom. It's bottom if any bit is zero in - // both integers. - return (one_bit_states | zero_bit_states) != - std::numeric_limits::max(); - } - - bool is_top() const { return *this == top(); } - - // Get the bit that can be determined to be one or zero. The returned - // integer hosts all bits that are determined to be zero/one. - uint64_t get_determined_zero_bits() const { - return zero_bit_states & ~one_bit_states; - } - uint64_t get_determined_one_bits() const { - return one_bit_states & ~zero_bit_states; - } - - // Set particular bits to be known to be zero/one. - Bitset& set_determined_zero_bits(uint64_t bits) { - one_bit_states &= ~bits; - zero_bit_states |= bits; - return *this; - } - Bitset& set_determined_one_bits(uint64_t bits) { - one_bit_states |= bits; - zero_bit_states &= ~bits; - return *this; - } - - // Is the constant unrepresentable by the bitset? - bool unequals_constant(int64_t integer) const { - const auto determinable_one_bits = get_determined_one_bits(); - if ((determinable_one_bits & integer) != determinable_one_bits) { - return true; - } - - const auto determinable_zero_bits = get_determined_zero_bits(); - return (determinable_zero_bits & ~integer) != determinable_zero_bits; - } - - static constexpr Bitset bottom() { - constexpr Bitset res(false, false); - return res; - } - static constexpr Bitset top() { - constexpr Bitset res(true, true); - return res; - } - }; - // TODO(TT222824773): Remove OptionalBitset. // Not using the abstract class/inheritence pattern to avoid heap allocation. class OptionalBitset final { @@ -523,15 +535,17 @@ class SignedConstantDomain final OptionalBitset& operator=(OptionalBitset&&) = default; explicit OptionalBitset(const Bitset& bs = Bitset::top()) - : bitset(signed_constant_domain::enable_bitset ? BitsetType(bs) - : std::nullopt) {} + : bitset(signed_constant_domain_internal::enable_bitset + ? BitsetType(bs) + : std::nullopt) {} explicit OptionalBitset(int64_t value) - : bitset(signed_constant_domain::enable_bitset ? BitsetType(value) - : std::nullopt) {} + : bitset(signed_constant_domain_internal::enable_bitset + ? BitsetType(value) + : std::nullopt) {} OptionalBitset& operator=(const Bitset& bs) { - if (signed_constant_domain::enable_bitset) { + if (signed_constant_domain_internal::enable_bitset) { bitset = bs; } return *this; @@ -673,77 +687,11 @@ class SignedConstantDomain final SignedConstantDomain(Bounds bounds, Low6Bits low6bits, Bitset bitset) : m_bounds(bounds), m_low6bits(low6bits), m_bitset(bitset) {} - // When either bounds or bitset meets (become narrower), we can possibly - // infer the other one with some info. - void cross_infer_meet_from_bounds() { - if (m_bitset.is_bottom() || m_low6bits.is_bottom()) { - always_assert(m_bounds.is_bottom()); - return; - } - - // Constant inference - if (m_bounds.is_constant()) { - if (m_bitset.unequals_constant(m_bounds.l) || - m_low6bits.unequals_constant(m_bounds.l)) { - set_to_bottom(); - return; - } - m_low6bits = Low6Bits(m_bounds.l); - m_bitset = Bitset(m_bounds.l); - return; - } - - // One is bottom, then all is bottom. - if (m_bounds.is_bottom()) { - set_to_bottom(); - return; - } - - // More cross inference can be added here... - } - - void cross_infer_meet_from_bitset() { - if (m_bounds.is_bottom() || m_low6bits.is_bottom()) { - always_assert(!signed_constant_domain::enable_bitset || - m_bitset.is_bottom()); - return; - } - - const auto bitset_constant = m_bitset.get_constant(); - - if (bitset_constant.has_value()) { - if (m_bounds.unequals_constant(*bitset_constant) || - m_low6bits.unequals_constant(*bitset_constant)) { - set_to_bottom(); - return; - } - m_bounds = Bounds::from_integer(bitset_constant.value()); - m_low6bits = Low6Bits(bitset_constant.value()); - return; - } - - // One is bottom, then all is bottom. - if (m_bitset.is_bottom()) { - set_to_bottom(); - return; - } - - // More cross inference can be added here... - } - - void cross_infer_meet_from_low6bits() { - if (m_bounds.is_bottom() || m_bitset.is_bottom()) { - always_assert(!signed_constant_domain::enable_low6bits || - m_low6bits.is_bottom()); - return; - } - - // One is bottom, then all is bottom. - if (m_low6bits.is_bottom()) { - set_to_bottom(); - return; - } - } + // When any of the domain meets (become narrower), we can possibly + // infer the other ones to some extent. + inline void cross_infer_meet_from_bounds(); + inline void cross_infer_meet_from_bitset(); + inline void cross_infer_meet_from_low6bits(); public: SignedConstantDomain() @@ -762,10 +710,9 @@ class SignedConstantDomain final } SignedConstantDomain(int64_t min, int64_t max) - : m_bounds({min > 0 || max < 0, min, max}), + : m_bounds(Bounds::from_range(min, max)), m_low6bits(Low6Bits::top()), m_bitset(Bitset::top()) { - always_assert(min <= max); cross_infer_meet_from_bounds(); } @@ -790,15 +737,16 @@ class SignedConstantDomain final return SignedConstantDomain(Bounds::nez(), Low6Bits::top(), Bitset::top()); } bool is_bottom() const { - const bool res = - m_bounds.is_bottom() || - (!signed_constant_domain::enable_low6bits && m_low6bits.is_bottom()) || - (!signed_constant_domain::enable_bitset && m_bitset.is_bottom()); + const bool res = m_bounds.is_bottom() || + (!signed_constant_domain_internal::enable_low6bits && + m_low6bits.is_bottom()) || + (!signed_constant_domain_internal::enable_bitset && + m_bitset.is_bottom()); if (res) { always_assert(m_bounds.is_bottom()); - always_assert(!signed_constant_domain::enable_low6bits || + always_assert(!signed_constant_domain_internal::enable_low6bits || m_low6bits.is_bottom()); - always_assert(!signed_constant_domain::enable_bitset || + always_assert(!signed_constant_domain_internal::enable_bitset || m_bitset.is_bottom()); } return res; @@ -861,30 +809,7 @@ class SignedConstantDomain final return sign_domain::Domain(interval()); } - sign_domain::Interval interval() const { - if (m_bounds.is_bottom()) { - return sign_domain::Interval::EMPTY; - } - if (m_bounds.l > 0) { - return sign_domain::Interval::GTZ; - } - if (m_bounds.u < 0) { - return sign_domain::Interval::LTZ; - } - if (m_bounds.l == 0) { - if (m_bounds.u == 0) { - return sign_domain::Interval::EQZ; - } - return sign_domain::Interval::GEZ; - } - if (m_bounds.u == 0) { - return sign_domain::Interval::LEZ; - } - if (m_bounds.is_nez) { - return sign_domain::Interval::NEZ; - } - return sign_domain::Interval::ALL; - } + inline sign_domain::Interval interval() const; ConstantDomain constant_domain() const { if (const auto constant = get_constant(); constant.has_value()) { @@ -910,7 +835,7 @@ class SignedConstantDomain final if (!m_bounds.is_constant()) { return boost::none; } - if (signed_constant_domain::enable_bitset) { + if (signed_constant_domain_internal::enable_bitset) { always_assert(m_bitset.is_constant() && *m_bitset.get_constant() == m_bounds.l); } @@ -955,44 +880,8 @@ class SignedConstantDomain final // setting bounds to top if either zeros or ones is provided. Useful in // inferring results of bitwise ops, which usually invalidate any existing // inferences on Bounds. - SignedConstantDomain& set_determined_bits_erasing_bounds( - std::optional zeros, std::optional ones, bool bit32) { - // No bit can be 1 in both zeros and ones - always_assert(!zeros.has_value() || !ones.has_value() || - (*ones & *zeros) == 0u); - - if (!zeros.has_value() && !ones.has_value()) { - return *this; - } - - if (zeros.has_value()) { - uint64_t new_zeros = *zeros; - if (bit32) { - if ((*zeros & 0x80000000) != 0) { // sign bit is 1 - new_zeros |= static_cast(0xffffffff00000000ul); - } else { - new_zeros &= static_cast(0x7ffffffful); - } - } - m_bitset.set_determined_zero_bits(new_zeros); - } - if (ones.has_value()) { - uint64_t new_ones = *ones; - if (bit32) { - if ((*ones & 0x80000000) != 0) { // sign bit is 1 - new_ones |= static_cast(0xffffffff00000000ul); - } else { - new_ones &= static_cast(0x7ffffffful); - } - } - m_bitset.set_determined_one_bits(new_ones); - } - - m_bounds.set_to_top(); - m_low6bits.set_to_top(); - cross_infer_meet_from_bitset(); - return *this; - } + inline SignedConstantDomain& set_determined_bits_erasing_bounds( + std::optional zeros, std::optional ones, bool bit32); uint64_t get_one_bit_states() const { return m_bitset.get_one_bit_states(); } uint64_t get_zero_bit_states() const { @@ -1026,3 +915,133 @@ class SignedConstantDomain final static_cast(std::numeric_limits::min()))); } }; + +void SignedConstantDomain::cross_infer_meet_from_bounds() { + if (m_bitset.is_bottom() || m_low6bits.is_bottom()) { + always_assert(m_bounds.is_bottom()); + return; + } + + // Constant inference + if (m_bounds.is_constant()) { + if (m_bitset.unequals_constant(m_bounds.l) || + m_low6bits.unequals_constant(m_bounds.l)) { + set_to_bottom(); + return; + } + m_low6bits = Low6Bits(m_bounds.l); + m_bitset = Bitset(m_bounds.l); + return; + } + + // One is bottom, then all is bottom. + if (m_bounds.is_bottom()) { + set_to_bottom(); + return; + } +} + +void SignedConstantDomain::cross_infer_meet_from_bitset() { + if (m_bounds.is_bottom() || m_low6bits.is_bottom()) { + always_assert(!signed_constant_domain_internal::enable_bitset || + m_bitset.is_bottom()); + return; + } + + const auto bitset_constant = m_bitset.get_constant(); + + if (bitset_constant.has_value()) { + if (m_bounds.unequals_constant(*bitset_constant) || + m_low6bits.unequals_constant(*bitset_constant)) { + set_to_bottom(); + return; + } + m_bounds = Bounds::from_integer(bitset_constant.value()); + m_low6bits = Low6Bits(bitset_constant.value()); + return; + } + + // One is bottom, then all is bottom. + if (m_bitset.is_bottom()) { + set_to_bottom(); + return; + } +} + +void SignedConstantDomain::cross_infer_meet_from_low6bits() { + if (m_bounds.is_bottom() || m_bitset.is_bottom()) { + always_assert(!signed_constant_domain_internal::enable_low6bits || + m_low6bits.is_bottom()); + return; + } + + // One is bottom, then all is bottom. + if (m_low6bits.is_bottom()) { + set_to_bottom(); + return; + } +} + +SignedConstantDomain& SignedConstantDomain::set_determined_bits_erasing_bounds( + std::optional zeros, std::optional ones, bool bit32) { + // No bit can be 1 in both zeros and ones + always_assert(!zeros.has_value() || !ones.has_value() || + (*ones & *zeros) == 0u); + + if (!zeros.has_value() && !ones.has_value()) { + return *this; + } + + if (zeros.has_value()) { + uint64_t new_zeros = *zeros; + if (bit32) { + if ((*zeros & 0x80000000) != 0) { // sign bit is 1 + new_zeros |= static_cast(0xffffffff00000000ul); + } else { + new_zeros &= static_cast(0x7ffffffful); + } + } + m_bitset.set_determined_zero_bits(new_zeros); + } + if (ones.has_value()) { + uint64_t new_ones = *ones; + if (bit32) { + if ((*ones & 0x80000000) != 0) { // sign bit is 1 + new_ones |= static_cast(0xffffffff00000000ul); + } else { + new_ones &= static_cast(0x7ffffffful); + } + } + m_bitset.set_determined_one_bits(new_ones); + } + + m_bounds.set_to_top(); + m_low6bits.set_to_top(); + cross_infer_meet_from_bitset(); + return *this; +} + +sign_domain::Interval SignedConstantDomain::interval() const { + if (m_bounds.is_bottom()) { + return sign_domain::Interval::EMPTY; + } + if (m_bounds.l > 0) { + return sign_domain::Interval::GTZ; + } + if (m_bounds.u < 0) { + return sign_domain::Interval::LTZ; + } + if (m_bounds.l == 0) { + if (m_bounds.u == 0) { + return sign_domain::Interval::EQZ; + } + return sign_domain::Interval::GEZ; + } + if (m_bounds.u == 0) { + return sign_domain::Interval::LEZ; + } + if (m_bounds.is_nez) { + return sign_domain::Interval::NEZ; + } + return sign_domain::Interval::ALL; +} diff --git a/test/common/RedexTest.h b/test/common/RedexTest.h index 088f079627..33bfd8e9f1 100644 --- a/test/common/RedexTest.h +++ b/test/common/RedexTest.h @@ -53,21 +53,21 @@ inline std::string get_env(const char* name) { return env_file; } -namespace signed_constant_domain { +namespace signed_constant_domain_internal { // TODO(T222824773): Remove this. extern bool enable_bitset; // TODO(T236830337): Remove this. extern bool enable_low6bits; -} // namespace signed_constant_domain +} // namespace signed_constant_domain_internal struct RedexTest : public testing::Test { public: RedexTest() { g_redex = new RedexContext(); // TODO(TT222824773): Remove this. - signed_constant_domain::enable_bitset = true; + signed_constant_domain_internal::enable_bitset = true; // TODO(T236830337): Remove this. - signed_constant_domain::enable_low6bits = true; + signed_constant_domain_internal::enable_low6bits = true; } ~RedexTest() { delete g_redex; } diff --git a/tools/redex-all/main.cpp b/tools/redex-all/main.cpp index bba9b8d834..908a771b46 100644 --- a/tools/redex-all/main.cpp +++ b/tools/redex-all/main.cpp @@ -1964,13 +1964,13 @@ int main(int argc, char* argv[]) { keep_reason::Reason::set_record_keep_reasons( args.config.get("record_keep_reasons", false).asBool()); - signed_constant_domain::enable_bitset = + signed_constant_domain_internal::enable_bitset = args.config.get("enable_bitset_constant_propagation", true).asBool(); - signed_constant_domain::enable_low6bits = + signed_constant_domain_internal::enable_low6bits = args.config.get("enable_low6bits_constant_propagation", false).asBool(); - always_assert_log(!signed_constant_domain::enable_low6bits || - signed_constant_domain::enable_bitset, + always_assert_log(!signed_constant_domain_internal::enable_low6bits || + signed_constant_domain_internal::enable_bitset, "enable_bitset_constant_propagation must be turned on if " "enable_low6bits_constant_propagation is turned on.");