Skip to content

Commit c06b0b6

Browse files
Use recursive apply
1 parent c097bf9 commit c06b0b6

File tree

9 files changed

+63
-71
lines changed

9 files changed

+63
-71
lines changed

GPU/Common/MemLayout.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,10 @@ struct interface<S, F, Flag::soa> { using type = wrapper<S, F>; };
296296
#define MEMLAYOUT_EXPAND(m) f(m, other.m)
297297

298298
#define MEMLAYOUT_APPLY_BINARY(STRUCT_NAME, ...)\
299+
template <template <class> class F_other, class Function>\
300+
constexpr STRUCT_NAME apply(STRUCT_NAME<F_other>& other, Function&& f) { return {__VA_ARGS__}; }\
301+
template <template <class> class F_other, class Function>\
302+
constexpr STRUCT_NAME apply(STRUCT_NAME<F_other>& other, Function&& f) const { return {__VA_ARGS__}; }\
299303
template <template <class> class F_other, class Function>\
300304
constexpr STRUCT_NAME apply(const STRUCT_NAME<F_other>& other, Function&& f) { return {__VA_ARGS__}; }\
301305
template <template <class> class F_other, class Function>\

GPU/GPUTracking/DataTypes/GPUDataTypes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ struct GPUTrackingInOutPointers {
225225
const AliHLTTPCRawCluster* rawClusters[NSECTORS] = {nullptr};
226226
uint32_t nRawClusters[NSECTORS] = {0};
227227
const o2::tpc::ClusterNativeAccess* clustersNative = nullptr;
228-
MemLayout::interface<GPUTPCTrackSkeleton, MemLayout::const_pointer, MemLayout::Flag::aos>::type sectorTracks[NSECTORS];
228+
MemLayout::interface<GPUTPCTrackSkeleton, MemLayout::const_pointer, GPUTPCTrackLayout>::type sectorTracks[NSECTORS];
229229
uint32_t nSectorTracks[NSECTORS] = {0};
230230
const GPUTPCHitId* sectorClusters[NSECTORS] = {nullptr};
231231
uint32_t nSectorClusters[NSECTORS] = {0};

GPU/GPUTracking/Global/GPUChainTracking.cxx

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -549,35 +549,33 @@ void GPUChainTracking::ClearIOPointers()
549549
new (&mIOMem) InOutMemory;
550550
}
551551

