diff --git a/.github/workflows/cmake_build.yaml b/.github/workflows/cmake_build.yaml index a5cb813..e7c65e9 100644 --- a/.github/workflows/cmake_build.yaml +++ b/.github/workflows/cmake_build.yaml @@ -51,4 +51,5 @@ jobs: && ./mdio_stats_test \ && ./mdio_utils_trim_test \ && ./mdio_utils_delete_test \ - && ./mdio_variable_collection_test + && ./mdio_variable_collection_test \ + && ./mdio_coordinate_selector_test \ No newline at end of file diff --git a/mdio/CMakeLists.txt b/mdio/CMakeLists.txt index 0d53797..47fc89f 100644 --- a/mdio/CMakeLists.txt +++ b/mdio/CMakeLists.txt @@ -334,3 +334,26 @@ mdio_cc_test( tensorstore::util_status_testutil nlohmann_json_schema_validator ) + +mdio_cc_test( + NAME + coordinate_selector_test + SRCS + coordinate_selector_test.cc + COPTS + ${mdio_DEFAULT_COPTS} + LINKOPTS + ${mdio_DEFAULT_LINKOPTS} + DEPS + GTest::gmock_main + tensorstore::driver_array + tensorstore::driver_zarr + tensorstore::driver_json + tensorstore::kvstore_file + tensorstore::tensorstore + tensorstore::stack + tensorstore::index_space_dim_expression + tensorstore::index_space_index_transform + tensorstore::util_status_testutil + nlohmann_json_schema_validator +) diff --git a/mdio/coordinate_selector.h b/mdio/coordinate_selector.h new file mode 100644 index 0000000..43e952c --- /dev/null +++ b/mdio/coordinate_selector.h @@ -0,0 +1,570 @@ +// Copyright 2025 TGS + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MDIO_COORDINATE_SELECTOR_H_ +#define MDIO_COORDINATE_SELECTOR_H_ + +#include +#include +#include +#include +#include + +#include "mdio/dataset.h" +#include "mdio/impl.h" +#include "mdio/variable.h" +#include "tensorstore/array.h" +#include "tensorstore/box.h" +#include "tensorstore/index_space/dim_expression.h" +#include "tensorstore/index_space/index_domain.h" +#include "tensorstore/index_space/index_domain_builder.h" +#include "tensorstore/index_space/index_transform.h" +#include "tensorstore/index_space/index_transform_builder.h" +#include "tensorstore/util/span.h" + +// #define MDIO_INTERNAL_PROFILING 0 // TODO(BrianMichell): Remove simple +// profiling code once we approach a more mature API access. + +namespace mdio { + +#ifdef MDIO_INTERNAL_PROFILING +void timer(std::chrono::high_resolution_clock::time_point start) { + auto end = std::chrono::high_resolution_clock::now(); + auto duration = + std::chrono::duration_cast(end - start); + std::cout << "Time taken: " << duration.count() << " microseconds" + << std::endl; +} +#endif + +// — helper to tag multi-key sorts —— +template +struct SortKey { + std::string key; + using value_type = T; +}; + +// — trait to detect ValueDescriptor —— +template +struct is_value_descriptor : std::false_type {}; +template +struct is_value_descriptor> : std::true_type {}; +template +inline constexpr bool is_value_descriptor_v = + is_value_descriptor>::value; + +// — trait to detect SortKey —— +template +struct is_sort_key : std::false_type {}; +template +struct is_sort_key> : std::true_type {}; +template +inline constexpr bool is_sort_key_v = is_sort_key>::value; + +/// \brief Collects valid index selections per dimension for a Dataset without +/// performing slicing immediately. +/// +/// Only dimensions explicitly filtered via filterByCoordinate appear in the +/// map; any dimension not present should be treated as having its full index +/// range. +class CoordinateSelector { + public: + /// Construct from an existing Dataset (captures its full domain). + explicit CoordinateSelector(Dataset& dataset) // NOLINT (non-const) + : dataset_(dataset), base_domain_(dataset.domain) {} + + template + Future...>> ReadDataVariables( + std::vector const& data_variables, Ops const&... ops) { + if (data_variables.size() != sizeof...(OutTs)) { + return absl::InvalidArgumentError( + "ReadDataVariables: number of names must match number of OutTs"); + } + // 1) apply all filters & sorts in order + absl::Status st = absl::OkStatus(); + ((st = st.ok() ? _applyOp(ops).status() : st), ...); + if (!st.ok()) return st; + + // 2) kick off and await all reads + return _readMultiple(data_variables); + } + + void reset() { kept_runs_.clear(); } + + /** + * @brief Filter the Dataset by the given coordinate. + * Limitations: + * - Only a single filter is currently tested. + * - A bug exists if the filter value does not make a perfect hyper-rectangle + * within its dimensions. + * + */ + template + mdio::Future filterByCoordinate(const ValueDescriptor& descriptor) { + if (kept_runs_.empty()) { + return _init_runs(descriptor); + } else { + return _add_new_run(descriptor); + } + } + + template + Future sortSelectionByKey(const std::string& sort_key) { +#ifdef MDIO_INTERNAL_PROFILING + auto start = std::chrono::high_resolution_clock::now(); +#endif + const size_t n = kept_runs_.size(); + + // 1) Fire off all reads in parallel and gather the key values + std::vector>> reads; + reads.reserve(n); + for (auto const& desc : kept_runs_) { + MDIO_ASSIGN_OR_RETURN(auto ds, dataset_.isel(desc)); + MDIO_ASSIGN_OR_RETURN(auto var, ds.variables.at(sort_key)); + reads.push_back(var.Read()); + } + + std::vector keys; + keys.reserve(n); +#ifdef MDIO_INTERNAL_PROFILING + std::cout << "Set up sorting of " << sort_key << " ... "; + timer(start); + start = std::chrono::high_resolution_clock::now(); +#endif + for (auto& f : reads) { + // if (!f.status().ok()) return f.status(); + // auto data = f.value(); + // keys.push_back(data.get_data_accessor().data()[data.get_flattened_offset()]); + MDIO_ASSIGN_OR_RETURN(auto resolution, _resolve_future(f)); + auto data = std::get<0>(resolution); + auto data_ptr = static_cast(std::get<1>(resolution)); + auto offset = std::get<2>(resolution); + // auto n = std::get<3>(resolution); // Not required + keys.push_back(data_ptr[offset]); + } +#ifdef MDIO_INTERNAL_PROFILING + std::cout << "Waiting for reads to complete for " << sort_key << " ... "; + timer(start); + start = std::chrono::high_resolution_clock::now(); +#endif + + // 2) Build and stable-sort an index array [0…n-1] by key + std::vector idx(n); + std::iota(idx.begin(), idx.end(), 0); + std::stable_sort(idx.begin(), idx.end(), + [&](size_t a, size_t b) { return keys[a] < keys[b]; }); +#ifdef MDIO_INTERNAL_PROFILING + std::cout << "Sorting time for " << sort_key << " ... "; + timer(start); + start = std::chrono::high_resolution_clock::now(); +#endif + + // 3) One linear, move-only pass into a temp buffer + using Desc = std::decay_t::value_type; + std::vector tmp; + tmp.reserve(n); + for (size_t new_pos = 0; new_pos < n; ++new_pos) { + tmp.emplace_back(std::move(kept_runs_[idx[new_pos]])); + } + + // 4) Steal the buffer back + kept_runs_ = std::move(tmp); +#ifdef MDIO_INTERNAL_PROFILING + std::cout << "Stealing buffer back time for " << sort_key << " ... "; + timer(start); +#endif + return absl::OkStatus(); + } + + template + Future> readSelection(const std::string& output_variable) { +#ifdef MDIO_INTERNAL_PROFILING + auto start = std::chrono::high_resolution_clock::now(); +#endif + std::vector>> reads; + reads.reserve(kept_runs_.size()); + std::vector ret; + + for (const auto& desc : kept_runs_) { + MDIO_ASSIGN_OR_RETURN(auto ds, dataset_.isel(desc)); + MDIO_ASSIGN_OR_RETURN(auto var, ds.variables.at(output_variable)); + auto fut = var.Read(); + reads.push_back(fut); + if (var.rank() == 1) { + break; + } + } + +#ifdef MDIO_INTERNAL_PROFILING + std::cout << "Set up reading of " << output_variable << " ... "; + timer(start); + start = std::chrono::high_resolution_clock::now(); +#endif + + for (auto& f : reads) { + MDIO_ASSIGN_OR_RETURN(auto resolution, _resolve_future(f)); + auto data = std::get<0>(resolution); + auto data_ptr = static_cast(std::get<1>(resolution)); + auto offset = std::get<2>(resolution); + auto n = std::get<3>(resolution); + std::vector buffer(n); + std::memcpy(buffer.data(), data_ptr + offset, n * sizeof(T)); + ret.insert(ret.end(), buffer.begin(), buffer.end()); + } + +#ifdef MDIO_INTERNAL_PROFILING + std::cout << "Reading time for " << output_variable << " ... "; + timer(start); +#endif + return ret; + } + + private: + Dataset& dataset_; + tensorstore::IndexDomain<> base_domain_; + std::vector>> kept_runs_; + std::map> cached_variables_; + + template + Future _applyOp(D const& op) { + if constexpr (is_value_descriptor_v) { + auto fut = filterByCoordinate(op); + if (!fut.status().ok()) { + return fut.status(); + } + return absl::OkStatus(); + } else if constexpr (is_sort_key_v) { + using SortT = typename std::decay_t::value_type; + auto fut = sortSelectionByKey(op.key); + if (!fut.status().ok()) { + return fut.status(); + } + return absl::OkStatus(); + } else { + return absl::UnimplementedError( + "query(): RangeDescriptor and ListDescriptor not supported"); + } + } + + // helper: expands readSelection(vars[I])... + template + Future...>> _readMultipleImpl( + std::vector const& vars, std::index_sequence) { + // 1) start all reads + auto futs = std::make_tuple(readSelection(vars[I])...); + + // 2) wait on them in order + absl::Status st = absl::OkStatus(); + std::tuple...> results; + // fold over I... + ( + [&]() { + if (!st.ok()) return; + auto& f = std::get(futs); + st = f.status(); + if (st.ok()) std::get(results) = std::move(f.value()); + }(), + ...); + if (!st.ok()) return st; + return results; + } + + template + Future...>> _readMultiple( + std::vector const& vars) { + return _readMultipleImpl(vars, + std::index_sequence_for{}); + } + + /* + TODO: The built RangeDescriptors aren't behaving as I hoped. + They are building the longest runs possible properly, however + as it becomes disjointed we start to lose some info. + + e.g. We can have [0,1], [0, 25], [0, 120] but + the last dimension is actually [0, 1000]. + + What we should get instead is [0, 1], [0, 24], [0, 1000] and [0, 1], [24, 25], + [0, 120] + */ + + template + Future _init_runs(const ValueDescriptor& descriptor) { + using Interval = typename Variable::Interval; +#ifdef MDIO_INTERNAL_PROFILING + auto start = std::chrono::high_resolution_clock::now(); +#endif + MDIO_ASSIGN_OR_RETURN( + auto var, dataset_.variables.at(std::string(descriptor.label.label()))); + + const T* data_ptr; + Index offset; + Index n_samples; + MDIO_ASSIGN_OR_RETURN(auto intervals, var.get_intervals()); + if (cached_variables_.find(descriptor.label.label()) == + cached_variables_.end()) { + // TODO(BrianMichell): Ensure that the domain has not changed. + std::cout << "Reading VariableData" << std::endl; + auto fut = var.Read(); + MDIO_ASSIGN_OR_RETURN(auto resolution, _resolve_future(fut)); + auto dataToCache = std::get<0>(resolution); + cached_variables_.insert_or_assign(descriptor.label.label(), + std::move(dataToCache)); + } + auto it = cached_variables_.find(descriptor.label.label()); + if (it == cached_variables_.end()) { + std::stringstream ss; + ss << "Cached variable not found for coordinate '" + << descriptor.label.label() << "'"; + return absl::NotFoundError(ss.str()); + } + auto& data = it->second; + data_ptr = static_cast(data.get_data_accessor().data()); + offset = data.get_flattened_offset(); + n_samples = data.num_samples(); + + auto current_pos = intervals; + bool isInRun = false; + std::vector> local_runs; + + std::size_t run_idx = offset; + +#ifdef MDIO_INTERNAL_PROFILING + std::cout << "Initialize and read time... "; + timer(start); + start = std::chrono::high_resolution_clock::now(); +#endif + + for (mdio::Index idx = offset; idx < offset + n_samples; ++idx) { + bool is_match = data_ptr[idx] == descriptor.value; + + if (is_match && !isInRun) { + // The start of a new run + isInRun = true; + for (auto i = run_idx; i < idx; ++i) { + // _current_position_increment(current_pos, intervals); + _current_position_increment(current_pos, intervals); + } + // _current_position_stride(current_pos, intervals, idx - run_idx); + run_idx = idx; + std::vector run = current_pos; + local_runs.push_back(std::move(run)); + } else if (is_match && isInRun) { + // Somewhere in the middle of a run + // do nothing TODO: Remove me + } else if (!is_match && isInRun) { + // The end of a run + isInRun = false; + // Use 1 less than the current index to ensure we get the correct end + // location. + for (auto i = run_idx; i < idx - 1; ++i) { + _current_position_increment(current_pos, intervals); + } + run_idx = idx; + auto& last_run = local_runs.back(); + for (auto i = 0; i < current_pos.size(); ++i) { + last_run[i].exclusive_max = current_pos[i].inclusive_min + 1; + } + // We need to advance to the actual current position + _current_position_increment(current_pos, intervals); + } else if (!is_match && !isInRun) { + // No run at all + // do nothing TODO: Remove me + } else { + // base case TODO: Remove me + } + } + +#ifdef MDIO_INTERNAL_PROFILING + std::cout << "Build runs time... "; + timer(start); + start = std::chrono::high_resolution_clock::now(); +#endif + + if (local_runs.empty()) { + std::stringstream ss; + ss << "No matches for coordinate '" << descriptor.label.label() << "'"; + return absl::NotFoundError(ss.str()); + } + + kept_runs_ = _from_intervals(local_runs); +#ifdef MDIO_INTERNAL_PROFILING + std::cout << "Finalize time... "; + timer(start); +#endif + return absl::OkStatus(); + } + + /** + * @brief Using the existing runs, further filter the Dataset by the new + * coordiante. + */ + template + Future _add_new_run(const ValueDescriptor& descriptor) { + using Interval = typename Variable::Interval; + std::vector> new_runs; + + std::vector> + stored_intervals; // Use this to ensure everything remains in memory + // until the Intervals are no longer needed. + stored_intervals.reserve(kept_runs_.size()); + + bool is_first_run = true; + + for (const auto& desc : kept_runs_) { + MDIO_ASSIGN_OR_RETURN(auto ds, dataset_.isel(desc)); + MDIO_ASSIGN_OR_RETURN( + auto var, ds.variables.get(std::string(descriptor.label.label()))); + auto fut = var.Read(); + MDIO_ASSIGN_OR_RETURN(auto intervals, var.get_intervals()); + + if (is_first_run) { + is_first_run = false; + if (intervals.size() != kept_runs_[0].size()) { + std::cout << "WARNING: Different coordinate dimensions detected. " + "This behavior is not yet supported." + << std::endl; + std::cout + << "\tFor expected behavior, please ensure all previous " + "dimensions are less than or equal to the current dimension." + << std::endl; + } + } + + stored_intervals.push_back(std::move( + intervals)); // Just to ensure nothing gets freed prematurely. + MDIO_ASSIGN_OR_RETURN(auto resolution, _resolve_future(fut)); + auto data = std::get<0>(resolution); + auto data_ptr = std::get<1>(resolution); + auto offset = std::get<2>(resolution); + auto n = std::get<3>(resolution); + + auto current_pos = intervals; + bool isInRun = false; + + std::size_t run_idx = offset; + + for (Index idx = offset; idx < offset + n; ++idx) { + bool is_match = data_ptr[idx] == descriptor.value; + if (is_match && !isInRun) { + isInRun = true; + for (auto i = run_idx; i < idx; ++i) { + _current_position_increment(current_pos, intervals); + } + run_idx = idx; + std::vector run = current_pos; + new_runs.push_back(std::move(run)); + } else if (is_match && isInRun) { + // Somewhere in the middle of a run + // do nothing TODO: Remove me + } else if (!is_match && isInRun) { + // The end of a run + // TODO(BrianMichell): Ensure we are using the correct index (see + // above) + isInRun = false; + for (auto i = run_idx; i < idx; ++i) { + _current_position_increment(current_pos, intervals); + } + run_idx = idx; + auto& last_run = new_runs.back(); + for (auto i = 0; i < current_pos.size(); ++i) { + last_run[i].exclusive_max = current_pos[i].inclusive_min + 1; + } + } else if (!is_match && !isInRun) { + // No run at all + // do nothing TODO: Remove me + } else { + // base case TODO: Remove me + } + } + } + + if (new_runs.empty()) { + std::stringstream ss; + ss << "No matches for coordinate '" << descriptor.label.label() << "'"; + return absl::NotFoundError(ss.str()); + } + + kept_runs_ = _from_intervals( + new_runs); // TODO(BrianMichell): We need to ensure we don't + // accidentally drop any pre-sliced dimensions... + return absl::OkStatus(); + } + + template + std::vector>> _from_intervals( + std::vector::Interval>>& + intervals) { + std::vector>> ret; + ret.reserve(intervals.size()); + for (auto const& run : intervals) { + std::vector> run_descs; + run_descs.reserve(run.size()); + for (auto const& interval : run) { + run_descs.emplace_back(mdio::RangeDescriptor{ + interval.label, interval.inclusive_min, interval.exclusive_max, 1}); + } + ret.push_back(std::move(run_descs)); + } + return ret; + } + + /// Advance a multidimensional odometer position by one step. + template + void _current_position_increment( + std::vector::Interval>& + position, // NOLINT (non-const) + const std::vector::Interval>& interval) const { + for (std::size_t d = position.size(); d-- > 0;) { + if (position[d].inclusive_min + 1 < interval[d].exclusive_max) { + ++position[d].inclusive_min; + return; + } + position[d].inclusive_min = interval[d].inclusive_min; + } + } + + template + void _current_position_stride( + std::vector::Interval>& + position, // NOLINT (non-const) + const std::vector::Interval>& interval, + const std::size_t num_elements) { + auto dims = position.size(); + if (position[dims - 1].exclusive_max < + position[dims - 1].inclusive_min + num_elements) { + position[dims - 1].inclusive_min = + position[dims - 1].inclusive_min + num_elements; + return; + } + for (auto i = 0; i < num_elements; ++i) { + _current_position_increment(position, interval); + } + } + + template + Result, const T*, Index, Index>> _resolve_future( + Future>& fut) { // NOLINT (non-const) + if (!fut.status().ok()) return fut.status(); + auto data = fut.value(); + const T* data_ptr = data.get_data_accessor().data(); + Index offset = data.get_flattened_offset(); + Index n_samples = data.num_samples(); + return std::make_tuple(std::move(data), data_ptr, offset, n_samples); + } +}; + +} // namespace mdio + +#endif // MDIO_COORDINATE_SELECTOR_H_ diff --git a/mdio/coordinate_selector_test.cc b/mdio/coordinate_selector_test.cc new file mode 100644 index 0000000..93eaec9 --- /dev/null +++ b/mdio/coordinate_selector_test.cc @@ -0,0 +1,341 @@ +// Copyright 2025 TGS + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mdio/coordinate_selector.h" + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "mdio/dataset.h" +#include "mdio/dataset_factory.h" +#include "tensorstore/driver/driver.h" +#include "tensorstore/driver/registry.h" +#include "tensorstore/index_space/dim_expression.h" +#include "tensorstore/index_space/index_domain_builder.h" +#include "tensorstore/kvstore/kvstore.h" +#include "tensorstore/kvstore/operations.h" +#include "tensorstore/open.h" +#include "tensorstore/tensorstore.h" +#include "tensorstore/util/future.h" +#include "tensorstore/util/status_testutil.h" + +// clang-format off +#include // NOLINT +// clang-format on + +namespace { + +mdio::Result SetupDataset() { + std::string ds_path = "generic_with_coords.mdio"; + std::string schema_str = R"( + { + "metadata": { + "name": "generic_with_coords", + "apiVersion": "1.0.0", + "createdOn": "2025-05-13T12:00:00.000000-05:00", + "attributes": { + "generic type" : true + } + }, + "variables": [ + { + "name": "DataVariable", + "dataType": "float32", + "dimensions": ["task", "trace", "sample"], + "coordinates": ["inline", "crossline", "live_mask"], + "metadata": { + "chunkGrid": { + "name": "regular", + "configuration": { "chunkShape": [1, 64, 128] } + } + } + }, + { + "name": "inline", + "dataType": "int32", + "dimensions": ["task", "trace"], + "coordinates": ["live_mask"], + "metadata": { + "chunkGrid": { + "name": "regular", + "configuration": { "chunkShape": [1, 64] } + } + } + }, + { + "name": "crossline", + "dataType": "int32", + "dimensions": ["task", "trace"], + "coordinates": ["live_mask"], + "metadata": { + "chunkGrid": { + "name": "regular", + "configuration": { "chunkShape": [1, 64] } + } + } + }, + { + "name": "live_mask", + "dataType": "bool", + "dimensions": ["task", "trace"], + "coordinates": ["inline", "crossline"], + "metadata": { + "chunkGrid": { + "name": "regular", + "configuration": { "chunkShape": [1, 64] } + } + } + }, + { + "name": "task", + "dataType": "uint32", + "dimensions": [{"name": "task", "size": 25}], + "metadata": { + "chunkGrid": { + "name": "regular", + "configuration": { "chunkShape": [1] } + } + } + }, + { + "name": "trace", + "dataType": "uint32", + "dimensions": [{"name": "trace", "size": 256}], + "metadata": { + "chunkGrid": { + "name": "regular", + "configuration": { "chunkShape": [64] } + } + } + }, + { + "name": "sample", + "dataType": "uint32", + "dimensions": [{"name": "sample", "size": 512}], + "metadata": { + "chunkGrid": { + "name": "regular", + "configuration": { "chunkShape": [128] } + } + } + } + ] + })"; + + auto schema = ::nlohmann::json::parse(schema_str); + auto dsFut = + mdio::Dataset::from_json(schema, ds_path, mdio::constants::kCreate); + if (!dsFut.status().ok()) { + return ds_path; + } + auto ds = dsFut.value(); + + // Populate the dataset with data + MDIO_ASSIGN_OR_RETURN(auto dataVar, ds.variables.get("DataVariable")); + MDIO_ASSIGN_OR_RETURN(auto inlineVar, ds.variables.get("inline")); + MDIO_ASSIGN_OR_RETURN(auto crosslineVar, + ds.variables.get("crossline")); + MDIO_ASSIGN_OR_RETURN(auto liveMaskVar, ds.variables.get("live_mask")); + + MDIO_ASSIGN_OR_RETURN(auto varData, mdio::from_variable(dataVar)); + MDIO_ASSIGN_OR_RETURN(auto inlineData, + mdio::from_variable(inlineVar)); + MDIO_ASSIGN_OR_RETURN(auto crosslineData, + mdio::from_variable(crosslineVar)); + MDIO_ASSIGN_OR_RETURN(auto liveMaskData, + mdio::from_variable(liveMaskVar)); + + auto varDataPtr = varData.get_data_accessor().data(); + auto inlineDataPtr = inlineData.get_data_accessor().data(); + auto crosslineDataPtr = crosslineData.get_data_accessor().data(); + auto liveMaskDataPtr = liveMaskData.get_data_accessor().data(); + + auto varOffset = varData.get_flattened_offset(); + auto inlineOffset = inlineData.get_flattened_offset(); + auto crosslineOffset = crosslineData.get_flattened_offset(); + auto liveMaskOffset = liveMaskData.get_flattened_offset(); + + std::size_t coords = 0; + std::size_t var = 0; + + for (int i = 0; i < 15; ++i) { // Only 15 of the 25 tasks were "assigned" + for (int j = 0; j < 256; ++j) { + inlineDataPtr[coords + inlineOffset] = coords; + crosslineDataPtr[coords + crosslineOffset] = coords * 4; + liveMaskDataPtr[coords + liveMaskOffset] = true; + for (int k = 0; k < 512; ++k) { + varDataPtr[var + varOffset] = i * 256 * 512 + j * 512 + k; + var++; + } + coords++; + } + } + + auto varDataFut = dataVar.Write(varData); + auto inlineDataFut = inlineVar.Write(inlineData); + auto crosslineDataFut = crosslineVar.Write(crosslineData); + auto liveMaskDataFut = liveMaskVar.Write(liveMaskData); + + if (!varDataFut.status().ok()) { + return varDataFut.status(); + } + if (!inlineDataFut.status().ok()) { + return inlineDataFut.status(); + } + if (!crosslineDataFut.status().ok()) { + return crosslineDataFut.status(); + } + if (!liveMaskDataFut.status().ok()) { + return liveMaskDataFut.status(); + } + + return ds_path; +} + +TEST(Intersection, SETUP) { + auto pathResult = SetupDataset(); + ASSERT_TRUE(pathResult.status().ok()) << pathResult.status(); +} + +TEST(Intersection, constructor) { + auto pathResult = SetupDataset(); + ASSERT_TRUE(pathResult.status().ok()) << pathResult.status(); + auto path = pathResult.value(); + + auto dsFut = mdio::Dataset::Open(path, mdio::constants::kOpen); + ASSERT_TRUE(dsFut.status().ok()) << dsFut.status(); + auto ds = dsFut.value(); + + mdio::CoordinateSelector cs(ds); +} + +TEST(Intersection, add_selection) { + auto pathResult = SetupDataset(); + ASSERT_TRUE(pathResult.status().ok()) << pathResult.status(); + auto path = pathResult.value(); + + auto dsFut = mdio::Dataset::Open(path, mdio::constants::kOpen); + ASSERT_TRUE(dsFut.status().ok()) << dsFut.status(); + auto ds = dsFut.value(); + + mdio::CoordinateSelector cs(ds); + mdio::ValueDescriptor liveMaskDesc = {"live_mask", true}; + auto isFut = cs.filterByCoordinate(liveMaskDesc); + ASSERT_TRUE(isFut.status().ok()) + << isFut + .status(); // When this resolves, the selection object is updated. + + // auto selections = cs.selections(); + // ASSERT_EQ(selections.size(), 2) << "Expected 2 dimensions in the selection + // map but got " << selections.size(); +} + +TEST(Intersection, range_descriptors) { + auto pathResult = SetupDataset(); + ASSERT_TRUE(pathResult.status().ok()) << pathResult.status(); + auto path = pathResult.value(); + + auto dsFut = mdio::Dataset::Open(path, mdio::constants::kOpen); + ASSERT_TRUE(dsFut.status().ok()) << dsFut.status(); + auto ds = dsFut.value(); + + mdio::CoordinateSelector cs(ds); + mdio::ValueDescriptor liveMaskDesc = {"live_mask", true}; + auto isFut = cs.filterByCoordinate(liveMaskDesc); + ASSERT_TRUE(isFut.status().ok()) + << isFut + .status(); // When this resolves, the selection object is updated. + + // auto rangeDescriptors = cs.range_descriptors(); + + // ASSERT_EQ(rangeDescriptors.size(), 2); + // EXPECT_EQ(rangeDescriptors[0].label.label(), "task") << "Expected first + // RangeDescriptor to be for the 'task' dimension"; + // EXPECT_EQ(rangeDescriptors[1].label.label(), "trace") << "Expected second + // RangeDescriptor to be for the 'trace' dimension"; + + // EXPECT_EQ(rangeDescriptors[0].start, 0) << "Expected first RangeDescriptor + // to start at index 0"; EXPECT_EQ(rangeDescriptors[0].stop, 15) << "Expected + // first RangeDescriptor to stop at index 15"; + // EXPECT_EQ(rangeDescriptors[0].step, 1) << "Expected first RangeDescriptor + // to have a step of 1"; + + // EXPECT_EQ(rangeDescriptors[1].start, 0) << "Expected second RangeDescriptor + // to start at index 0"; EXPECT_EQ(rangeDescriptors[1].stop, 256) << "Expected + // second RangeDescriptor to stop at index 256"; + // EXPECT_EQ(rangeDescriptors[1].step, 1) << "Expected second RangeDescriptor + // to have a step of 1"; +} + +TEST(Intersection, get_inline_range) { + auto pathResult = SetupDataset(); + ASSERT_TRUE(pathResult.status().ok()) << pathResult.status(); + auto path = pathResult.value(); + + auto dsFut = mdio::Dataset::Open(path, mdio::constants::kOpen); + ASSERT_TRUE(dsFut.status().ok()) << dsFut.status(); + auto ds = dsFut.value(); + + mdio::CoordinateSelector cs(ds); + mdio::ValueDescriptor liveMaskDesc = {"live_mask", true}; + auto isFut = cs.filterByCoordinate(liveMaskDesc); + ASSERT_TRUE(isFut.status().ok()) + << isFut + .status(); // When this resolves, the selection object is updated. + isFut = cs.filterByCoordinate(mdio::ValueDescriptor{"inline", 18}); + ASSERT_TRUE(isFut.status().ok()) << isFut.status(); + // auto rangeDescriptors = cs.range_descriptors(); + + // for (const auto& desc : rangeDescriptors) { + // std::cout << "Dimension: " << desc.label.label() << " Start: " << + // desc.start << " Stop: " << desc.stop << " Step: " << desc.step << + // std::endl; + // } +} + +TEST(Intersection, get_inline_range_dead) { + auto pathResult = SetupDataset(); + ASSERT_TRUE(pathResult.status().ok()) << pathResult.status(); + auto path = pathResult.value(); + + auto dsFut = mdio::Dataset::Open(path, mdio::constants::kOpen); + ASSERT_TRUE(dsFut.status().ok()) << dsFut.status(); + auto ds = dsFut.value(); + + mdio::CoordinateSelector cs(ds); + mdio::ValueDescriptor liveMaskDesc = {"live_mask", true}; + auto isFut = cs.filterByCoordinate(liveMaskDesc); + ASSERT_TRUE(isFut.status().ok()) + << isFut + .status(); // When this resolves, the selection object is updated. + isFut = cs.filterByCoordinate(mdio::ValueDescriptor{"inline", 5000}); + EXPECT_FALSE(isFut.status().ok()) << "Expected an error when adding a " + "selection for an invalid inline index"; + + // auto rangeDescriptors = is.range_descriptors(); + // for (const auto& desc : rangeDescriptors) { + // std::cout << "Dimension: " << desc.label.label() << " Start: " << + // desc.start << " Stop: " << desc.stop << " Step: " << desc.step << + // std::endl; + // } +} + +} // namespace diff --git a/mdio/dataset.h b/mdio/dataset.h index 815fb58..da96a59 100644 --- a/mdio/dataset.h +++ b/mdio/dataset.h @@ -16,6 +16,7 @@ #define MDIO_API_VERSION "1.0.0" +#include #include #include #include @@ -623,24 +624,35 @@ class Dataset { return absl::InvalidArgumentError("No slices provided."); } - if (slices.size() > internal::kMaxNumSlices) { - return absl::InvalidArgumentError( - absl::StrCat("Too many slices provided or implicitly generated. " - "Maximum number of slices is ", - internal::kMaxNumSlices, " but ", slices.size(), - " were provided.\n\tUse -DMAX_NUM_SLICES cmake flag to " - "increase the maximum number of slices.")); + // 1) Group descriptors by their dimension label + std::map>> groups; + for (auto& desc : slices) { + groups[desc.label.label()].push_back(desc); } - std::vector> slicesCopy = slices; - for (int i = slices.size(); i <= internal::kMaxNumSlices; i++) { - slicesCopy.emplace_back( - RangeDescriptor({internal::kInertSliceKey, 0, 1, 1})); + // 2) Walk through each dimension-group and break it into kMax-sized windows + Dataset current = *this; + for (auto& [label, descs] : groups) { + for (size_t i = 0; i < descs.size(); i += internal::kMaxNumSlices) { + size_t end = std::min(i + internal::kMaxNumSlices, descs.size()); + std::vector> window(descs.begin() + i, + descs.begin() + end); + + // 3) Pad this window up to kMax (if your impl still needs padding) + window.reserve(internal::kMaxNumSlices); + for (size_t p = window.size(); p < internal::kMaxNumSlices; ++p) { + window.emplace_back( + RangeDescriptor{internal::kInertSliceKey, 0, 1, 1}); + } + + MDIO_ASSIGN_OR_RETURN( + current, + current.call_isel_with_vector_impl( + window, std::make_index_sequence{})); + } } - // Generate the index sequence and call the implementation - return call_isel_with_vector_impl( - slicesCopy, std::make_index_sequence{}); + return current; } /** @@ -849,7 +861,8 @@ class Dataset { // The map 'label_to_indices' is now populated with all the relevant // indices. You can now proceed with further processing based on this map. - return isel(slices); + return isel( + static_cast>&>(slices)); } else if constexpr ((std::is_same_v>&>(slices)); } else { std::map> label_to_range; // pair.first = start, pair.second = stop @@ -976,7 +990,8 @@ class Dataset { "No slices could be made from the given descriptors."); } - return isel(slices); + return isel( + static_cast>&>(slices)); } return absl::OkStatus(); diff --git a/mdio/mdio.h b/mdio/mdio.h index 1b3b9f3..082bc0d 100644 --- a/mdio/mdio.h +++ b/mdio/mdio.h @@ -21,6 +21,7 @@ #ifndef MDIO_MDIO_H_ #define MDIO_MDIO_H_ +#include "mdio/coordinate_selector.h" #include "mdio/dataset.h" #endif // MDIO_MDIO_H_ diff --git a/mdio/variable.h b/mdio/variable.h index 14cad32..a8b5b2e 100644 --- a/mdio/variable.h +++ b/mdio/variable.h @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -575,6 +576,10 @@ Future> OpenVariable(const nlohmann::json& json_store, } // the negative of this is valid for tensorstore ... + // TODO(BrianMichell): Look into making the recheck_cached_data an open + // option. store_spec["recheck_cached_data"] = false; // This could become + // problematic if we are doing read/write operations. + auto spec = tensorstore::MakeReadyFuture<::nlohmann::json>(store_spec); // open a store: @@ -1002,6 +1007,9 @@ class Variable { } if (slices.size() > internal::kMaxNumSlices) { + // We are expecting the only entry point for this method to be fro mthe + // Dataset::isel method. That method should handle the partitioning of the + // slices. return absl::InvalidArgumentError( absl::StrCat("Too many slices provided or implicitly generated. " "Maximum number of slices is ", @@ -1041,137 +1049,101 @@ class Variable { */ template Result slice(const Descriptors&... descriptors) const { - constexpr size_t numDescriptors = sizeof...(descriptors); - - auto tuple_descs = std::make_tuple(descriptors...); + // 1) Pack descriptors + constexpr size_t N = sizeof...(Descriptors); + std::array, N> descs = {descriptors...}; + // 2) Clamp + precondition check std::vector labels; - labels.reserve(numDescriptors); - - std::vector start, stop, step; - start.reserve(numDescriptors); - stop.reserve(numDescriptors); - step.reserve(numDescriptors); - // -1 Everything is ok - // >=0 Error: Start is greater than or equal to stop - int8_t preconditionStatus = -1; - - std::apply( - [&](const auto&... desc) { - size_t idx = 0; - (( - [&] { - auto clampedDesc = sliceInRange(desc); - if (clampedDesc.start > clampedDesc.stop) { - preconditionStatus = idx; - return 1; - } - if (this->hasLabel(clampedDesc.label)) { - labels.push_back(clampedDesc.label); - start.push_back(clampedDesc.start); - stop.push_back(clampedDesc.stop); - step.push_back(clampedDesc.step); - } - return 0; // Return a dummy value to satisfy the comma - // operator - }(), - idx++), - ...); - }, - tuple_descs); - - if (preconditionStatus >= 0) { - mdio::RangeDescriptor err; - std::apply( - [&](const auto&... desc) { - size_t idx = 0; - (([&] { - if (idx == preconditionStatus) { - err = desc; - } - idx++; - }()), - ...); - }, - tuple_descs); - return absl::InvalidArgumentError( + std::vector starts, stops, steps; + labels.reserve(N); + starts.reserve(N); + stops.reserve(N); + steps.reserve(N); + + int8_t bad_idx = -1; + for (size_t i = 0; i < N; ++i) { + auto d = sliceInRange(descs[i]); + if (d.start > d.stop) { + bad_idx = static_cast(i); + break; + } + if (this->hasLabel(d.label)) { + labels.push_back(d.label); + starts.push_back(d.start); + stops.push_back(d.stop); + steps.push_back(d.step); + } + } + if (bad_idx >= 0) { + auto& err = descs[bad_idx]; + return Result{absl::InvalidArgumentError( std::string("Slice descriptor for ") + - std::string(err.label.label()) + - " had an illegal configuration.\n\tStart '" + - std::to_string(err.start) + "' greater than or equal to stop '" + - std::to_string(err.stop) + "'."); + std::string(err.label.label()) + " is invalid: start=" + + std::to_string(err.start) + " > stop=" + std::to_string(err.stop))}; } - auto labelSize = labels.size(); - if (labelSize) { + // 3) Fast path: all labels (or axis indices) are unique + if (!labels.empty()) { std::set labelSet; std::set indexSet; - for (const auto& label : labels) { - labelSet.insert(label.label()); - indexSet.insert(label.index()); + for (auto& lab : labels) { + labelSet.insert(lab.label()); + indexSet.insert(lab.index()); } - - if (labelSet.size() == labelSize || indexSet.size() == labelSize) { + if (labelSet.size() == labels.size() || + indexSet.size() == labels.size()) { MDIO_ASSIGN_OR_RETURN( auto slice_store, - store | - tensorstore::Dims(labels).HalfOpenInterval(start, stop, step)); - // return a new variable with the sliced store - return Variable{variableName, longName, metadata, slice_store, - attributes}; - } else if (labelSet.size() != labelSize) { - // Concat the sliced Variable together if there are duplicate - // labels(dimensions) - std::vector fragments; - absl::Status trueStatus = absl::OkStatus(); - auto fragmentStore = [&](auto& descriptor) -> absl::Status { - if (descriptor.label.label() == internal::kInertSliceKey) { - // pass on kInertSliceKey - return absl::OkStatus(); - } - auto sliceRes = slice(descriptor); - if (!sliceRes.status().ok()) { - trueStatus = sliceRes.status(); - return trueStatus; - } - fragments.push_back(sliceRes.value()); - return absl::OkStatus(); - }; + store | tensorstore::Dims(labels).HalfOpenInterval(starts, stops, + steps)); + return Variable{variableName, longName, metadata, + std::move(slice_store), attributes}; + } + } - auto status = (fragmentStore(descriptors).ok() && ...); - if (!status) { - return trueStatus; - } + // 4) Group by label to find any duplicates + std::map>> by_label; + for (auto& d : descs) { + if (d.label.label() != internal::kInertSliceKey) { + by_label[d.label.label()].push_back(d); + } + } - if (!fragments.empty()) { - // Concat appears to only work with void types, so we'll strip the - // type away - tensorstore::TensorStore catStore; - // Initialize catStore with the first fragment's store - catStore = fragments.front().get_store(); - - // Concatenate remaining fragments - for (size_t i = 1; i < fragments.size(); ++i) { - MDIO_ASSIGN_OR_RETURN( - catStore, - tensorstore::Concat({catStore, fragments[i].get_store()}, - /*axis=*/0)); - } - // Recast to the original type - tensorstore::TensorStore typedCatStore = - tensorstore::TensorStore(tensorstore::unchecked, - catStore); - // Return a new Variable with the concatenated store - return Variable{variableName, longName, metadata, typedCatStore, - attributes}; + // 5) Handle the first label that has >1 descriptor + for (auto& [label, vec] : by_label) { + if (vec.size() > 1) { + // 5a) Unwrap the Spec so we can ask for transform().input_labels() + MDIO_ASSIGN_OR_RETURN(auto spec, store.spec()); + auto spec_labels = spec.transform().input_labels(); + + // find the numeric axis for this label + auto it = std::find(spec_labels.begin(), spec_labels.end(), label); + if (it == spec_labels.end()) { + // no-op if the label isn't in the spec; skip it + continue; + } + int axis = static_cast(std::distance(spec_labels.begin(), it)); + + // 5b) Slice each sub‑range in isolation + std::vector> pieces; + pieces.reserve(vec.size()); + for (auto& r : vec) { + auto sub = slice(r); + if (!sub.status().ok()) return sub.status(); + pieces.push_back(sub.value().get_store()); } - return absl::InternalError("No fragments to concatenate."); - } - return absl::InvalidArgumentError( - "Unexpected error occured while trying to slice the Variable."); + // 5c) Concatenate them along the correct axis + MDIO_ASSIGN_OR_RETURN(auto cat_store, + tensorstore::Concat(pieces, axis)); + + return Variable{variableName, longName, metadata, std::move(cat_store), + attributes}; + } } - // the slice didn't change anything in the variables dimensions. + + // 6) No descriptors matched → no change return *this; } @@ -1529,6 +1501,12 @@ class Variable { */ const tensorstore::TensorStore& get_store() const { return store; } + tensorstore::TensorStore& get_mutable_store() { return store; } + void set_store( + tensorstore::TensorStore& new_store) { // NOLINT (non-const) + store = new_store; + } + // The data that should remain static, but MAY need to be updated. std::shared_ptr> attributes;