diff --git a/rng/philox/philox.hpp b/rng/philox/philox.hpp index 51cdbe2..01cfcae 100644 --- a/rng/philox/philox.hpp +++ b/rng/philox/philox.hpp @@ -38,8 +38,46 @@ namespace detail { { static constexpr std::size_t value = Is * 2; }; + + template + class save_stream_flags + { + typedef ::std::basic_ios<_CharT, _Traits> stream_type; + + public: + save_stream_flags(const save_stream_flags&) = delete; + save_stream_flags& + operator=(const save_stream_flags&) = delete; + + explicit save_stream_flags(stream_type& stream) + : stream_(stream), fmtflags_(stream.flags()), fill_(stream.fill()) + { + } + ~save_stream_flags() + { + stream_.flags(fmtflags_); + stream_.fill(fill_); + } + + private: + typename stream_type::fmtflags fmtflags_; + stream_type& stream_; + _CharT fill_; + }; } // namespace detail +template +struct philox_engine; + +template +std::basic_ostream& +operator<<(std::basic_ostream& os, const philox_engine& x); + +template +std::basic_istream& +operator>>(std::basic_istream& is, philox_engine& x); + + template struct philox_engine { @@ -73,7 +111,7 @@ struct philox_engine static constexpr std::array multipliers = extract_elements(even_indices_sequence{}); static constexpr std::array round_consts = extract_elements(odd_indices_sequence{}); static constexpr result_type min() { return 0; } - static constexpr result_type max() { return max_impl(); } + static constexpr result_type max() { return result_mask; } static constexpr result_type default_seed = 20111115u; // constructors and seeding functions philox_engine() : philox_engine(default_seed) {} @@ -171,13 +209,13 @@ struct philox_engine } // inserters and extractors - template - friend std::basic_ostream& - operator<<(std::basic_ostream& os, const philox_engine& x); + template + friend std::basic_ostream& + operator<<(std::basic_ostream& os, const philox_engine& x); - template - friend std::basic_istream& - operator>>(std::basic_istream& is, philox_engine& x); + template + friend std::basic_istream& + operator>>(std::basic_istream& is, philox_engine& x); private: // utilities @@ -194,8 +232,8 @@ struct philox_engine using counter_type = std::tuple_element_t; using promotion_type = std::tuple_element_t; - static constexpr counter_type counter_mask = ~counter_type(0) >> (sizeof(counter_type) * CHAR_BIT - w); - static constexpr result_type result_mask = ~result_type(0) >> (sizeof(result_type) * CHAR_BIT - w); + static constexpr counter_type counter_mask = static_cast(~counter_type(0)) >> (std::numeric_limits::digits - w); + static constexpr result_type result_mask = static_cast(~result_type(0)) >> (std::numeric_limits::digits - w); private: // functions @@ -274,13 +312,6 @@ struct philox_engine state_i = n - 1; } - static constexpr result_type max_impl() - { - return w == std::numeric_limits::digits - ? std::numeric_limits::digits - 1 - : (result_type(1) << w) - 1; - } - public: // state std::array x; std::array k; @@ -289,6 +320,58 @@ struct philox_engine std::uint32_t state_i; }; +template +std::basic_ostream& +operator<<(std::basic_ostream& os, const philox_engine& engine) { + + detail::save_stream_flags flags(os); + + os.setf(std::ios_base::dec | std::ios_base::left); + CharT sp = os.widen(' '); + os.fill(sp); + + for (std::size_t i = 0; i < n / 2; ++i) { + os << engine.k[i] << sp; + } + for (std::size_t i = 0; i < n; ++i) { + os << (UIntType)engine.x[i] << sp; + } + for (std::size_t i = 0; i < n; ++i) { + os << engine.y[i] << sp; + } + os << engine.state_i; + return os; +} + +template +std::basic_istream& +operator>>(std::basic_istream& is, philox_engine& engine) { + detail::save_stream_flags flags(is); + + is.setf(std::ios_base::dec | std::ios_base::skipws); + + // need a check for the different types? + UIntType tmp[5 * n / 2 + 1]; + + for (std::size_t i = 0; i < 5 * n / 2 + 1; ++i) { + is >> tmp[i]; + } + if(!is.fail()) { + std::size_t j = 0; + for (std::size_t i = 0; i < n / 2; ++i) { + engine.k[i] = tmp[j++]; + } + for (std::size_t i = 0; i < n; ++i) { + engine.x[i] = tmp[j++]; + } + for (std::size_t i = 0; i < n; ++i) { + engine.y[i] = tmp[j++]; + } + engine.state_i = tmp[j]; // do we change state_i or set it to n - 1 ? + } + return is; +} + using philox4x32 = philox_engine; using philox4x64 = philox_engine; diff --git a/rng/philox/test.cpp b/rng/philox/test.cpp new file mode 100644 index 0000000..ba0337e --- /dev/null +++ b/rng/philox/test.cpp @@ -0,0 +1,308 @@ +#include +#include +#include +#include +#include + +#include "philox.hpp" + +using philox2x32_w5 = std::philox_engine; +using philox2x32_w30 = std::philox_engine; +using philox2x64_w15 = std::philox_engine; +using philox2x64_w49 = std::philox_engine; + +// Test the conformance of the implementation with the ISO C++ standard +template +void conformance_test() { + Engine engine; + for(int i = 0; i < 9999; i++) { + engine(); + } + typename Engine::result_type reference; + if(std::is_same_v) { + reference = 1955073260; + } + else { + reference = 3409172418970261260; + } + if(engine() == reference) { + std::cout << __PRETTY_FUNCTION__ << " passed" << std::endl; + } else { + std::cout << __PRETTY_FUNCTION__ << " failed" << std::endl; + } +} + +// Test public API +template +void api_test() { + { + Engine engine; + engine.seed(); + } + { + Engine engine(1); + engine.seed(1); + } + { + std::seed_seq s; + Engine engine(s); + engine.seed(s); + } + { + Engine engine; + Engine engine2; + if(!(engine == engine2) || (engine != engine2)) { + std::cout << __PRETTY_FUNCTION__ << " failed !=, == for the same engines" << std::endl; + return; + } + engine2.seed(42); + if((engine == engine2) || !(engine != engine2)) { + std::cout << __PRETTY_FUNCTION__ << " failed !=, == for the different engines" << std::endl; + return; + } + } + { + std::ostringstream os; + Engine engine; + os << engine << std::endl; + Engine engine2; + engine2(); + std::istringstream in(os.str()); + in >> engine2; + if(engine != engine2) { + std::cout << __PRETTY_FUNCTION__ << " failed for >> << operators" << std::endl; + return; + } + } + { + Engine engine; + engine.min(); + engine.max(); + } + std::cout << __PRETTY_FUNCTION__ << " passed" << std::endl; +} + +template +void seed_test() { + for(int i = 1; i < 5; i++) { // make sure that the state is reset properly for all idx positions + Engine engine; + typename Engine::result_type res; + for(int j = 0; j < i - 1; j++) { + engine(); + } + res = engine(); + engine.seed(); + for(int j = 0; j < i - 1; j++) { + engine(); + } + if(res != engine()) { + std::cout << __PRETTY_FUNCTION__ << " failed while generating " << i << " elements" << std::endl; + } + } + std::cout << __PRETTY_FUNCTION__ << " passed" << std::endl; +} + +template +void discard_test() { + { + constexpr size_t n = 10; // arbitrary length we want to check + typename Engine::result_type reference[n]; + Engine engine; + for(int i = 0; i < n; i++) { + reference[i] = engine(); + } + for(int i = 0; i < n; i++) { + engine.seed(); + engine.discard(i); + for(size_t j = i; j < n; j++) { + if(reference[j] != engine()) { + std::cout << __PRETTY_FUNCTION__ << " failed with error in element " << j << " discard " << i << std::endl; + break; + } + } + } + std::cout << __PRETTY_FUNCTION__ << " passed step 1 discard from the intial state" << std::endl; + + for(int i = 1; i < n; i++) { + for(int j = 1; j < i; j++) { + engine.seed(); + for(size_t k = 0; k < i - j; k++) { + engine(); + } + engine.discard(j); + if(reference[i] != engine()) { + std::cout << __PRETTY_FUNCTION__ << " failed on step " << i << " " << j << std::endl; + break; + } + } + } + std::cout << __PRETTY_FUNCTION__ << " passed step 2 discard after generation" << std::endl; + } +} + +template +void set_counter_conformance_test() { + Engine engine; + std::array counter; + for(int i = 0; i < Engine::word_count - 1; i++) { + counter[i] = 0; + } + + counter[Engine::word_count - 1] = 2499; // to get 10'000 element + engine.set_counter(counter); + + for(int i = 0; i < Engine::word_count - 1; i++) { + engine(); + } + + typename Engine::result_type reference; + if(std::is_same_v) { + reference = 1955073260; + } + else { + reference = 3409172418970261260; + } + if(engine() == reference) { + std::cout << __PRETTY_FUNCTION__ << " passed" << std::endl; + } else { + std::cout << __PRETTY_FUNCTION__ << " failed" << std::endl; + } +} + +template +void skip_test() { + using T = typename Engine::result_type; + for(T i = 1; i <= Engine::word_count + 1; i++) { + Engine engine1; + std::array counter = {0}; + counter[Engine::word_count - 1] = i / Engine::word_count; + engine1.set_counter(counter); + for(T j = 0; j < i % Engine::word_count; j++) { + engine1(); + } + + Engine engine2; + engine2.discard(i); + + if(engine1() != engine2()) { + std::cout << __PRETTY_FUNCTION__ << " failed for " << i << " skip" << std::endl; + return; + } + } + std::cout << __PRETTY_FUNCTION__ << " passed" << std::endl; +} + +template +void counter_overflow_test() { + using T = typename Engine::result_type; + Engine engine1; + std::array counter; + for(int i = 0; i < Engine::word_count; i++) { + counter[i] = std::numeric_limits::max(); + } + + engine1.set_counter(counter); + for(int i = 0; i < Engine::word_count; i++) { + engine1(); + } // all counters overflowed == start from 0 0 0 0 + + Engine engine2; + + if(engine1() == engine2()) { + std::cout << __PRETTY_FUNCTION__ << " passed" << std::endl; + } else { + std::cout << __PRETTY_FUNCTION__ << " failed" << std::endl; + } +} + +template +void discard_overflow_test() { + using T = typename Engine::result_type; + for (int overflow_position = 0; overflow_position < Engine::word_count - 1; overflow_position++) { + Engine engine1; + std::array counter = {0}; + + int raw_counter_position = (Engine::word_count - overflow_position - 2) % Engine::word_count; + std::cout << "Testing discard overflow for position " << raw_counter_position << std::endl; + counter[raw_counter_position] = 1; + + engine1.set_counter(counter); + + Engine engine2; + + std::array counter2 = {0}; + for (int i = Engine::word_count - overflow_position - 1; i < Engine::word_count - 1; i++) { + counter2[i] = std::numeric_limits::max(); + } + + engine2.set_counter(counter2); + + for (int i = 0; i < Engine::word_count; i++) { + engine2(); + } + + for (int i = 0; i < Engine::word_count; i++) { + engine2.discard(engine2.max()); + } + + if (engine1() == engine2()) { + std::cout << __PRETTY_FUNCTION__ << " passed for overflow_position " << overflow_position << std::endl; + } + else { + std::cout << __PRETTY_FUNCTION__ << " failed for overflow_position " << overflow_position << std::endl; + break; + } + } +} + +int main() { + conformance_test(); + conformance_test(); + + set_counter_conformance_test(); + set_counter_conformance_test(); + + api_test(); + api_test(); + api_test(); + api_test(); + api_test(); + api_test(); + + seed_test(); + seed_test(); + seed_test(); + seed_test(); + seed_test(); + seed_test(); + + discard_test(); + discard_test(); + discard_test(); + discard_test(); + discard_test(); + discard_test(); + + skip_test(); + skip_test(); + skip_test(); + skip_test(); + skip_test(); + skip_test(); + + counter_overflow_test(); + counter_overflow_test(); + counter_overflow_test(); + counter_overflow_test(); + counter_overflow_test(); + counter_overflow_test(); + + discard_overflow_test(); + discard_overflow_test(); + discard_overflow_test(); + discard_overflow_test(); + discard_overflow_test(); + discard_overflow_test(); + + return 0; +}