Skip to content
115 changes: 99 additions & 16 deletions rng/philox/philox.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,46 @@ namespace detail {
{
static constexpr std::size_t value = Is * 2;
};

template <class _CharT, class _Traits>
class save_stream_flags
{
typedef ::std::basic_ios<_CharT, _Traits> stream_type;

public:
save_stream_flags(const save_stream_flags&) = delete;
save_stream_flags&
operator=(const save_stream_flags&) = delete;

explicit save_stream_flags(stream_type& stream)
: stream_(stream), fmtflags_(stream.flags()), fill_(stream.fill())
{
}
~save_stream_flags()
{
stream_.flags(fmtflags_);
stream_.fill(fill_);
}

private:
typename stream_type::fmtflags fmtflags_;
stream_type& stream_;
_CharT fill_;
};
} // namespace detail

template <typename UIntType, std::size_t w, std::size_t n, std::size_t r, UIntType... consts>
struct philox_engine;

template<class CharT, class Traits, typename UIntType_, std::size_t w_, std::size_t n_, std::size_t r_, UIntType_... consts_>
std::basic_ostream<CharT, Traits>&
operator<<(std::basic_ostream<CharT, Traits>& os, const philox_engine<UIntType_, w_, n_, r_, consts_...>& x);

template<class CharT, class Traits, typename UIntType_, std::size_t w_, std::size_t n_, std::size_t r_, UIntType_... consts_>
std::basic_istream<CharT, Traits>&
operator>>(std::basic_istream<CharT, Traits>& is, philox_engine<UIntType_, w_, n_, r_, consts_...>& x);


template <typename UIntType, std::size_t w, std::size_t n, std::size_t r, UIntType... consts>
struct philox_engine
{
Expand Down Expand Up @@ -73,7 +111,7 @@ struct philox_engine
static constexpr std::array<result_type, array_size> multipliers = extract_elements(even_indices_sequence{});
static constexpr std::array<result_type, array_size> round_consts = extract_elements(odd_indices_sequence{});
static constexpr result_type min() { return 0; }
static constexpr result_type max() { return max_impl(); }
static constexpr result_type max() { return result_mask; }
static constexpr result_type default_seed = 20111115u;
// constructors and seeding functions
philox_engine() : philox_engine(default_seed) {}
Expand Down Expand Up @@ -171,13 +209,13 @@ struct philox_engine
}

// inserters and extractors
template<class charT, class traits>
friend std::basic_ostream<charT, traits>&
operator<<(std::basic_ostream<charT, traits>& os, const philox_engine& x);
template<class CharT, class Traits, typename UIntType_, std::size_t w_, std::size_t n_, std::size_t r_, UIntType_... consts_>
friend std::basic_ostream<CharT, Traits>&
operator<<(std::basic_ostream<CharT, Traits>& os, const philox_engine<UIntType_, w_, n_, r_, consts_...>& x);

template<class charT, class traits>
friend std::basic_istream<charT, traits>&
operator>>(std::basic_istream<charT, traits>& is, philox_engine& x);
template<class CharT, class Traits, typename UIntType_, std::size_t w_, std::size_t n_, std::size_t r_, UIntType_... consts_>
friend std::basic_istream<CharT, Traits>&
operator>>(std::basic_istream<CharT, Traits>& is, philox_engine<UIntType_, w_, n_, r_, consts_...>& x);


private: // utilities
Expand All @@ -194,8 +232,8 @@ struct philox_engine
using counter_type = std::tuple_element_t<get_log_index(w), uint_types>;
using promotion_type = std::tuple_element_t<get_log_index(w), promotion_types>;

static constexpr counter_type counter_mask = ~counter_type(0) >> (sizeof(counter_type) * CHAR_BIT - w);
static constexpr result_type result_mask = ~result_type(0) >> (sizeof(result_type) * CHAR_BIT - w);
static constexpr counter_type counter_mask = static_cast<counter_type>(~counter_type(0)) >> (std::numeric_limits<counter_type>::digits - w);
static constexpr result_type result_mask = static_cast<result_type>(~result_type(0)) >> (std::numeric_limits<result_type>::digits - w);


private: // functions
Expand Down Expand Up @@ -274,13 +312,6 @@ struct philox_engine
state_i = n - 1;
}

static constexpr result_type max_impl()
{
return w == std::numeric_limits<result_type>::digits
? std::numeric_limits<result_type>::digits - 1
: (result_type(1) << w) - 1;
}

public: // state
std::array<counter_type, n> x;
std::array<result_type, array_size> k;
Expand All @@ -289,6 +320,58 @@ struct philox_engine
std::uint32_t state_i;
};

template<class CharT, class Traits, typename UIntType, std::size_t w, std::size_t n, std::size_t r, UIntType... consts>
std::basic_ostream<CharT, Traits>&
operator<<(std::basic_ostream<CharT, Traits>& os, const philox_engine<UIntType, w, n, r, consts...>& engine) {

detail::save_stream_flags<CharT, Traits> flags(os);

os.setf(std::ios_base::dec | std::ios_base::left);
CharT sp = os.widen(' ');
os.fill(sp);

for (std::size_t i = 0; i < n / 2; ++i) {
os << engine.k[i] << sp;
}
for (std::size_t i = 0; i < n; ++i) {
os << (UIntType)engine.x[i] << sp;
}
for (std::size_t i = 0; i < n; ++i) {
os << engine.y[i] << sp;
}
os << engine.state_i;
return os;
}

template<class CharT, class Traits, typename UIntType, std::size_t w, std::size_t n, std::size_t r, UIntType... consts>
std::basic_istream<CharT, Traits>&
operator>>(std::basic_istream<CharT, Traits>& is, philox_engine<UIntType, w, n, r, consts...>& engine) {
detail::save_stream_flags<CharT, Traits> flags(is);

is.setf(std::ios_base::dec | std::ios_base::skipws);

// need a check for the different types?
UIntType tmp[5 * n / 2 + 1];

for (std::size_t i = 0; i < 5 * n / 2 + 1; ++i) {
is >> tmp[i];
}
if(!is.fail()) {
std::size_t j = 0;
for (std::size_t i = 0; i < n / 2; ++i) {
engine.k[i] = tmp[j++];
}
for (std::size_t i = 0; i < n; ++i) {
engine.x[i] = tmp[j++];
}
for (std::size_t i = 0; i < n; ++i) {
engine.y[i] = tmp[j++];
}
engine.state_i = tmp[j]; // do we change state_i or set it to n - 1 ?
}
return is;
}

using philox4x32 = philox_engine<std::uint_fast32_t, 32, 4, 10, 0xCD9E8D57, 0x9E3779B9, 0xD2511F53, 0xBB67AE85>;
using philox4x64 = philox_engine<std::uint_fast64_t, 64, 4, 10, 0xCA5A826395121157, 0x9E3779B97F4A7C15, 0xD2E7470EE14C6C93, 0xBB67AE8584CAA73B>;

Expand Down
Loading