Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mrmd/action/LangevinThermostat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ class LangevinThermostat

void apply(data::Atoms& atoms);

template <std::predicate<const real_t, const real_t, const real_t> UnaryPred>
void apply_if(data::Atoms& atoms, const UnaryPred& pred);
template <OnePositionPredicate Pred>
void apply_if(data::Atoms& atoms, const Pred& pred);

void set(const real_t gamma, const real_t temperature, const real_t timestep)
{
Expand All @@ -53,8 +53,8 @@ class LangevinThermostat
}
};

template <std::predicate<const real_t, const real_t, const real_t> UnaryPred>
void LangevinThermostat::apply_if(data::Atoms& atoms, const UnaryPred& pred)
template <OnePositionPredicate Pred>
void LangevinThermostat::apply_if(data::Atoms& atoms, const Pred& pred)
{
auto RNG = randPool_;
auto pos = atoms.getPos();
Expand Down
61 changes: 7 additions & 54 deletions mrmd/action/LennardJones.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,60 +22,13 @@ namespace mrmd::action
{
void LennardJones::apply(data::Atoms& atoms, HalfVerletList& verletList)
{
energyAndVirial_ = data::EnergyAndVirialReducer();

pos_ = atoms.getPos();
force_ = atoms.getForce();
type_ = atoms.getType();
verletList_ = verletList;

auto policy = Kokkos::RangePolicy<>(0, atoms.numLocalAtoms);
Kokkos::parallel_reduce("LennardJones::applyForces", policy, *this, energyAndVirial_);
Kokkos::fence();
}

KOKKOS_FUNCTION
void LennardJones::operator()(const idx_t& idx, data::EnergyAndVirialReducer& energyAndVirial) const
{
real_t posTmp[3];
posTmp[0] = pos_(idx, 0);
posTmp[1] = pos_(idx, 1);
posTmp[2] = pos_(idx, 2);

real_t forceTmp[3] = {0_r, 0_r, 0_r};

const auto numNeighbors = idx_c(HalfNeighborList::numNeighbor(verletList_, idx));
for (idx_t n = 0; n < numNeighbors; ++n)
{
idx_t jdx = idx_c(HalfNeighborList::getNeighbor(verletList_, idx, n));
assert(0 <= jdx);

auto dx = posTmp[0] - pos_(jdx, 0);
auto dy = posTmp[1] - pos_(jdx, 1);
auto dz = posTmp[2] - pos_(jdx, 2);

auto distSqr = dx * dx + dy * dy + dz * dz;

if (distSqr > rcSqr_) continue;

auto typeIdx = type_(idx) * numTypes_ + type_(jdx);
auto forceAndEnergy = LJ_.computeForceAndEnergy(distSqr, typeIdx);
assert(!std::isnan(forceAndEnergy.forceFactor));
energyAndVirial.energy += forceAndEnergy.energy;
energyAndVirial.virial -= 0.5_r * forceAndEnergy.forceFactor * distSqr;

forceTmp[0] += dx * forceAndEnergy.forceFactor;
forceTmp[1] += dy * forceAndEnergy.forceFactor;
forceTmp[2] += dz * forceAndEnergy.forceFactor;

force_(jdx, 0) -= dx * forceAndEnergy.forceFactor;
force_(jdx, 1) -= dy * forceAndEnergy.forceFactor;
force_(jdx, 2) -= dz * forceAndEnergy.forceFactor;
}

force_(idx, 0) += forceTmp[0];
force_(idx, 1) += forceTmp[1];
force_(idx, 2) += forceTmp[2];
apply_if(
atoms,
verletList,
KOKKOS_LAMBDA(
const real_t, const real_t, const real_t, const real_t, const real_t, const real_t) {
return true;
});
}

real_t LennardJones::getEnergy() const { return energyAndVirial_.energy; }
Expand Down
80 changes: 77 additions & 3 deletions mrmd/action/LennardJones.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ class LennardJones
data::EnergyAndVirialReducer energyAndVirial_;

public:
KOKKOS_FUNCTION
void operator()(const idx_t& idx, data::EnergyAndVirialReducer& energyAndVirial) const;

real_t getEnergy() const;
real_t getVirial() const;

void apply(data::Atoms& atoms, HalfVerletList& verletList);

template <TwoPositionsPredicate Pred>
void apply_if(const data::Atoms& atoms, const HalfVerletList& verletList, const Pred& pred);

LennardJones(const real_t rc,
const real_t& sigma,
const real_t& epsilon,
Expand All @@ -105,4 +105,78 @@ class LennardJones
const idx_t& numTypes,
const bool isShifted);
};

template <TwoPositionsPredicate Pred>
void LennardJones::apply_if(const data::Atoms& atoms,
const HalfVerletList& verletList,
const Pred& pred)
{
energyAndVirial_ = data::EnergyAndVirialReducer();
pos_ = atoms.getPos();
force_ = atoms.getForce();
type_ = atoms.getType();
verletList_ = verletList;

auto policy = Kokkos::RangePolicy<>(0, atoms.numLocalAtoms);

// avoid capturing this pointer
auto pos = pos_;
auto force = force_;
auto type = type_;
auto verletListLocal = verletList_;
auto rcSqr = rcSqr_;
auto LJ = LJ_;
auto numTypes = numTypes_;
auto predLocal = pred;

auto kernel = KOKKOS_LAMBDA(const idx_t idx, data::EnergyAndVirialReducer& energyAndVirial)
{
real_t posTmp[3];
posTmp[0] = pos(idx, 0);
posTmp[1] = pos(idx, 1);
posTmp[2] = pos(idx, 2);

real_t forceTmp[3] = {0_r, 0_r, 0_r};

const auto numNeighbors = idx_c(HalfNeighborList::numNeighbor(verletListLocal, idx));
for (idx_t n = 0; n < numNeighbors; ++n)
{
idx_t jdx = idx_c(HalfNeighborList::getNeighbor(verletListLocal, idx, n));
assert(0 <= jdx);

if (!predLocal(posTmp[0], posTmp[1], posTmp[2], pos(jdx, 0), pos(jdx, 1), pos(jdx, 2)))
continue;

auto dx = posTmp[0] - pos(jdx, 0);
auto dy = posTmp[1] - pos(jdx, 1);
auto dz = posTmp[2] - pos(jdx, 2);

auto distSqr = dx * dx + dy * dy + dz * dz;

if (distSqr > rcSqr) continue;

auto typeIdx = type(idx) * numTypes + type(jdx);
auto forceAndEnergy = LJ.computeForceAndEnergy(distSqr, typeIdx);
assert(!std::isnan(forceAndEnergy.forceFactor));
energyAndVirial.energy += forceAndEnergy.energy;
energyAndVirial.virial -= 0.5_r * forceAndEnergy.forceFactor * distSqr;

forceTmp[0] += dx * forceAndEnergy.forceFactor;
forceTmp[1] += dy * forceAndEnergy.forceFactor;
forceTmp[2] += dz * forceAndEnergy.forceFactor;

force(jdx, 0) -= dx * forceAndEnergy.forceFactor;
force(jdx, 1) -= dy * forceAndEnergy.forceFactor;
force(jdx, 2) -= dz * forceAndEnergy.forceFactor;
}

force(idx, 0) += forceTmp[0];
force(idx, 1) += forceTmp[1];
force(idx, 2) += forceTmp[2];
};

Kokkos::parallel_reduce("LennardJones::apply_if", policy, kernel, energyAndVirial_);
Kokkos::fence();
}

} // namespace mrmd::action
8 changes: 4 additions & 4 deletions mrmd/action/ThermodynamicForce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class ThermodynamicForce
void update(const real_t& smoothingSigma, const real_t& smoothingIntensity);
void apply(const data::Atoms& atoms) const;

