From 545c6bcc3bb8de1e422a29bdb4fcfae9276652f0 Mon Sep 17 00:00:00 2001 From: Hong Xu Date: Wed, 12 Nov 2025 20:02:09 -0800 Subject: [PATCH] Split up SignedConstantDomain and other cleanups (#968) Summary: The class is now doing too many things. Split up for better testability and easier implementation for future optimizations. Put internal implementations under an internal namespace rather than nested (easier to test some logically complicated part as individual components later). Additionally, moved some big function definitions out of the class for better readability. Added and used `Bounds::from_range`. This is a no-op diff. Differential Revision: D86894728 --- .../SignedConstantDomain.cpp | 5 +- .../SignedConstantDomain.h | 957 +++++++++--------- test/common/RedexTest.h | 8 +- tools/redex-all/main.cpp | 8 +- 4 files changed, 499 insertions(+), 479 deletions(-) diff --git a/service/constant-propagation/SignedConstantDomain.cpp b/service/constant-propagation/SignedConstantDomain.cpp index 8ae48de573..eb2e43ba56 100644 --- a/service/constant-propagation/SignedConstantDomain.cpp +++ b/service/constant-propagation/SignedConstantDomain.cpp @@ -21,12 +21,13 @@ enum class SignedConstantDomain::BitShiftMask : int32_t { Long = 0x3f, }; -namespace signed_constant_domain { +namespace signed_constant_domain_internal { // TODO(T222824773): Remove this. bool enable_bitset = true; // TODO(T236830337): Remove this. bool enable_low6bits = false; -} // namespace signed_constant_domain + +} // namespace signed_constant_domain_internal SignedConstantDomain& SignedConstantDomain::left_shift_bits_int(int32_t shift) { return left_shift_bits(shift, BitShiftMask::Int); diff --git a/service/constant-propagation/SignedConstantDomain.h b/service/constant-propagation/SignedConstantDomain.h index 7c94575b90..dec253108e 100644 --- a/service/constant-propagation/SignedConstantDomain.h +++ b/service/constant-propagation/SignedConstantDomain.h @@ -45,215 +45,368 @@ inline NumericIntervalDomain numeric_interval_domain_from_int(int64_t min, } } -class SignedConstantDomain; -std::ostream& operator<<(std::ostream& o, const SignedConstantDomain& scd); - -namespace signed_constant_domain { +namespace signed_constant_domain_internal { // TODO(T222824773): Remove this. extern bool enable_bitset; // TODO(T236830337): Remove this. extern bool enable_low6bits; -} // namespace signed_constant_domain -// This class effectively implements a -// ReducedProductAbstractDomain -class SignedConstantDomain final - : public sparta::AbstractDomain { +struct Bounds final { + Bounds() = delete; // Use bottom() or top() instead. + private: static constexpr int64_t MIN = std::numeric_limits::min(); static constexpr int64_t MAX = std::numeric_limits::max(); - struct Bounds final { - bool is_nez; - int64_t l; - int64_t u; - bool operator==(const Bounds& other) const { - return is_nez == other.is_nez && l == other.l && u == other.u; - } - bool operator<=(const Bounds& other) const { - return this->is_bottom() || - (other.l <= l && u <= other.u && - static_cast(other.is_nez) <= static_cast(is_nez)); - } - // Partial order only. Use <= instead. - bool operator>(const Bounds&) = delete; - bool operator<(const Bounds&) = delete; - bool operator>=(const Bounds&) = delete; - - bool is_constant() const { return l == u; } - bool is_top() const { return *this == top(); } - bool is_bottom() const { return *this == bottom(); } - void normalize() { - if (is_nez) { - if (l == 0) { - l++; - } - if (u == 0) { - u--; - } + constexpr Bounds(bool is_nez, int64_t l, int64_t u) + : is_nez(is_nez), l(l), u(u) {} + + public: + bool is_nez; + int64_t l; + int64_t u; + bool operator==(const Bounds& other) const { + return is_nez == other.is_nez && l == other.l && u == other.u; + } + bool operator<=(const Bounds& other) const { + return this->is_bottom() || + (other.l <= l && u <= other.u && + static_cast(other.is_nez) <= static_cast(is_nez)); + } + // Partial order only. Use <= instead. + bool operator>(const Bounds&) = delete; + bool operator<(const Bounds&) = delete; + bool operator>=(const Bounds&) = delete; + + bool is_constant() const { return l == u; } + bool is_top() const { return *this == top(); } + bool is_bottom() const { return *this == bottom(); } + void normalize() { + if (is_nez) { + if (l == 0) { + l++; } - if (u < l) { - this->set_to_bottom(); + if (u == 0) { + u--; } - always_assert(is_normalized()); } - // Is the constant not within the bounds? - bool unequals_constant(int64_t integer) const { - return (integer == 0 && is_nez) || integer < l || u < integer; + if (u < l) { + this->set_to_bottom(); } - bool is_normalized() const { - // bottom has a particular shape - if (u < l) { - return this->is_bottom(); - } - // nez cannot be set if 0 is a lower or upper bound - if (l == 0 || u == 0) { - return !is_nez; - } - // nez must be set if 0 is not in range - return (l <= 0 && u >= 0) || is_nez; + always_assert(is_normalized()); + } + // Is the constant not within the bounds? + bool unequals_constant(int64_t integer) const { + return (integer == 0 && is_nez) || integer < l || u < integer; + } + bool is_normalized() const { + // bottom has a particular shape + if (u < l) { + return this->is_bottom(); + } + // nez cannot be set if 0 is a lower or upper bound + if (l == 0 || u == 0) { + return !is_nez; } + // nez must be set if 0 is not in range + return (l <= 0 && u >= 0) || is_nez; + } - // Is the bounds known to be NEZ and nothing more? - bool is_nez_only() const { return is_nez && l == MIN && u == MAX; } + // Is the bounds known to be NEZ and nothing more? + bool is_nez_only() const { return is_nez && l == MIN && u == MAX; } - Bounds& set_to_top() { - *this = top(); - return *this; - } + Bounds& set_to_top() { + *this = top(); + return *this; + } - Bounds& set_to_bottom() { - *this = bottom(); - return *this; - } + Bounds& set_to_bottom() { + *this = bottom(); + return *this; + } - Bounds& join_with(const Bounds& that) { - l = std::min(l, that.l); - u = std::max(u, that.u); - is_nez &= that.is_nez; - always_assert(is_normalized()); - return *this; - } + Bounds& join_with(const Bounds& that) { + l = std::min(l, that.l); + u = std::max(u, that.u); + is_nez &= that.is_nez; + always_assert(is_normalized()); + return *this; + } - Bounds& meet_with(const Bounds& that) { - l = std::max(l, that.l); - u = std::min(u, that.u); - is_nez |= that.is_nez; - normalize(); - return *this; - } + Bounds& meet_with(const Bounds& that) { + l = std::max(l, that.l); + u = std::min(u, that.u); + is_nez |= that.is_nez; + normalize(); + return *this; + } - static Bounds from_interval(sign_domain::Interval interval) { - switch (interval) { - case sign_domain::Interval::EMPTY: - return bottom(); - case sign_domain::Interval::EQZ: - return {false, 0, 0}; - case sign_domain::Interval::LEZ: - return {false, MIN, 0}; - case sign_domain::Interval::LTZ: - return {true, MIN, -1}; - case sign_domain::Interval::GEZ: - return {false, 0, MAX}; - case sign_domain::Interval::GTZ: - return {true, 1, MAX}; - case sign_domain::Interval::ALL: - return top(); - case sign_domain::Interval::NEZ: - return nez(); - case sign_domain::Interval::SIZE: - break; - } - not_reached(); - } + static Bounds from_interval(sign_domain::Interval interval) { + switch (interval) { + case sign_domain::Interval::EMPTY: + return bottom(); + case sign_domain::Interval::EQZ: + return {false, 0, 0}; + case sign_domain::Interval::LEZ: + return {false, MIN, 0}; + case sign_domain::Interval::LTZ: + return {true, MIN, -1}; + case sign_domain::Interval::GEZ: + return {false, 0, MAX}; + case sign_domain::Interval::GTZ: + return {true, 1, MAX}; + case sign_domain::Interval::ALL: + return top(); + case sign_domain::Interval::NEZ: + return nez(); + case sign_domain::Interval::SIZE: + break; + } + not_reached(); + } - static Bounds from_integer(int64_t integer) { - return Bounds{integer != 0, integer, integer}; - } + static Bounds from_range(int64_t min, int64_t max) { + always_assert(min <= max); + return Bounds{min > 0 || max < 0, min, max}; + } - static constexpr Bounds top() { - constexpr Bounds res{false, MIN, MAX}; - return res; - } - static constexpr Bounds bottom() { - constexpr Bounds res{true, MAX, MIN}; - return res; - } - static constexpr Bounds nez() { - constexpr Bounds res{true, MIN, MAX}; - return res; - } - }; - Bounds m_bounds; + static Bounds from_integer(int64_t integer) { + return from_range(integer, integer); + } - /** Encodes the combination the lowest 6 bits of an integer. - * - * This class contains a 64-bit integer, low6bits_state, that tracks possible - * states of the lowest 6 bits. There are 64 bits in low6bits_state, each - * represents a value of the lowest 6 bits of the integer represented by the - * domain. The n'th bit in low6bits_state being 1 implies that the lowest 6 - * bits of the integer (or %64 in arithmetic terms) may still be possibly n - * conservatively, i.e., redex can't prove that the lowest 64 bits unequals n. - * On the other hand, the n'th bit being 0 implies the lowest 6 bits of the - * integer can be proven to unequal n. - * - * For example, if low6bits_state is 0b101, it means that the lowest 6 bits of - * the represented integer can be either decimal 0 or 2. - * - * This class costs 8 bytes of RAM. - */ - class Low6Bits final { - private: - uint64_t low6bits_state{std::numeric_limits::max()}; + static constexpr Bounds top() { + constexpr Bounds res{false, MIN, MAX}; + return res; + } + static constexpr Bounds bottom() { + constexpr Bounds res{true, MAX, MIN}; + return res; + } + static constexpr Bounds nez() { + constexpr Bounds res{true, MIN, MAX}; + return res; + } +}; - public: - Low6Bits() = default; - Low6Bits(const Low6Bits&) = default; - Low6Bits& operator=(const Low6Bits&) = default; - Low6Bits(Low6Bits&&) = default; - Low6Bits& operator=(Low6Bits&&) = default; - explicit Low6Bits(int64_t value) - : low6bits_state(static_cast(1u) << (value & 63)) {} - - bool operator==(const Low6Bits& that) const { - return low6bits_state == that.low6bits_state; - } +/** Encodes the combination the lowest 6 bits of an integer. + * + * This class contains a 64-bit integer, low6bits_state, that tracks possible + * states of the lowest 6 bits. There are 64 bits in low6bits_state, each + * represents a value of the lowest 6 bits of the integer represented by the + * domain. The n'th bit in low6bits_state being 1 implies that the lowest 6 + * bits of the integer (or %64 in arithmetic terms) may still be possibly n + * conservatively, i.e., redex can't prove that the lowest 64 bits unequals n. + * On the other hand, the n'th bit being 0 implies the lowest 6 bits of the + * integer can be proven to unequal n. + * + * For example, if low6bits_state is 0b101, it means that the lowest 6 bits of + * the represented integer can be either decimal 0 or 2. + * + * This class costs 8 bytes of RAM. + */ +class Low6Bits final { + private: + uint64_t low6bits_state{std::numeric_limits::max()}; - bool operator<=(const Low6Bits& that) const { - return (low6bits_state & ~that.low6bits_state) == 0; - } + public: + Low6Bits() = default; + Low6Bits(const Low6Bits&) = default; + Low6Bits& operator=(const Low6Bits&) = default; + Low6Bits(Low6Bits&&) = default; + Low6Bits& operator=(Low6Bits&&) = default; + explicit Low6Bits(int64_t value) + : low6bits_state(static_cast(1u) << (value & 63)) {} + + bool operator==(const Low6Bits& that) const { + return low6bits_state == that.low6bits_state; + } - bool is_bottom() const { return low6bits_state == 0u; } - bool is_top() const { return *this == top(); } + bool operator<=(const Low6Bits& that) const { + return (low6bits_state & ~that.low6bits_state) == 0; + } - constexpr void set_to_bottom() { low6bits_state = 0u; } - void set_to_top() { *this = top(); } + bool is_bottom() const { return low6bits_state == 0u; } + bool is_top() const { return *this == top(); } - Low6Bits& join_with(const Low6Bits& that) { - low6bits_state |= that.low6bits_state; - return *this; - } + constexpr void set_to_bottom() { low6bits_state = 0u; } + void set_to_top() { *this = top(); } - Low6Bits& meet_with(const Low6Bits& that) { - low6bits_state &= that.low6bits_state; - return *this; + Low6Bits& join_with(const Low6Bits& that) { + low6bits_state |= that.low6bits_state; + return *this; + } + + Low6Bits& meet_with(const Low6Bits& that) { + low6bits_state &= that.low6bits_state; + return *this; + } + + bool unequals_constant(int64_t integer) const { + return (low6bits_state & (static_cast(1u) << (integer & 63))) == + 0u; + } + + uint64_t get_low6bits_state() const { return low6bits_state; } + + static constexpr Low6Bits top() { return Low6Bits(); } + + static constexpr Low6Bits bottom() { + Low6Bits res; + res.set_to_bottom(); + return res; + } +}; + +class Bitset final { + // We use two integers to represent the state of each bit. A bit of + // one_bit_states/zero_bit_states being one means that the corresponding bit + // of the integer can possibly be one/zero. Hence, if the same bits of both + // are one, it means that the bit can be either one or zero, i.e., top for + // that bit. If any bit is zero in both integers, then the bitset is bottom. + // Use uint64_t instead of int64_t to avoid undefined behavior with >>. + // + // For 32-bit integers, the high 32 bits should be the same as the highest + // bit of the lower 32 bits, i.e., the sign bit of the integer. In this way, + // we achieve consistency with SignedConstantDomain initialized from a + // constant. + uint64_t one_bit_states{std::numeric_limits::max()}; + uint64_t zero_bit_states{std::numeric_limits::max()}; + + constexpr void set_all_to(bool zero, bool one) { + zero_bit_states = zero ? std::numeric_limits::max() : 0u; + one_bit_states = one ? std::numeric_limits::max() : 0u; + } + + // Construct with all bits set to a given bit state. + constexpr Bitset(bool zero, bool one) { set_all_to(zero, one); } + + public: + Bitset() = default; + Bitset(const Bitset& that) = default; + Bitset(Bitset&& that) = default; + Bitset& operator=(const Bitset& that) = default; + Bitset& operator=(Bitset&& that) = default; + + bool operator==(const Bitset& that) const { + return (one_bit_states == that.one_bit_states && + zero_bit_states == that.zero_bit_states) || + (is_bottom() && that.is_bottom()); + } + + bool operator<=(const Bitset& that) const { + if (is_bottom()) { + return true; } + return ((one_bit_states | that.one_bit_states) == that.one_bit_states && + (zero_bit_states | that.zero_bit_states) == that.zero_bit_states); + } - bool unequals_constant(int64_t integer) const { - return (low6bits_state & (static_cast(1u) << (integer & 63))) == - 0u; + // Partial order only. Use <= instead. + bool operator>(const Bitset&) = delete; + bool operator<(const Bitset&) = delete; + bool operator>=(const Bitset&) = delete; + + uint64_t get_one_bit_states() const { return one_bit_states; } + uint64_t get_zero_bit_states() const { return zero_bit_states; } + + // Construct from a constant. + explicit Bitset(int64_t value) { + one_bit_states = static_cast(value); + zero_bit_states = ~one_bit_states; + } + + bool is_constant() const { return get_constant().has_value(); } + + std::optional get_constant() const { + if (~one_bit_states == zero_bit_states) { + return static_cast(one_bit_states); + } else { + return std::nullopt; } + } - uint64_t get_low6bits_state() const { return low6bits_state; } + void set_to_bottom() { set_all_to(false, false); } + void set_to_top() { set_all_to(true, true); } + Bitset& join_with(const Bitset& that) { + one_bit_states |= that.one_bit_states; + zero_bit_states |= that.zero_bit_states; + return *this; + } - static constexpr Low6Bits top() { return Low6Bits(); } + Bitset& meet_with(const Bitset& that) { + one_bit_states &= that.one_bit_states; + zero_bit_states &= that.zero_bit_states; + return *this; + } - static constexpr Low6Bits bottom() { - Low6Bits res; - res.set_to_bottom(); - return res; + bool is_bottom() const { + // We don't use a single representation for bottom. Always use this + // function to check if it's bottom. It's bottom if any bit is zero in + // both integers. + return (one_bit_states | zero_bit_states) != + std::numeric_limits::max(); + } + + bool is_top() const { return *this == top(); } + + // Get the bit that can be determined to be one or zero. The returned + // integer hosts all bits that are determined to be zero/one. + uint64_t get_determined_zero_bits() const { + return zero_bit_states & ~one_bit_states; + } + uint64_t get_determined_one_bits() const { + return one_bit_states & ~zero_bit_states; + } + + // Set particular bits to be known to be zero/one. + Bitset& set_determined_zero_bits(uint64_t bits) { + one_bit_states &= ~bits; + zero_bit_states |= bits; + return *this; + } + Bitset& set_determined_one_bits(uint64_t bits) { + one_bit_states |= bits; + zero_bit_states &= ~bits; + return *this; + } + + // Is the constant unrepresentable by the bitset? + bool unequals_constant(int64_t integer) const { + const auto determinable_one_bits = get_determined_one_bits(); + if ((determinable_one_bits & integer) != determinable_one_bits) { + return true; } - }; + + const auto determinable_zero_bits = get_determined_zero_bits(); + return (determinable_zero_bits & ~integer) != determinable_zero_bits; + } + + static constexpr Bitset bottom() { + constexpr Bitset res(false, false); + return res; + } + static constexpr Bitset top() { + constexpr Bitset res(true, true); + return res; + } +}; + +} // namespace signed_constant_domain_internal + +class SignedConstantDomain; +std::ostream& operator<<(std::ostream& o, const SignedConstantDomain& scd); + +// This class effectively implements a +// ReducedProductAbstractDomain and bitset +// and low6bit domains. +class SignedConstantDomain final + : public sparta::AbstractDomain { + private: + using Bounds = signed_constant_domain_internal::Bounds; + using Low6Bits = signed_constant_domain_internal::Low6Bits; + using Bitset = signed_constant_domain_internal::Bitset; + + Bounds m_bounds; // TODO(T236830337): Remove OptionalLow6Bits. // Not using the abstract class/inheritence pattern to avoid heap allocation. @@ -271,15 +424,17 @@ class SignedConstantDomain final OptionalLow6Bits& operator=(OptionalLow6Bits&&) = default; explicit OptionalLow6Bits(const Low6Bits& l6b = Low6Bits::top()) - : low6bits(signed_constant_domain::enable_low6bits ? Low6BitsType(l6b) - : std::nullopt) {} + : low6bits(signed_constant_domain_internal::enable_low6bits + ? Low6BitsType(l6b) + : std::nullopt) {} explicit OptionalLow6Bits(int64_t value) - : low6bits(signed_constant_domain::enable_low6bits ? Low6BitsType(value) - : std::nullopt) {} + : low6bits(signed_constant_domain_internal::enable_low6bits + ? Low6BitsType(value) + : std::nullopt) {} OptionalLow6Bits& operator=(const Low6Bits& bs) { - if (signed_constant_domain::enable_low6bits) { + if (signed_constant_domain_internal::enable_low6bits) { low6bits = bs; } return *this; @@ -364,149 +519,6 @@ class SignedConstantDomain final OptionalLow6Bits m_low6bits; - class Bitset final { - // Albeit unusual, make sure the compiler uses 2's complement to represent - // int64, upon which much of this class relies. - static_assert( - static_cast(static_cast(-1)) == - std::numeric_limits::max() && - static_cast(std::numeric_limits::max()) == - static_cast(-1), - "Unsupported compiler: int64_t is not represented as 2's complement"); - - // We use two integers to represent the state of each bit. A bit of - // one_bit_states/zero_bit_states being one means that the corresponding bit - // of the integer can possibly be one/zero. Hence, if the same bits of both - // are one, it means that the bit can be either one or zero, i.e., top for - // that bit. If any bit is zero in both integers, then the bitset is bottom. - // Use uint64_t instead of int64_t to avoid undefined behavior with >>. - // - // For 32-bit integers, the high 32 bits should be the same as the highest - // bit of the lower 32 bits, i.e., the sign bit of the integer. In this way, - // we achieve consistency with SignedConstantDomain initialized from a - // constant. - uint64_t one_bit_states{std::numeric_limits::max()}; - uint64_t zero_bit_states{std::numeric_limits::max()}; - - constexpr void set_all_to(bool zero, bool one) { - zero_bit_states = zero ? std::numeric_limits::max() : 0u; - one_bit_states = one ? std::numeric_limits::max() : 0u; - } - - // Construct with all bits set to a given bit state. - constexpr Bitset(bool zero, bool one) { set_all_to(zero, one); } - - public: - Bitset() = default; - Bitset(const Bitset& that) = default; - Bitset(Bitset&& that) = default; - Bitset& operator=(const Bitset& that) = default; - Bitset& operator=(Bitset&& that) = default; - - bool operator==(const Bitset& that) const { - return (one_bit_states == that.one_bit_states && - zero_bit_states == that.zero_bit_states) || - (is_bottom() && that.is_bottom()); - } - - bool operator<=(const Bitset& that) const { - if (is_bottom()) { - return true; - } - return ((one_bit_states | that.one_bit_states) == that.one_bit_states && - (zero_bit_states | that.zero_bit_states) == that.zero_bit_states); - } - - // Partial order only. Use <= instead. - bool operator>(const Bitset&) = delete; - bool operator<(const Bitset&) = delete; - bool operator>=(const Bitset&) = delete; - - uint64_t get_one_bit_states() const { return one_bit_states; } - uint64_t get_zero_bit_states() const { return zero_bit_states; } - - // Construct from a constant. - explicit Bitset(int64_t value) { - one_bit_states = static_cast(value); - zero_bit_states = ~one_bit_states; - } - - bool is_constant() const { return get_constant().has_value(); } - - std::optional get_constant() const { - if (~one_bit_states == zero_bit_states) { - return static_cast(one_bit_states); - } else { - return std::nullopt; - } - } - - void set_to_bottom() { set_all_to(false, false); } - void set_to_top() { set_all_to(true, true); } - Bitset& join_with(const Bitset& that) { - one_bit_states |= that.one_bit_states; - zero_bit_states |= that.zero_bit_states; - return *this; - } - - Bitset& meet_with(const Bitset& that) { - one_bit_states &= that.one_bit_states; - zero_bit_states &= that.zero_bit_states; - return *this; - } - - bool is_bottom() const { - // We don't use a single representation for bottom. Always use this - // function to check if it's bottom. It's bottom if any bit is zero in - // both integers. - return (one_bit_states | zero_bit_states) != - std::numeric_limits::max(); - } - - bool is_top() const { return *this == top(); } - - // Get the bit that can be determined to be one or zero. The returned - // integer hosts all bits that are determined to be zero/one. - uint64_t get_determined_zero_bits() const { - return zero_bit_states & ~one_bit_states; - } - uint64_t get_determined_one_bits() const { - return one_bit_states & ~zero_bit_states; - } - - // Set particular bits to be known to be zero/one. - Bitset& set_determined_zero_bits(uint64_t bits) { - one_bit_states &= ~bits; - zero_bit_states |= bits; - return *this; - } - Bitset& set_determined_one_bits(uint64_t bits) { - one_bit_states |= bits; - zero_bit_states &= ~bits; - return *this; - } - - // Is the constant unrepresentable by the bitset? - bool unequals_constant(int64_t integer) const { - const auto determinable_one_bits = get_determined_one_bits(); - if ((determinable_one_bits & integer) != determinable_one_bits) { - return true; - } - - const auto determinable_zero_bits = get_determined_zero_bits(); - return (determinable_zero_bits & ~integer) != determinable_zero_bits; - } - - static constexpr Bitset bottom() { - constexpr Bitset res(false, false); - return res; - } - static constexpr Bitset top() { - constexpr Bitset res(true, true); - return res; - } - }; - // TODO(TT222824773): Remove OptionalBitset. // Not using the abstract class/inheritence pattern to avoid heap allocation. class OptionalBitset final { @@ -523,15 +535,17 @@ class SignedConstantDomain final OptionalBitset& operator=(OptionalBitset&&) = default; explicit OptionalBitset(const Bitset& bs = Bitset::top()) - : bitset(signed_constant_domain::enable_bitset ? BitsetType(bs) - : std::nullopt) {} + : bitset(signed_constant_domain_internal::enable_bitset + ? BitsetType(bs) + : std::nullopt) {} explicit OptionalBitset(int64_t value) - : bitset(signed_constant_domain::enable_bitset ? BitsetType(value) - : std::nullopt) {} + : bitset(signed_constant_domain_internal::enable_bitset + ? BitsetType(value) + : std::nullopt) {} OptionalBitset& operator=(const Bitset& bs) { - if (signed_constant_domain::enable_bitset) { + if (signed_constant_domain_internal::enable_bitset) { bitset = bs; } return *this; @@ -673,77 +687,11 @@ class SignedConstantDomain final SignedConstantDomain(Bounds bounds, Low6Bits low6bits, Bitset bitset) : m_bounds(bounds), m_low6bits(low6bits), m_bitset(bitset) {} - // When either bounds or bitset meets (become narrower), we can possibly - // infer the other one with some info. - void cross_infer_meet_from_bounds() { - if (m_bitset.is_bottom() || m_low6bits.is_bottom()) { - always_assert(m_bounds.is_bottom()); - return; - } - - // Constant inference - if (m_bounds.is_constant()) { - if (m_bitset.unequals_constant(m_bounds.l) || - m_low6bits.unequals_constant(m_bounds.l)) { - set_to_bottom(); - return; - } - m_low6bits = Low6Bits(m_bounds.l); - m_bitset = Bitset(m_bounds.l); - return; - } - - // One is bottom, then all is bottom. - if (m_bounds.is_bottom()) { - set_to_bottom(); - return; - } - - // More cross inference can be added here... - } - - void cross_infer_meet_from_bitset() { - if (m_bounds.is_bottom() || m_low6bits.is_bottom()) { - always_assert(!signed_constant_domain::enable_bitset || - m_bitset.is_bottom()); - return; - } - - const auto bitset_constant = m_bitset.get_constant(); - - if (bitset_constant.has_value()) { - if (m_bounds.unequals_constant(*bitset_constant) || - m_low6bits.unequals_constant(*bitset_constant)) { - set_to_bottom(); - return; - } - m_bounds = Bounds::from_integer(bitset_constant.value()); - m_low6bits = Low6Bits(bitset_constant.value()); - return; - } - - // One is bottom, then all is bottom. - if (m_bitset.is_bottom()) { - set_to_bottom(); - return; - } - - // More cross inference can be added here... - } - - void cross_infer_meet_from_low6bits() { - if (m_bounds.is_bottom() || m_bitset.is_bottom()) { - always_assert(!signed_constant_domain::enable_low6bits || - m_low6bits.is_bottom()); - return; - } - - // One is bottom, then all is bottom. - if (m_low6bits.is_bottom()) { - set_to_bottom(); - return; - } - } + // When any of the domain meets (become narrower), we can possibly + // infer the other ones to some extent. + inline void cross_infer_meet_from_bounds(); + inline void cross_infer_meet_from_bitset(); + inline void cross_infer_meet_from_low6bits(); public: SignedConstantDomain() @@ -762,10 +710,9 @@ class SignedConstantDomain final } SignedConstantDomain(int64_t min, int64_t max) - : m_bounds({min > 0 || max < 0, min, max}), + : m_bounds(Bounds::from_range(min, max)), m_low6bits(Low6Bits::top()), m_bitset(Bitset::top()) { - always_assert(min <= max); cross_infer_meet_from_bounds(); } @@ -790,15 +737,16 @@ class SignedConstantDomain final return SignedConstantDomain(Bounds::nez(), Low6Bits::top(), Bitset::top()); } bool is_bottom() const { - const bool res = - m_bounds.is_bottom() || - (!signed_constant_domain::enable_low6bits && m_low6bits.is_bottom()) || - (!signed_constant_domain::enable_bitset && m_bitset.is_bottom()); + const bool res = m_bounds.is_bottom() || + (!signed_constant_domain_internal::enable_low6bits && + m_low6bits.is_bottom()) || + (!signed_constant_domain_internal::enable_bitset && + m_bitset.is_bottom()); if (res) { always_assert(m_bounds.is_bottom()); - always_assert(!signed_constant_domain::enable_low6bits || + always_assert(!signed_constant_domain_internal::enable_low6bits || m_low6bits.is_bottom()); - always_assert(!signed_constant_domain::enable_bitset || + always_assert(!signed_constant_domain_internal::enable_bitset || m_bitset.is_bottom()); } return res; @@ -861,30 +809,7 @@ class SignedConstantDomain final return sign_domain::Domain(interval()); } - sign_domain::Interval interval() const { - if (m_bounds.is_bottom()) { - return sign_domain::Interval::EMPTY; - } - if (m_bounds.l > 0) { - return sign_domain::Interval::GTZ; - } - if (m_bounds.u < 0) { - return sign_domain::Interval::LTZ; - } - if (m_bounds.l == 0) { - if (m_bounds.u == 0) { - return sign_domain::Interval::EQZ; - } - return sign_domain::Interval::GEZ; - } - if (m_bounds.u == 0) { - return sign_domain::Interval::LEZ; - } - if (m_bounds.is_nez) { - return sign_domain::Interval::NEZ; - } - return sign_domain::Interval::ALL; - } + inline sign_domain::Interval interval() const; ConstantDomain constant_domain() const { if (const auto constant = get_constant(); constant.has_value()) { @@ -910,7 +835,7 @@ class SignedConstantDomain final if (!m_bounds.is_constant()) { return boost::none; } - if (signed_constant_domain::enable_bitset) { + if (signed_constant_domain_internal::enable_bitset) { always_assert(m_bitset.is_constant() && *m_bitset.get_constant() == m_bounds.l); } @@ -955,44 +880,8 @@ class SignedConstantDomain final // setting bounds to top if either zeros or ones is provided. Useful in // inferring results of bitwise ops, which usually invalidate any existing // inferences on Bounds. - SignedConstantDomain& set_determined_bits_erasing_bounds( - std::optional zeros, std::optional ones, bool bit32) { - // No bit can be 1 in both zeros and ones - always_assert(!zeros.has_value() || !ones.has_value() || - (*ones & *zeros) == 0u); - - if (!zeros.has_value() && !ones.has_value()) { - return *this; - } - - if (zeros.has_value()) { - uint64_t new_zeros = *zeros; - if (bit32) { - if ((*zeros & 0x80000000) != 0) { // sign bit is 1 - new_zeros |= static_cast(0xffffffff00000000ul); - } else { - new_zeros &= static_cast(0x7ffffffful); - } - } - m_bitset.set_determined_zero_bits(new_zeros); - } - if (ones.has_value()) { - uint64_t new_ones = *ones; - if (bit32) { - if ((*ones & 0x80000000) != 0) { // sign bit is 1 - new_ones |= static_cast(0xffffffff00000000ul); - } else { - new_ones &= static_cast(0x7ffffffful); - } - } - m_bitset.set_determined_one_bits(new_ones); - } - - m_bounds.set_to_top(); - m_low6bits.set_to_top(); - cross_infer_meet_from_bitset(); - return *this; - } + inline SignedConstantDomain& set_determined_bits_erasing_bounds( + std::optional zeros, std::optional ones, bool bit32); uint64_t get_one_bit_states() const { return m_bitset.get_one_bit_states(); } uint64_t get_zero_bit_states() const { @@ -1026,3 +915,133 @@ class SignedConstantDomain final static_cast(std::numeric_limits::min()))); } }; + +void SignedConstantDomain::cross_infer_meet_from_bounds() { + if (m_bitset.is_bottom() || m_low6bits.is_bottom()) { + always_assert(m_bounds.is_bottom()); + return; + } + + // Constant inference + if (m_bounds.is_constant()) { + if (m_bitset.unequals_constant(m_bounds.l) || + m_low6bits.unequals_constant(m_bounds.l)) { + set_to_bottom(); + return; + } + m_low6bits = Low6Bits(m_bounds.l); + m_bitset = Bitset(m_bounds.l); + return; + } + + // One is bottom, then all is bottom. + if (m_bounds.is_bottom()) { + set_to_bottom(); + return; + } +} + +void SignedConstantDomain::cross_infer_meet_from_bitset() { + if (m_bounds.is_bottom() || m_low6bits.is_bottom()) { + always_assert(!signed_constant_domain_internal::enable_bitset || + m_bitset.is_bottom()); + return; + } + + const auto bitset_constant = m_bitset.get_constant(); + + if (bitset_constant.has_value()) { + if (m_bounds.unequals_constant(*bitset_constant) || + m_low6bits.unequals_constant(*bitset_constant)) { + set_to_bottom(); + return; + } + m_bounds = Bounds::from_integer(bitset_constant.value()); + m_low6bits = Low6Bits(bitset_constant.value()); + return; + } + + // One is bottom, then all is bottom. + if (m_bitset.is_bottom()) { + set_to_bottom(); + return; + } +} + +void SignedConstantDomain::cross_infer_meet_from_low6bits() { + if (m_bounds.is_bottom() || m_bitset.is_bottom()) { + always_assert(!signed_constant_domain_internal::enable_low6bits || + m_low6bits.is_bottom()); + return; + } + + // One is bottom, then all is bottom. + if (m_low6bits.is_bottom()) { + set_to_bottom(); + return; + } +} + +SignedConstantDomain& SignedConstantDomain::set_determined_bits_erasing_bounds( + std::optional zeros, std::optional ones, bool bit32) { + // No bit can be 1 in both zeros and ones + always_assert(!zeros.has_value() || !ones.has_value() || + (*ones & *zeros) == 0u); + + if (!zeros.has_value() && !ones.has_value()) { + return *this; + } + + if (zeros.has_value()) { + uint64_t new_zeros = *zeros; + if (bit32) { + if ((*zeros & 0x80000000) != 0) { // sign bit is 1 + new_zeros |= static_cast(0xffffffff00000000ul); + } else { + new_zeros &= static_cast(0x7ffffffful); + } + } + m_bitset.set_determined_zero_bits(new_zeros); + } + if (ones.has_value()) { + uint64_t new_ones = *ones; + if (bit32) { + if ((*ones & 0x80000000) != 0) { // sign bit is 1 + new_ones |= static_cast(0xffffffff00000000ul); + } else { + new_ones &= static_cast(0x7ffffffful); + } + } + m_bitset.set_determined_one_bits(new_ones); + } + + m_bounds.set_to_top(); + m_low6bits.set_to_top(); + cross_infer_meet_from_bitset(); + return *this; +} + +sign_domain::Interval SignedConstantDomain::interval() const { + if (m_bounds.is_bottom()) { + return sign_domain::Interval::EMPTY; + } + if (m_bounds.l > 0) { + return sign_domain::Interval::GTZ; + } + if (m_bounds.u < 0) { + return sign_domain::Interval::LTZ; + } + if (m_bounds.l == 0) { + if (m_bounds.u == 0) { + return sign_domain::Interval::EQZ; + } + return sign_domain::Interval::GEZ; + } + if (m_bounds.u == 0) { + return sign_domain::Interval::LEZ; + } + if (m_bounds.is_nez) { + return sign_domain::Interval::NEZ; + } + return sign_domain::Interval::ALL; +} diff --git a/test/common/RedexTest.h b/test/common/RedexTest.h index 088f079627..33bfd8e9f1 100644 --- a/test/common/RedexTest.h +++ b/test/common/RedexTest.h @@ -53,21 +53,21 @@ inline std::string get_env(const char* name) { return env_file; } -namespace signed_constant_domain { +namespace signed_constant_domain_internal { // TODO(T222824773): Remove this. extern bool enable_bitset; // TODO(T236830337): Remove this. extern bool enable_low6bits; -} // namespace signed_constant_domain +} // namespace signed_constant_domain_internal struct RedexTest : public testing::Test { public: RedexTest() { g_redex = new RedexContext(); // TODO(TT222824773): Remove this. - signed_constant_domain::enable_bitset = true; + signed_constant_domain_internal::enable_bitset = true; // TODO(T236830337): Remove this. - signed_constant_domain::enable_low6bits = true; + signed_constant_domain_internal::enable_low6bits = true; } ~RedexTest() { delete g_redex; } diff --git a/tools/redex-all/main.cpp b/tools/redex-all/main.cpp index bba9b8d834..908a771b46 100644 --- a/tools/redex-all/main.cpp +++ b/tools/redex-all/main.cpp @@ -1964,13 +1964,13 @@ int main(int argc, char* argv[]) { keep_reason::Reason::set_record_keep_reasons( args.config.get("record_keep_reasons", false).asBool()); - signed_constant_domain::enable_bitset = + signed_constant_domain_internal::enable_bitset = args.config.get("enable_bitset_constant_propagation", true).asBool(); - signed_constant_domain::enable_low6bits = + signed_constant_domain_internal::enable_low6bits = args.config.get("enable_low6bits_constant_propagation", false).asBool(); - always_assert_log(!signed_constant_domain::enable_low6bits || - signed_constant_domain::enable_bitset, + always_assert_log(!signed_constant_domain_internal::enable_low6bits || + signed_constant_domain_internal::enable_bitset, "enable_bitset_constant_propagation must be turned on if " "enable_low6bits_constant_propagation is turned on.");