552-
void GPUChainTracking::AllocateIOMemorySectorTracks(
553-
uint32_t nSectorTracks,
554-
MemLayout::interface<GPUTPCTrackSkeleton, MemLayout::const_pointer, MemLayout::Flag::aos>::type& IOPtrSectorTrack,
555-
MemLayout::interface<GPUTPCTrackSkeleton, unique_ptr_array, MemLayout::Flag::aos>::type& IOMemSectorTrack
556-
) {
557-
AllocateIOMemoryHelper(nSectorTracks, IOPtrSectorTrack, IOMemSectorTrack);
558-
}
559-
void GPUChainTracking::AllocateIOMemorySectorTracks(
560-
uint32_t nSectorTracks,
561-
MemLayout::interface<GPUTPCTrackSkeleton, MemLayout::const_pointer, MemLayout::Flag::soa>::type& IOPtrSectorTrack,
562-
MemLayout::interface<GPUTPCTrackSkeleton, unique_ptr_array, MemLayout::Flag::soa>::type& IOMemSectorTrack
563-
) {
564-
AllocateIOMemoryHelper(nSectorTracks, IOPtrSectorTrack.mFirstHitID, IOMemSectorTrack.mFirstHitID);
565-
AllocateIOMemoryHelper(nSectorTracks, IOPtrSectorTrack.mNHits, IOMemSectorTrack.mNHits);
566-
AllocateIOMemoryHelper(nSectorTracks, IOPtrSectorTrack.mLocalTrackId, IOMemSectorTrack.mLocalTrackId);
567-
568-
AllocateIOMemoryHelper(nSectorTracks, IOPtrSectorTrack.mParam.mX, IOMemSectorTrack.mParam.mX);
569-
AllocateIOMemoryHelper(nSectorTracks, IOPtrSectorTrack.mParam.mC, IOMemSectorTrack.mParam.mC);
570-
AllocateIOMemoryHelper(nSectorTracks, IOPtrSectorTrack.mParam.mZOffset, IOMemSectorTrack.mParam.mZOffset);
571-
AllocateIOMemoryHelper(nSectorTracks, IOPtrSectorTrack.mParam.mP, IOMemSectorTrack.mParam.mP);
552+
namespace {
553+
554+
template <class Function>
555+
struct ApplyRecursive {
556+
Function f;
557+
558+
template <class T>
559+
const T * operator()(const T * & aosIOPtr, GPUChainTracking::unique_ptr_array<T>& aosIOMem) const { return f(aosIOPtr, aosIOMem); }
560+
561+
template <template <template <class> class> class S>
562+
S<MemLayout::const_pointer> operator()(S<MemLayout::const_pointer>& soaIOPtr, S<GPUChainTracking::unique_ptr_array>& soaIOMem) const {
563+
return soaIOPtr.apply(soaIOMem, ApplyRecursive{f});
564+
}
565+
};
566+
572567
}
573568

574569
void GPUChainTracking::AllocateIOMemory()
575570
{
576571
for (uint32_t i = 0; i < NSECTORS; i++) {
577572
AllocateIOMemoryHelper(mIOPtrs.nClusterData[i], mIOPtrs.clusterData[i], mIOMem.clusterData[i]);
578573
AllocateIOMemoryHelper(mIOPtrs.nRawClusters[i], mIOPtrs.rawClusters[i], mIOMem.rawClusters[i]);
579-
//AllocateIOMemoryHelper(mIOPtrs.nSectorTracks[i], mIOPtrs.sectorTracks[i], mIOMem.sectorTracks[i]); // debug
580-
AllocateIOMemorySectorTracks(mIOPtrs.nSectorTracks[i], mIOPtrs.sectorTracks[i], mIOMem.sectorTracks[i]); // new
574+
auto sectorTrackAllocator = [this, nSectorTrack = this->mIOPtrs.nSectorTracks[i]](auto& IOPtrsTrack, auto& mIOMemTrack) {
575+
AllocateIOMemoryHelper(nSectorTrack, IOPtrsTrack, mIOMemTrack);
576+
return IOPtrsTrack;
577+
};
578+
ApplyRecursive{sectorTrackAllocator}(mIOPtrs.sectorTracks[i], mIOMem.sectorTracks[i]);
581579
AllocateIOMemoryHelper(mIOPtrs.nSectorClusters[i], mIOPtrs.sectorClusters[i], mIOMem.sectorClusters[i]);
582580
}
583581
mIOMem.clusterNativeAccess.reset(new ClusterNativeAccess);

GPU/GPUTracking/Global/GPUChainTracking.h

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
#include "GPUChain.h"
1919
#include "GPUDataTypes.h"
20+
#include "GPUTPCTrack.h"
2021
#include "MemLayout.h"
22+
2123
#include <atomic>
2224
#include <mutex>
2325
#include <functional>
@@ -111,7 +113,7 @@ class GPUChainTracking : public GPUChain
111113
std::unique_ptr<AliHLTTPCRawCluster[]> rawClusters[NSECTORS];
112114
std::unique_ptr<o2::tpc::ClusterNative[]> clustersNative;
113115
std::unique_ptr<o2::tpc::ClusterNativeAccess> clusterNativeAccess;
114-
MemLayout::interface<GPUTPCTrackSkeleton, unique_ptr_array, MemLayout::Flag::aos>::type sectorTracks[NSECTORS];
116+
MemLayout::interface<GPUTPCTrackSkeleton, unique_ptr_array, GPUTPCTrackLayout>::type sectorTracks[NSECTORS];
115117
std::unique_ptr<GPUTPCHitId[]> sectorClusters[NSECTORS];
116118
std::unique_ptr<AliHLTTPCClusterMCLabel[]> mcLabelsTPC;
117119
std::unique_ptr<GPUTPCMCInfo[]> mcInfosTPC;
@@ -136,16 +138,6 @@ class GPUChainTracking : public GPUChain
136138

137139
// Read / Dump / Clear Data
138140
void ClearIOPointers();
139-
void AllocateIOMemorySectorTracks(
140-
uint32_t nSectorTracks,
141-
MemLayout::interface<GPUTPCTrackSkeleton, MemLayout::const_pointer, MemLayout::Flag::aos>::type& IOPtrSectorTrack,
142-
MemLayout::interface<GPUTPCTrackSkeleton, unique_ptr_array, MemLayout::Flag::aos>::type& IOMemSectorTrack
143-
);
144-
void AllocateIOMemorySectorTracks(
145-
uint32_t nSectorTracks,
146-
MemLayout::interface<GPUTPCTrackSkeleton, MemLayout::const_pointer, MemLayout::Flag::soa>::type& IOPtrSectorTrack,
147-
MemLayout::interface<GPUTPCTrackSkeleton, unique_ptr_array, MemLayout::Flag::soa>::type& IOMemSectorTrack
148-
);
149141
void AllocateIOMemory();
150142
using GPUChain::DumpData;
151143
void DumpData(const char* filename, const GPUTrackingInOutPointers* ioPtrs = nullptr);

GPU/GPUTracking/SectorTracker/GPUTPCTrack.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
namespace o2::gpu
2323
{
24+
constexpr MemLayout::Flag GPUTPCTrackLayout = MemLayout::Flag::aos;
25+
2426
/**
2527
* @class GPUTPCTrack
2628
*

GPU/GPUTracking/SectorTracker/GPUTPCTracker.cxx

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414

1515
#include "GPUTPCTracker.h"
1616
#include "GPUTPCRow.h"
17-
#include "GPUTPCTrack.h"
1817
#include "GPUCommonMath.h"
18+
#include "MemLayout.h"
1919

2020
#include "GPUTPCClusterData.h"
2121
#include "GPUO2DataTypes.h"
2222
#include "GPUTPCTrackParam.h"
23+
#include "GPUTPCTracklet.h"
24+
#include "GPUTPCTrack.h"
2325
#include "GPUParam.inc"
2426
#include "GPUTPCConvertImpl.h"
2527
#include "GPUDefParametersRuntime.h"
@@ -29,6 +31,7 @@
2931
#include <cmath>
3032
#include <algorithm>
3133
#include <stdexcept>
34+
#include <type_traits>
3235

3336
#include "GPUReconstruction.h"
3437
#include "GPUMemorySizeScalers.h"
@@ -108,47 +111,41 @@ void GPUTPCTracker::RegisterMemoryAllocation()
108111
mMemoryResOutput = mRec->RegisterMemoryAllocation(this, &GPUTPCTracker::SetPointersOutput, type, "TPCTrackerTracks");
109112
}
110113

111-
GPUhd() void GPUTPCTracker::SetPointersTrackletsHelper(void* & mem, MemLayout::interface<GPUTPCTrackletSkeleton, MemLayout::pointer, MemLayout::Flag::aos>::type& tracklets) {
112-
computePointerWithAlignment(mem, tracklets, mNMaxTracklets);
113-
}
114+
namespace {
115+
116+
template <class Function>
117+
struct ApplyMemberwise {
118+
Function g;
119+
120+
template <class ...Args>
121+
void operator()(Args& ...args) const { (g(args), ...); }
122+
};
114123

115-
GPUhd() void GPUTPCTracker::SetPointersTrackletsHelper(void* & mem, MemLayout::interface<GPUTPCTrackletSkeleton, MemLayout::pointer, MemLayout::Flag::soa>::type& tracklets) {
116-
computePointerWithAlignment(mem, tracklets.mFirstRow, mNMaxTracklets);
117-
computePointerWithAlignment(mem, tracklets.mLastRow, mNMaxTracklets);
124+
template <class Function>
125+
struct ApplyRecursive {
126+
Function f;
118127

119-
computePointerWithAlignment(mem, tracklets.mParam.mX, mNMaxTracklets);
120-
computePointerWithAlignment(mem, tracklets.mParam.mC, mNMaxTracklets);
121-
computePointerWithAlignment(mem, tracklets.mParam.mZOffset, mNMaxTracklets);
122-
computePointerWithAlignment(mem, tracklets.mParam.mP, mNMaxTracklets);
128+
template <class T>
129+
void operator()(T * & aos) const { f(aos); }
130+
131+
template <template <template <class> class> class S>
132+
void operator()(S<MemLayout::pointer>& soa) const { soa.apply(ApplyMemberwise<ApplyRecursive>{f}); }
133+
};
123134

124-
computePointerWithAlignment(mem, tracklets.mHitWeight, mNMaxTracklets);
125-
computePointerWithAlignment(mem, tracklets.mFirstHit, mNMaxTracklets);
126135
}
127136

128137
GPUhd() void* GPUTPCTracker::SetPointersTracklets(void* mem)
129138
{
130-
SetPointersTrackletsHelper(mem, mTracklets);
139+
auto tracklet_helper = [&mem, this](auto& tracklets) -> void { computePointerWithAlignment(mem, tracklets, mNMaxTracklets); };
140+
ApplyRecursive{tracklet_helper}(mTracklets);
131141
computePointerWithAlignment(mem, mTrackletRowHits, mNMaxRowHits);
132142
return mem;
133143
}
134144

135-
GPUhd() void GPUTPCTracker::SetPointersTracksHelper(void* & mem, MemLayout::interface<GPUTPCTrackSkeleton, MemLayout::pointer, MemLayout::Flag::aos>::type& tracks) {
136-
computePointerWithAlignment(mem, tracks, mNMaxTracks);
137-
}
138-
139-
GPUhd() void GPUTPCTracker::SetPointersTracksHelper(void* & mem, MemLayout::interface<GPUTPCTrackSkeleton, MemLayout::pointer, MemLayout::Flag::soa>::type& tracks) {
140-
computePointerWithAlignment(mem, tracks.mFirstHitID, mNMaxTracks);
141-
computePointerWithAlignment(mem, tracks.mNHits, mNMaxTracks);
142-
computePointerWithAlignment(mem, tracks.mLocalTrackId, mNMaxTracks);
143-
computePointerWithAlignment(mem, tracks.mParam.mX, mNMaxTracks);
144-
computePointerWithAlignment(mem, tracks.mParam.mC, mNMaxTracks);
145-
computePointerWithAlignment(mem, tracks.mParam.mZOffset, mNMaxTracks);
146-
computePointerWithAlignment(mem, tracks.mParam.mP, mNMaxTracks);
147-
}
148-
149145
GPUhd() void* GPUTPCTracker::SetPointersOutput(void* mem)
150146
{
151-
SetPointersTracksHelper(mem, mTracks);
147+
auto track_helper = [&mem, this](auto& tracks) -> void { computePointerWithAlignment(mem, tracks, mNMaxTracks); };
148+
ApplyRecursive{track_helper}(mTracks);
152149
computePointerWithAlignment(mem, mTrackHits, mNMaxTrackHits);
153150
return mem;
154151
}

GPU/GPUTracking/SectorTracker/GPUTPCTracker.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ namespace o2::gpu
3434
struct GPUTPCClusterData;
3535
struct GPUParam;
3636
template <template <class> class F>
37-
class GPUTPCTrackSkeleton;
37+
class GPUTPCTrackParamSkeleton;
3838
class GPUTPCRow;
3939

4040
class GPUTPCTracker : public GPUProcessor
@@ -107,11 +107,7 @@ class GPUTPCTracker : public GPUProcessor
107107
void* SetPointersScratch(void* mem);
108108
void* SetPointersScratchHost(void* mem);
109109
void* SetPointersCommon(void* mem);
110-
void SetPointersTrackletsHelper(void* & mem, MemLayout::interface<GPUTPCTrackletSkeleton, MemLayout::pointer, MemLayout::Flag::aos>::type& tracklets);
111-
void SetPointersTrackletsHelper(void* & mem, MemLayout::interface<GPUTPCTrackletSkeleton, MemLayout::pointer, MemLayout::Flag::soa>::type& tracklets);
112110
void* SetPointersTracklets(void* mem);
113-
void SetPointersTracksHelper(void* & mem, MemLayout::interface<GPUTPCTrackSkeleton, MemLayout::pointer, MemLayout::Flag::aos>::type& tracks);
114-
void SetPointersTracksHelper(void* & mem, MemLayout::interface<GPUTPCTrackSkeleton, MemLayout::pointer, MemLayout::Flag::soa>::type& tracks);
115111
void* SetPointersTracks(void* mem);
116112
void* SetPointersOutput(void* mem);
117113
void RegisterMemoryAllocation();
@@ -247,9 +243,9 @@ class GPUTPCTracker : public GPUProcessor
247243
// event
248244
GPUglobalref() commonMemoryStruct* mCommonMem = nullptr; // common event memory
249245
GPUglobalref() GPUTPCHitId* mTrackletStartHits = nullptr; // start hits for the tracklets
250-
GPUglobalref() MemLayout::interface<GPUTPCTrackletSkeleton, MemLayout::pointer, MemLayout::Flag::aos>::type mTracklets; // tracklets
246+
GPUglobalref() MemLayout::interface<GPUTPCTrackletSkeleton, MemLayout::pointer, GPUTPCTrackletLayout>::type mTracklets; // tracklets
251247
GPUglobalref() calink* mTrackletRowHits = nullptr; // Hits for each Tracklet in each row
252-
GPUglobalref() MemLayout::interface<GPUTPCTrackSkeleton, MemLayout::pointer, MemLayout::Flag::aos>::type mTracks; // reconstructed tracks
248+
GPUglobalref() MemLayout::interface<GPUTPCTrackSkeleton, MemLayout::pointer, GPUTPCTrackLayout>::type mTracks; // reconstructed tracks
253249
GPUglobalref() GPUTPCHitId* mTrackHits = nullptr; // array of track hit numbers
254250

255251
static int32_t StarthitSortComparison(const void* a, const void* b);

GPU/GPUTracking/SectorTracker/GPUTPCTrackerDump.cxx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "GPUTPCTracker.h"
1616
#include "GPUReconstruction.h"
1717
#include "GPUTPCHitId.h"
18+
#include "GPUTPCTracklet.h"
1819
#include "GPUTPCTrack.h"
1920
#include "GPULogging.h"
2021
#include "MemLayout.h"
@@ -144,7 +145,7 @@ void GPUTPCTracker::DumpTrackletHits(std::ostream& out)
144145
}
145146
for (int32_t jj = 0; jj < nTracklets; jj++) {
146147
const int32_t j = Ids[jj];
147-
const auto& tracklet = Tracklets()[j];
148+
MemLayout::wrapper<GPUTPCTrackletSkeleton, MemLayout::const_reference> tracklet = Tracklets()[j];
148149
out << "Tracklet " << std::setw(4) << jj << " (Rows: " << Tracklets()[j].FirstRow() << " - " << tracklet.LastRow() << ", Weight " << Tracklets()[j].HitWeight() << ") ";
149150
if (tracklet.LastRow() > tracklet.FirstRow() && (tracklet.FirstRow() >= GPUCA_ROW_COUNT || tracklet.LastRow() >= GPUCA_ROW_COUNT)) {
150151
GPUError("Error: Tracklet %d First %d Last %d", j, tracklet.FirstRow(), tracklet.LastRow());

GPU/GPUTracking/SectorTracker/GPUTPCTracklet.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
namespace o2::gpu
2323
{
24+
constexpr MemLayout::Flag GPUTPCTrackletLayout = MemLayout::Flag::aos;
25+
2426
/**
2527
* @class GPUTPCTracklet
2628
*

0 commit comments

Comments
 (0)