template <std::predicate<const real_t, const real_t, const real_t> UnaryPred>
void apply_if(const data::Atoms& atoms, const UnaryPred& pred) const;
template <OnePositionPredicate Pred>
void apply_if(const data::Atoms& atoms, const Pred& pred) const;

std::vector<real_t> getMuLeft() const;
std::vector<real_t> getMuRight() const;
Expand All @@ -85,8 +85,8 @@ class ThermodynamicForce
const bool usePeriodicity = false);
};

template <std::predicate<const real_t, const real_t, const real_t> UnaryPred>
void ThermodynamicForce::apply_if(const data::Atoms& atoms, const UnaryPred& pred) const
template <OnePositionPredicate Pred>
void ThermodynamicForce::apply_if(const data::Atoms& atoms, const Pred& pred) const
{
auto atomsPos = atoms.getPos();
auto atomsForce = atoms.getForce();
Expand Down
7 changes: 2 additions & 5 deletions mrmd/data/MultiHistogram.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,15 @@ void transform(const MultiHistogram& input1,
/**
* Replaces histogram values with a new value if the bin position satisfies a predicate.
*
* @tparam UnaryPred A unary predicate type that takes a real_t value and returns a boolean.
* Must satisfy std::predicate<const real_t> concept.
*
* @param hist The MultiHistogram object whose values will be conditionally replaced.
* @param pred A unary predicate function that is evaluated for each bin position.
* If it returns true for a bin position, all histogram values at that bin
* are replaced with newValue.
* @param newValue The value to assign to histogram entries whose bin position satisfies
* the predicate.
*/
template <std::predicate<const real_t> UnaryPred>
void replace_if_bin_position(MultiHistogram& hist, const UnaryPred& pred, real_t newValue)
template <OneCoordinatePredicate Pred>
void replace_if_bin_position(MultiHistogram& hist, const Pred& pred, real_t newValue)
{
auto policy =
Kokkos::MDRangePolicy<Kokkos::Rank<2>>({0, 0}, {hist.numBins, hist.numHistograms});
Expand Down
28 changes: 28 additions & 0 deletions mrmd/datatypes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,32 @@ using NeighborList [[deprecated]] = Cabana::NeighborList<HalfVerletList>;
using HalfNeighborList = Cabana::NeighborList<HalfVerletList>;
using FullNeighborList = Cabana::NeighborList<FullVerletList>;

// Concept for predicates that take one coordinate as input, e.g. for spatially selective updating
// of histogram bins
template <typename F>
concept OneCoordinatePredicate = std::predicate<F,
const real_t // coordinate value
>;

// Concept for predicates that take one position as input, e.g. for spatially selective application
// of forces or thermostats
template <typename F>
concept OnePositionPredicate = std::predicate<F,
const real_t, // pos x
const real_t, // pos y
const real_t // pos z
>;

// Concept for predicates that take two positions as input, e.g. for conditions based on relative
// positions of two atoms
template <typename F>
concept TwoPositionsPredicate = std::predicate<F,
const real_t, // pos 1 x
const real_t, // pos 1 y
const real_t, // pos 1 z
const real_t, // pos 2 x
const real_t, // pos 2 y
const real_t // pos 2 z
>;

} // namespace mrmd