diff --git a/tensorstore/driver/zarr/driver.cc b/tensorstore/driver/zarr/driver.cc index 69164648e..8a0943ae5 100644 --- a/tensorstore/driver/zarr/driver.cc +++ b/tensorstore/driver/zarr/driver.cc @@ -29,6 +29,10 @@ #include "absl/status/status.h" #include "absl/strings/cord.h" #include +#include "riegeli/bytes/cord_reader.h" +#include "riegeli/bytes/cord_writer.h" +#include "riegeli/bytes/read_all.h" +#include "riegeli/bytes/write.h" #include "tensorstore/array.h" #include "tensorstore/array_storage_statistics.h" #include "tensorstore/box.h" @@ -137,6 +141,20 @@ absl::Status ZarrDriverSpec::ApplyOptions(SpecOptions&& options) { } Result ZarrDriverSpec::GetSpecInfo() const { + // For open_as_void, we don't use normal field resolution + // Note: When opening an existing array, dtype may not be known yet, + // so we can't determine the exact rank until metadata is loaded. + if (open_as_void && partial_metadata.dtype) { + SpecRankAndFieldInfo info; + info.full_rank = schema.rank(); + info.chunked_rank = partial_metadata.rank; + // For void access, add one dimension for the bytes + info.field_rank = 1; // The bytes dimension + if (info.chunked_rank != dynamic_rank) { + info.full_rank = info.chunked_rank + 1; + } + return info; + } return GetSpecRankAndFieldInfo(partial_metadata, selected_field, schema); } @@ -171,6 +189,10 @@ TENSORSTORE_DEFINE_JSON_DEFAULT_BINDER( jb::Member("field", jb::Projection<&ZarrDriverSpec::selected_field>( jb::DefaultValue( [](auto* obj) { *obj = std::string{}; }))), + jb::Member("open_as_void", + jb::Projection<&ZarrDriverSpec::open_as_void>( + jb::DefaultValue( + [](auto* v) { *v = false; }))), jb::Initialize([](auto* obj) { TENSORSTORE_ASSIGN_OR_RETURN(auto info, obj->GetSpecInfo()); if (info.full_rank != dynamic_rank) { @@ -210,8 +232,19 @@ Result> ZarrDriverSpec::GetFillValue( const auto& metadata = partial_metadata; if (metadata.dtype && metadata.fill_value) { TENSORSTORE_ASSIGN_OR_RETURN( - size_t field_index, GetFieldIndex(*metadata.dtype, selected_field)); - fill_value = (*metadata.fill_value)[field_index]; + size_t field_index, + GetFieldIndex(*metadata.dtype, selected_field, open_as_void)); + + // For void access, synthesize a byte-level fill value + if (field_index == kVoidFieldIndex) { + const Index nbytes = metadata.dtype->bytes_per_outer_element; + auto byte_arr = AllocateArray( + span({nbytes}), c_order, value_init, + dtype_v); + fill_value = byte_arr; + } else { + fill_value = (*metadata.fill_value)[field_index]; + } } if (!fill_value.valid() || !transform.valid()) { @@ -238,13 +271,15 @@ Result> ZarrDriverSpec::GetFillValue( DataCache::DataCache(Initializer&& initializer, std::string key_prefix, DimensionSeparator dimension_separator, - std::string metadata_key) + std::string metadata_key, bool open_as_void) : Base(std::move(initializer), GetChunkGridSpecification( - *static_cast(initializer.metadata.get()))), + *static_cast(initializer.metadata.get()), + open_as_void)), key_prefix_(std::move(key_prefix)), dimension_separator_(dimension_separator), - metadata_key_(std::move(metadata_key)) {} + metadata_key_(std::move(metadata_key)), + open_as_void_(open_as_void) {} absl::Status DataCache::ValidateMetadataCompatibility( const void* existing_metadata_ptr, const void* new_metadata_ptr) { @@ -268,12 +303,40 @@ void DataCache::GetChunkGridBounds(const void* metadata_ptr, DimensionSet& implicit_lower_bounds, DimensionSet& implicit_upper_bounds) { const auto& metadata = *static_cast(metadata_ptr); - assert(bounds.rank() == static_cast(metadata.shape.size())); - std::fill(bounds.origin().begin(), bounds.origin().end(), Index(0)); + // Use >= assertion like zarr3 to allow for extra dimensions + assert(bounds.rank() >= static_cast(metadata.shape.size())); + std::fill(bounds.origin().begin(), + bounds.origin().begin() + metadata.shape.size(), Index(0)); std::copy(metadata.shape.begin(), metadata.shape.end(), bounds.shape().begin()); implicit_lower_bounds = false; - implicit_upper_bounds = true; + implicit_upper_bounds = false; + for (DimensionIndex i = 0; + i < static_cast(metadata.shape.size()); ++i) { + implicit_upper_bounds[i] = true; + } + // Handle extra dimensions for void access or field shapes + if (bounds.rank() > static_cast(metadata.shape.size())) { + if (open_as_void_) { + // For void access, the extra dimension is the bytes_per_outer_element + if (static_cast(metadata.shape.size() + 1) == + bounds.rank()) { + bounds.shape()[metadata.rank] = metadata.dtype.bytes_per_outer_element; + bounds.origin()[metadata.rank] = 0; + } + } else if (metadata.dtype.fields.size() == 1) { + // Handle single field with field_shape (like zarr3) + const auto& field = metadata.dtype.fields[0]; + if (static_cast(metadata.shape.size() + + field.field_shape.size()) == + bounds.rank()) { + for (size_t i = 0; i < field.field_shape.size(); ++i) { + bounds.shape()[metadata.shape.size() + i] = field.field_shape[i]; + bounds.origin()[metadata.shape.size() + i] = 0; + } + } + } + } } Result> DataCache::GetResizedMetadata( @@ -294,13 +357,61 @@ Result> DataCache::GetResizedMetadata( } internal::ChunkGridSpecification DataCache::GetChunkGridSpecification( - const ZarrMetadata& metadata) { + const ZarrMetadata& metadata, bool open_as_void) { internal::ChunkGridSpecification::ComponentList components; - components.reserve(metadata.dtype.fields.size()); std::vector chunked_to_cell_dimensions( metadata.chunks.size()); std::iota(chunked_to_cell_dimensions.begin(), chunked_to_cell_dimensions.end(), static_cast(0)); + + // Special case: void access - create single component for raw bytes + if (open_as_void) { + const Index bytes_per_element = metadata.dtype.bytes_per_outer_element; + + // Create a zero-filled byte array as the fill value + auto base_fill_value = AllocateArray( + span({bytes_per_element}), c_order, value_init, + dtype_v); + + // The full chunk shape includes the extra bytes dimension + std::vector chunk_shape_with_bytes = metadata.chunks; + chunk_shape_with_bytes.push_back(bytes_per_element); + + const DimensionIndex cell_rank = metadata.rank + 1; + + // Broadcast fill value to target shape [unbounded, ..., bytes_per_element] + // like zarr3 does + std::vector target_shape(metadata.rank, kInfIndex); + target_shape.push_back(bytes_per_element); + auto chunk_fill_value = + BroadcastArray(base_fill_value, BoxView<>(target_shape)).value(); + + // Create valid data bounds - unbounded for chunked dimensions, + // explicit for bytes dimension + Box<> valid_data_bounds(cell_rank); + for (DimensionIndex i = 0; i < metadata.rank; ++i) { + valid_data_bounds[i] = IndexInterval::Infinite(); + } + valid_data_bounds[metadata.rank] = + IndexInterval::UncheckedSized(0, bytes_per_element); + + // Create permutation: copy existing order and add the bytes dimension + DimensionIndex layout_order_buffer[kMaxRank]; + GetChunkInnerOrder(metadata.rank, metadata.order, + span(layout_order_buffer, metadata.rank)); + layout_order_buffer[metadata.rank] = metadata.rank; // Add bytes dimension + + components.emplace_back( + internal::AsyncWriteArray::Spec{ + std::move(chunk_fill_value), std::move(valid_data_bounds), + ContiguousLayoutPermutation<>(span(layout_order_buffer, cell_rank))}, + std::move(chunk_shape_with_bytes), chunked_to_cell_dimensions); + + return internal::ChunkGridSpecification{std::move(components)}; + } + + // Normal field-based access + components.reserve(metadata.dtype.fields.size()); for (size_t field_i = 0; field_i < metadata.dtype.fields.size(); ++field_i) { const auto& field = metadata.dtype.fields[field_i]; const auto& field_layout = metadata.chunk_layout.fields[field_i]; @@ -335,12 +446,70 @@ internal::ChunkGridSpecification DataCache::GetChunkGridSpecification( Result, 1>> DataCache::DecodeChunk( span chunk_indices, absl::Cord data) { + if (open_as_void_) { + // For void access, return raw bytes as a single component + const auto& md = metadata(); + + // Decompress the data first (if compressed) + absl::Cord decompressed = std::move(data); + if (md.compressor) { + riegeli::CordReader base_reader(std::move(decompressed)); + auto compressed_reader = md.compressor->GetReader( + base_reader, md.dtype.bytes_per_outer_element); + absl::Cord uncompressed; + TENSORSTORE_RETURN_IF_ERROR( + riegeli::ReadAll(std::move(compressed_reader), uncompressed)); + if (!base_reader.VerifyEndAndClose()) return base_reader.status(); + decompressed = std::move(uncompressed); + } + + // Build the shape: chunk_shape + bytes_per_element + std::vector shape = md.chunks; + shape.push_back(md.dtype.bytes_per_outer_element); + + // Create a byte array from the decompressed data + auto flat_data = decompressed.Flatten(); + auto byte_array = AllocateArray(shape, c_order, default_init, + dtype_v); + std::memcpy(byte_array.data(), flat_data.data(), + std::min(static_cast(byte_array.num_elements()), + flat_data.size())); + + absl::InlinedVector, 1> result; + result.push_back(std::move(byte_array)); + return result; + } return internal_zarr::DecodeChunk(metadata(), std::move(data)); } Result DataCache::EncodeChunk( span chunk_indices, span> component_arrays) { + if (open_as_void_) { + // For void access, encode raw bytes directly + const auto& md = metadata(); + if (component_arrays.size() != 1) { + return absl::InvalidArgumentError( + "Expected exactly one component array for void access"); + } + const auto& byte_array = component_arrays[0]; + absl::Cord uncompressed( + std::string_view(static_cast(byte_array.data()), + byte_array.num_elements())); + + // Compress if needed + if (md.compressor) { + absl::Cord encoded; + riegeli::CordWriter base_writer(&encoded); + auto writer = md.compressor->GetWriter( + base_writer, md.dtype.bytes_per_outer_element); + TENSORSTORE_RETURN_IF_ERROR( + riegeli::Write(std::move(uncompressed), std::move(writer))); + if (!base_writer.Close()) return base_writer.status(); + return encoded; + } + return uncompressed; + } return internal_zarr::EncodeChunk(metadata(), component_arrays); } @@ -356,6 +525,7 @@ absl::Status DataCache::GetBoundSpecData( const auto& metadata = *static_cast(metadata_ptr); spec.selected_field = EncodeSelectedField(component_index, metadata.dtype); spec.metadata_key = metadata_key_; + spec.open_as_void = open_as_void_; auto& pm = spec.partial_metadata; pm.rank = metadata.rank; pm.zarr_format = metadata.zarr_format; @@ -416,6 +586,10 @@ Result ZarrDriverSpec::ToUrl() const { return absl::InvalidArgumentError( "zarr2 URL syntax not supported with selected_field specified"); } + if (open_as_void) { + return absl::InvalidArgumentError( + "zarr2 URL syntax not supported with open_as_void specified"); + } TENSORSTORE_ASSIGN_OR_RETURN(auto base_url, store.ToUrl()); return tensorstore::StrCat(base_url, "|", kUrlScheme, ":"); } @@ -483,7 +657,8 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase { TENSORSTORE_ASSIGN_OR_RETURN( auto metadata, internal_zarr::GetNewMetadata(spec().partial_metadata, - spec().selected_field, spec().schema), + spec().selected_field, spec().schema, + spec().open_as_void), tensorstore::MaybeAnnotateStatus( _, "Cannot create using specified \"metadata\" and schema")); return metadata; @@ -496,7 +671,8 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase { internal::EncodeCacheKey( &result, spec.store.path, GetDimensionSeparator(spec.partial_metadata, zarr_metadata), - zarr_metadata, spec.metadata_key); + zarr_metadata, spec.metadata_key, + spec.open_as_void ? "void" : "normal"); return result; } @@ -507,7 +683,7 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase { return std::make_unique( std::move(initializer), spec().store.path, GetDimensionSeparator(spec().partial_metadata, metadata), - spec().metadata_key); + spec().metadata_key, spec().open_as_void); } Result GetComponentIndex(const void* metadata_ptr, @@ -516,7 +692,14 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase { TENSORSTORE_RETURN_IF_ERROR( ValidateMetadata(metadata, spec().partial_metadata)); TENSORSTORE_ASSIGN_OR_RETURN( - auto field_index, GetFieldIndex(metadata.dtype, spec().selected_field)); + auto field_index, + GetFieldIndex(metadata.dtype, spec().selected_field, + spec().open_as_void)); + // For void access, map to component index 0 since we create a special + // component for raw byte access + if (field_index == kVoidFieldIndex) { + field_index = 0; + } TENSORSTORE_RETURN_IF_ERROR( ValidateMetadataSchema(metadata, field_index, spec().schema)); return field_index; diff --git a/tensorstore/driver/zarr/driver_impl.h b/tensorstore/driver/zarr/driver_impl.h index df3c3930f..c2933dd90 100644 --- a/tensorstore/driver/zarr/driver_impl.h +++ b/tensorstore/driver/zarr/driver_impl.h @@ -63,10 +63,11 @@ class ZarrDriverSpec ZarrPartialMetadata partial_metadata; SelectedField selected_field; std::string metadata_key; + bool open_as_void = false; constexpr static auto ApplyMembers = [](auto& x, auto f) { return f(internal::BaseCast(x), x.partial_metadata, - x.selected_field, x.metadata_key); + x.selected_field, x.metadata_key, x.open_as_void); }; absl::Status ApplyOptions(SpecOptions&& options) override; @@ -98,7 +99,7 @@ class DataCache : public internal_kvs_backed_chunk_driver::DataCache { public: explicit DataCache(Initializer&& initializer, std::string key_prefix, DimensionSeparator dimension_separator, - std::string metadata_key); + std::string metadata_key, bool open_as_void = false); const ZarrMetadata& metadata() { return *static_cast(initial_metadata().get()); @@ -117,7 +118,7 @@ class DataCache : public internal_kvs_backed_chunk_driver::DataCache { /// Returns the ChunkCache grid to use for the given metadata. static internal::ChunkGridSpecification GetChunkGridSpecification( - const ZarrMetadata& metadata); + const ZarrMetadata& metadata, bool open_as_void = false); Result, 1>> DecodeChunk( span chunk_indices, absl::Cord data) override; @@ -140,6 +141,7 @@ class DataCache : public internal_kvs_backed_chunk_driver::DataCache { std::string key_prefix_; DimensionSeparator dimension_separator_; std::string metadata_key_; + bool open_as_void_; }; class ZarrDriver; diff --git a/tensorstore/driver/zarr/driver_test.cc b/tensorstore/driver/zarr/driver_test.cc index 92c5be48a..a5014987d 100644 --- a/tensorstore/driver/zarr/driver_test.cc +++ b/tensorstore/driver/zarr/driver_test.cc @@ -3499,4 +3499,326 @@ TEST(DriverTest, UrlSchemeRoundtrip) { {"kvstore", {{"driver", "memory"}, {"path", "abc.zarr/def/"}}}}); } +// Tests for open_as_void functionality + +TEST(ZarrDriverTest, OpenAsVoidSimpleType) { + // Test open_as_void with a simple data type (int16) + auto context = Context::Default(); + + // First create a normal array + ::nlohmann::json create_spec{ + {"driver", "zarr"}, + {"kvstore", {{"driver", "memory"}, {"path", "prefix/"}}}, + {"metadata", + { + {"compressor", nullptr}, + {"dtype", "({{1, 2}, {3, 4}}); + TENSORSTORE_EXPECT_OK( + tensorstore::Write(data, store | tensorstore::Dims(0, 1).SizedInterval( + {0, 0}, {2, 2})) + .result()); + + // Now open with open_as_void=true + ::nlohmann::json void_spec{ + {"driver", "zarr"}, + {"kvstore", {{"driver", "memory"}, {"path", "prefix/"}}}, + {"open_as_void", true}, + }; + + TENSORSTORE_ASSERT_OK_AND_ASSIGN( + auto void_store, + tensorstore::Open(void_spec, context, tensorstore::OpenMode::open, + tensorstore::ReadWriteMode::read) + .result()); + + // The void store should have rank = original_rank + 1 (for bytes dimension) + EXPECT_EQ(3, void_store.rank()); + + // The last dimension should be the size of the data type (2 bytes for int16) + EXPECT_EQ(2, void_store.domain().shape()[2]); + + // The data type should be byte + EXPECT_EQ(tensorstore::dtype_v, + void_store.dtype()); +} + +TEST(ZarrDriverTest, OpenAsVoidStructuredType) { + // Test open_as_void with a structured data type + auto context = Context::Default(); + + // Create an array with a structured dtype + ::nlohmann::json create_spec{ + {"driver", "zarr"}, + {"kvstore", {{"driver", "memory"}, {"path", "prefix/"}}}, + {"field", "y"}, + {"metadata", + { + {"compressor", nullptr}, + {"dtype", ::nlohmann::json::array_t{{"x", "|u1"}, {"y", "({{100, 200}, {300, 400}}); + TENSORSTORE_EXPECT_OK( + tensorstore::Write(data, store | tensorstore::Dims(0, 1).SizedInterval( + {0, 0}, {2, 2})) + .result()); + + // Now open with open_as_void=true - this should give raw access to the entire + // struct + ::nlohmann::json void_spec{ + {"driver", "zarr"}, + {"kvstore", {{"driver", "memory"}, {"path", "prefix/"}}}, + {"open_as_void", true}, + }; + + TENSORSTORE_ASSERT_OK_AND_ASSIGN( + auto void_store, + tensorstore::Open(void_spec, context, tensorstore::OpenMode::open, + tensorstore::ReadWriteMode::read) + .result()); + + // The void store should have rank = original_rank + 1 (for bytes dimension) + EXPECT_EQ(3, void_store.rank()); + + // The last dimension should be 3 bytes (1 byte for u1 + 2 bytes for i2) + EXPECT_EQ(3, void_store.domain().shape()[2]); + + // The data type should be byte + EXPECT_EQ(tensorstore::dtype_v, + void_store.dtype()); +} + +TEST(ZarrDriverTest, OpenAsVoidWithCompression) { + // Test open_as_void with compression enabled + auto context = Context::Default(); + + // Create an array with blosc compression + ::nlohmann::json create_spec{ + {"driver", "zarr"}, + {"kvstore", {{"driver", "memory"}, {"path", "prefix/"}}}, + {"metadata", + { + {"compressor", {{"id", "blosc"}}}, + {"dtype", "({{0x01020304, 0x05060708}, + {0x090a0b0c, 0x0d0e0f10}}); + TENSORSTORE_EXPECT_OK( + tensorstore::Write(data, store | tensorstore::Dims(0, 1).SizedInterval( + {0, 0}, {2, 2})) + .result()); + + // Now open with open_as_void=true + ::nlohmann::json void_spec{ + {"driver", "zarr"}, + {"kvstore", {{"driver", "memory"}, {"path", "prefix/"}}}, + {"open_as_void", true}, + }; + + TENSORSTORE_ASSERT_OK_AND_ASSIGN( + auto void_store, + tensorstore::Open(void_spec, context, tensorstore::OpenMode::open, + tensorstore::ReadWriteMode::read) + .result()); + + // The void store should have rank = original_rank + 1 (for bytes dimension) + EXPECT_EQ(3, void_store.rank()); + + // The last dimension should be 4 bytes for int32 + EXPECT_EQ(4, void_store.domain().shape()[2]); + + // The data type should be byte + EXPECT_EQ(tensorstore::dtype_v, + void_store.dtype()); + + // Read the raw bytes and verify decompression works + TENSORSTORE_ASSERT_OK_AND_ASSIGN( + auto read_result, + tensorstore::Read(void_store | tensorstore::Dims(0, 1).SizedInterval( + {0, 0}, {2, 2})) + .result()); + EXPECT_EQ(read_result.shape()[0], 2); + EXPECT_EQ(read_result.shape()[1], 2); + EXPECT_EQ(read_result.shape()[2], 4); +} + +TEST(ZarrDriverTest, OpenAsVoidSpecRoundtrip) { + // Test that open_as_void is properly preserved in spec round-trips + ::nlohmann::json json_spec{ + {"driver", "zarr"}, + {"kvstore", {{"driver", "memory"}, {"path", "prefix/"}}}, + {"open_as_void", true}, + {"metadata", + { + {"compressor", nullptr}, + {"dtype", ", + void_store.dtype()); +} + +TEST(ZarrDriverTest, OpenAsVoidUrlNotSupported) { + // Test that open_as_void is not supported with URL syntax + ::nlohmann::json json_spec{ + {"driver", "zarr"}, + {"kvstore", {{"driver", "memory"}, {"path", "prefix/"}}}, + {"open_as_void", true}, + {"metadata", + { + {"dtype", "({{0x0102, 0x0304}, + {0x0506, 0x0708}}); + TENSORSTORE_EXPECT_OK(tensorstore::Write(data, store).result()); + + // Open as void and read + ::nlohmann::json void_spec{ + {"driver", "zarr"}, + {"kvstore", {{"driver", "memory"}, {"path", "prefix/"}}}, + {"open_as_void", true}, + }; + + TENSORSTORE_ASSERT_OK_AND_ASSIGN( + auto void_store, + tensorstore::Open(void_spec, context, tensorstore::OpenMode::open, + tensorstore::ReadWriteMode::read_write) + .result()); + + // Read the raw bytes + TENSORSTORE_ASSERT_OK_AND_ASSIGN(auto bytes_read, + tensorstore::Read(void_store).result()); + + // Verify shape: [2, 2, 2] where last dim is 2 bytes per uint16 + EXPECT_EQ(bytes_read.shape()[0], 2); + EXPECT_EQ(bytes_read.shape()[1], 2); + EXPECT_EQ(bytes_read.shape()[2], 2); + + // Verify the raw bytes (little endian) + auto bytes_ptr = static_cast(bytes_read.data()); + // First element: 0x0102 -> bytes 0x02, 0x01 (little endian) + EXPECT_EQ(bytes_ptr[0], 0x02); + EXPECT_EQ(bytes_ptr[1], 0x01); +} + } // namespace diff --git a/tensorstore/driver/zarr/schema.yml b/tensorstore/driver/zarr/schema.yml index 45711648c..a90fb7e3a 100644 --- a/tensorstore/driver/zarr/schema.yml +++ b/tensorstore/driver/zarr/schema.yml @@ -17,6 +17,14 @@ allOf: Must be specified if the `.metadata.dtype` specified in the array metadata has more than one field. default: null + open_as_void: + type: boolean + default: false + title: Raw byte access mode. + description: | + When true, opens the array as raw bytes instead of interpreting it + as structured data. The resulting array will have an additional + dimension representing the byte layout of each element. metadata: title: Zarr array metadata. description: | diff --git a/tensorstore/driver/zarr/spec.cc b/tensorstore/driver/zarr/spec.cc index 34a2825f9..4857d045b 100644 --- a/tensorstore/driver/zarr/spec.cc +++ b/tensorstore/driver/zarr/spec.cc @@ -151,7 +151,8 @@ absl::Status ValidateMetadata(const ZarrMetadata& metadata, Result GetNewMetadata( const ZarrPartialMetadata& partial_metadata, - const SelectedField& selected_field, const Schema& schema) { + const SelectedField& selected_field, const Schema& schema, + bool open_as_void) { ZarrMetadataPtr metadata = std::make_shared(); metadata->zarr_format = partial_metadata.zarr_format.value_or(2); metadata->dimension_separator = partial_metadata.dimension_separator.value_or( @@ -172,7 +173,12 @@ Result GetNewMetadata( // multi-field zarr dtype is desired, it must be specified explicitly. TENSORSTORE_ASSIGN_OR_RETURN( selected_field_index, - GetFieldIndex(*partial_metadata.dtype, selected_field)); + GetFieldIndex(*partial_metadata.dtype, selected_field, open_as_void)); + // For void access, use field 0 for metadata creation since we use all + // fields as raw bytes + if (selected_field_index == kVoidFieldIndex) { + selected_field_index = 0; + } metadata->dtype = *partial_metadata.dtype; } else { if (!selected_field.empty()) { @@ -527,7 +533,17 @@ std::string GetFieldNames(const ZarrDType& dtype) { } // namespace Result GetFieldIndex(const ZarrDType& dtype, - const SelectedField& selected_field) { + const SelectedField& selected_field, + bool open_as_void) { + // Special case: open_as_void requests raw byte access (works for any dtype) + if (open_as_void) { + if (dtype.fields.empty()) { + return absl::FailedPreconditionError( + "Requested void access but dtype has no fields"); + } + return kVoidFieldIndex; + } + if (selected_field.empty()) { if (dtype.fields.size() != 1) { return absl::FailedPreconditionError(tensorstore::StrCat( diff --git a/tensorstore/driver/zarr/spec.h b/tensorstore/driver/zarr/spec.h index 0ef3ab9d3..597fc32f0 100644 --- a/tensorstore/driver/zarr/spec.h +++ b/tensorstore/driver/zarr/spec.h @@ -70,9 +70,11 @@ using SelectedField = std::string; /// \param partial_metadata Constraints in the form of partial zarr metadata. /// \param selected_field The field to which `schema` applies. /// \param schema Schema constraints for the `selected_field`. +/// \param open_as_void If true, opens the array as raw bytes. Result GetNewMetadata( const ZarrPartialMetadata& partial_metadata, - const SelectedField& selected_field, const Schema& schema); + const SelectedField& selected_field, const Schema& schema, + bool open_as_void = false); struct SpecRankAndFieldInfo { /// Full rank of the TensorStore, if known. Equal to the chunked rank plus @@ -134,11 +136,16 @@ Result ParseSelectedField(const ::nlohmann::json& value); /// \param dtype The parsed zarr "dtype" specification. /// \param selected_field The label of the field, or an empty string to indicate /// that the zarr array must have only a single field. -/// \returns The field index. +/// \param open_as_void If true, returns kVoidFieldIndex for raw byte access. +/// \returns The field index, or kVoidFieldIndex if open_as_void is true. /// \error `absl::StatusCode::kFailedPrecondition` if `selected_field` is not /// valid. Result GetFieldIndex(const ZarrDType& dtype, - const SelectedField& selected_field); + const SelectedField& selected_field, + bool open_as_void = false); + +/// Special field index indicating void (raw byte) access. +constexpr size_t kVoidFieldIndex = size_t(-1); /// Encodes a field index as a `SelectedField` JSON specification. /// diff --git a/tensorstore/driver/zarr3/BUILD b/tensorstore/driver/zarr3/BUILD index 72d2f2f71..b9e442bdf 100644 --- a/tensorstore/driver/zarr3/BUILD +++ b/tensorstore/driver/zarr3/BUILD @@ -94,8 +94,8 @@ tensorstore_cc_library( tensorstore_cc_library( name = "metadata", - srcs = ["metadata.cc"], - hdrs = ["metadata.h"], + srcs = ["metadata.cc", "dtype.cc"], + hdrs = ["metadata.h", "dtype.h"], deps = [ ":default_nan", ":name_configuration_json_binder", @@ -145,6 +145,23 @@ tensorstore_cc_library( ], ) +tensorstore_cc_test( + name = "dtype_test", + size = "small", + srcs = ["dtype_test.cc"], + deps = [ + ":metadata", + "//tensorstore:data_type", + "//tensorstore:index", + "//tensorstore/internal/testing:json_gtest", + "//tensorstore/util:status_testutil", + "//tensorstore/util:str_cat", + "@abseil-cpp//absl/status", + "@googletest//:gtest_main", + "@nlohmann_json//:json", + ], +) + tensorstore_cc_test( name = "driver_test", size = "small", diff --git a/tensorstore/driver/zarr3/chunk_cache.cc b/tensorstore/driver/zarr3/chunk_cache.cc index ee1cba9c1..f14efd607 100644 --- a/tensorstore/driver/zarr3/chunk_cache.cc +++ b/tensorstore/driver/zarr3/chunk_cache.cc @@ -18,6 +18,8 @@ #include #include +#include +#include #include #include #include @@ -73,15 +75,19 @@ ZarrChunkCache::~ZarrChunkCache() = default; ZarrLeafChunkCache::ZarrLeafChunkCache( kvstore::DriverPtr store, ZarrCodecChain::PreparedState::Ptr codec_state, - internal::CachePool::WeakPtr /*data_cache_pool*/) - : Base(std::move(store)), codec_state_(std::move(codec_state)) {} + ZarrDType dtype, internal::CachePool::WeakPtr /*data_cache_pool*/, + bool open_as_void) + : Base(std::move(store)), + codec_state_(std::move(codec_state)), + dtype_(std::move(dtype)), + open_as_void_(open_as_void) {} void ZarrLeafChunkCache::Read(ZarrChunkCache::ReadRequest request, AnyFlowReceiver>&& receiver) { return internal::ChunkCache::Read( {static_cast(request), - /*component_index=*/0, request.staleness_bound, + request.component_index, request.staleness_bound, request.fill_missing_data_reads}, std::move(receiver)); } @@ -92,7 +98,7 @@ void ZarrLeafChunkCache::Write( receiver) { return internal::ChunkCache::Write( {static_cast(request), - /*component_index=*/0, request.store_data_equal_to_fill_value}, + request.component_index, request.store_data_equal_to_fill_value}, std::move(receiver)); } @@ -149,12 +155,59 @@ std::string ZarrLeafChunkCache::GetChunkStorageKey( Result, 1>> ZarrLeafChunkCache::DecodeChunk(span chunk_indices, absl::Cord data) { + const size_t num_fields = dtype_.fields.size(); + absl::InlinedVector, 1> field_arrays(num_fields); + + // Special case: void access - return raw bytes directly + if (open_as_void_) { + TENSORSTORE_ASSIGN_OR_RETURN( + field_arrays[0], codec_state_->DecodeArray(grid().components[0].shape(), + std::move(data))); + return field_arrays; + } + + // For single non-structured field, decode directly + if (num_fields == 1 && dtype_.fields[0].outer_shape.empty()) { + TENSORSTORE_ASSIGN_OR_RETURN( + field_arrays[0], codec_state_->DecodeArray(grid().components[0].shape(), + std::move(data))); + return field_arrays; + } + + // For structured types, decode byte array then extract fields + // Build decode shape: [chunk_dims..., bytes_per_outer_element] + const auto& chunk_shape = grid().chunk_shape; + std::vector decode_shape(chunk_shape.begin(), chunk_shape.end()); + decode_shape.push_back(dtype_.bytes_per_outer_element); + TENSORSTORE_ASSIGN_OR_RETURN( - auto array, - codec_state_->DecodeArray(grid().components[0].shape(), std::move(data))); - absl::InlinedVector, 1> components; - components.push_back(std::move(array)); - return components; + auto byte_array, codec_state_->DecodeArray(decode_shape, std::move(data))); + + // Extract each field from the byte array + const Index num_elements = byte_array.num_elements() / + dtype_.bytes_per_outer_element; + const auto* src_bytes = static_cast(byte_array.data()); + + for (size_t field_i = 0; field_i < num_fields; ++field_i) { + const auto& field = dtype_.fields[field_i]; + // Use the component's shape (from the grid) for the result array + const auto& component_shape = grid().components[field_i].shape(); + auto result_array = + AllocateArray(component_shape, c_order, default_init, field.dtype); + auto* dst = static_cast(result_array.data()); + const Index field_size = field.dtype->size; + + // Copy field data from each struct element + for (Index i = 0; i < num_elements; ++i) { + std::memcpy(dst + i * field_size, + src_bytes + i * dtype_.bytes_per_outer_element + + field.byte_offset, + field_size); + } + field_arrays[field_i] = std::move(result_array); + } + + return field_arrays; } Result ZarrLeafChunkCache::EncodeChunk( @@ -170,10 +223,13 @@ kvstore::Driver* ZarrLeafChunkCache::GetKvStoreDriver() { ZarrShardedChunkCache::ZarrShardedChunkCache( kvstore::DriverPtr store, ZarrCodecChain::PreparedState::Ptr codec_state, - internal::CachePool::WeakPtr data_cache_pool) + ZarrDType dtype, internal::CachePool::WeakPtr data_cache_pool, + bool open_as_void) : base_kvstore_(std::move(store)), codec_state_(std::move(codec_state)), - data_cache_pool_(std::move(data_cache_pool)) {} + dtype_(std::move(dtype)), + data_cache_pool_(std::move(data_cache_pool)), + open_as_void_(open_as_void) {} Result> TranslateCellToSourceTransformForShard( IndexTransform<> transform, span grid_cell_indices, @@ -326,6 +382,7 @@ void ZarrShardedChunkCache::Read( *this, std::move(request.transform), std::move(receiver), [transaction = std::move(request.transaction), batch = std::move(request.batch), + component_index = request.component_index, staleness_bound = request.staleness_bound, fill_missing_data_reads = request.fill_missing_data_reads](auto entry) { Batch shard_batch = batch; @@ -339,8 +396,7 @@ void ZarrShardedChunkCache::Read( IndexTransform<>>&& receiver) { entry->sub_chunk_cache.get()->Read( {{transaction, std::move(transform), shard_batch}, - staleness_bound, - fill_missing_data_reads}, + component_index, staleness_bound, fill_missing_data_reads}, std::move(receiver)); }; }); @@ -354,6 +410,7 @@ void ZarrShardedChunkCache::Write( &ZarrArrayToArrayCodec::PreparedState::Write>( *this, std::move(request.transform), std::move(receiver), [transaction = std::move(request.transaction), + component_index = request.component_index, store_data_equal_to_fill_value = request.store_data_equal_to_fill_value](auto entry) { internal::OpenTransactionPtr shard_transaction = transaction; @@ -366,7 +423,7 @@ void ZarrShardedChunkCache::Write( AnyFlowReceiver>&& receiver) { entry->sub_chunk_cache.get()->Write( - {{shard_transaction, std::move(transform)}, + {{shard_transaction, std::move(transform)}, component_index, store_data_equal_to_fill_value}, std::move(receiver)); }; @@ -481,7 +538,7 @@ void ZarrShardedChunkCache::Entry::DoInitialize() { *sharding_state.sub_chunk_codec_chain, std::move(sharding_kvstore), cache.executor(), ZarrShardingCodec::PreparedState::Ptr(&sharding_state), - cache.data_cache_pool_); + cache.dtype_, cache.data_cache_pool_, cache.open_as_void_); zarr_chunk_cache = new_cache.release(); return std::unique_ptr(&zarr_chunk_cache->cache()); }) diff --git a/tensorstore/driver/zarr3/chunk_cache.h b/tensorstore/driver/zarr3/chunk_cache.h index dd40e43ac..a39eb1dc8 100644 --- a/tensorstore/driver/zarr3/chunk_cache.h +++ b/tensorstore/driver/zarr3/chunk_cache.h @@ -31,6 +31,7 @@ #include "tensorstore/driver/read_request.h" #include "tensorstore/driver/write_request.h" #include "tensorstore/driver/zarr3/codec/codec.h" +#include "tensorstore/driver/zarr3/dtype.h" #include "tensorstore/index.h" #include "tensorstore/index_space/index_transform.h" #include "tensorstore/internal/cache/cache.h" @@ -72,6 +73,7 @@ class ZarrChunkCache { virtual const Executor& executor() const = 0; struct ReadRequest : internal::DriverReadRequest { + size_t component_index = 0; absl::Time staleness_bound; bool fill_missing_data_reads; }; @@ -81,6 +83,7 @@ class ZarrChunkCache { IndexTransform<>>&& receiver) = 0; struct WriteRequest : internal::DriverWriteRequest { + size_t component_index = 0; bool store_data_equal_to_fill_value; }; @@ -154,7 +157,9 @@ class ZarrLeafChunkCache : public internal::KvsBackedChunkCache, explicit ZarrLeafChunkCache(kvstore::DriverPtr store, ZarrCodecChain::PreparedState::Ptr codec_state, - internal::CachePool::WeakPtr data_cache_pool); + ZarrDType dtype, + internal::CachePool::WeakPtr data_cache_pool, + bool open_as_void = false); void Read(ZarrChunkCache::ReadRequest request, AnyFlowReceiver( @@ -239,6 +248,8 @@ class ZarrShardedChunkCache : public internal::Cache, public ZarrChunkCache { kvstore::DriverPtr base_kvstore_; ZarrCodecChain::PreparedState::Ptr codec_state_; + ZarrDType dtype_; + bool open_as_void_; // Data cache pool, if it differs from `this->pool()` (which is equal to the // metadata cache pool). @@ -253,11 +264,13 @@ class ZarrShardSubChunkCache : public ChunkCacheImpl { explicit ZarrShardSubChunkCache( kvstore::DriverPtr store, Executor executor, ZarrShardingCodec::PreparedState::Ptr sharding_state, - internal::CachePool::WeakPtr data_cache_pool) + ZarrDType dtype, internal::CachePool::WeakPtr data_cache_pool, + bool open_as_void = false) : ChunkCacheImpl(std::move(store), ZarrCodecChain::PreparedState::Ptr( sharding_state->sub_chunk_codec_state), - std::move(data_cache_pool)), + std::move(dtype), std::move(data_cache_pool), + open_as_void), sharding_state_(std::move(sharding_state)), executor_(std::move(executor)) {} diff --git a/tensorstore/driver/zarr3/driver.cc b/tensorstore/driver/zarr3/driver.cc index a516c1a7b..f65533197 100644 --- a/tensorstore/driver/zarr3/driver.cc +++ b/tensorstore/driver/zarr3/driver.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -79,6 +80,8 @@ namespace tensorstore { namespace internal_zarr3 { +constexpr size_t kVoidFieldIndex = size_t(-1); + // Avoid anonymous namespace to workaround MSVC bug. // // https://developercommunity.visualstudio.com/t/Bug-involving-virtual-functions-templat/10424129 @@ -103,9 +106,12 @@ class ZarrDriverSpec /*Parent=*/KvsDriverSpec>; ZarrMetadataConstraints metadata_constraints; + std::string selected_field; + bool open_as_void; constexpr static auto ApplyMembers = [](auto& x, auto f) { - return f(internal::BaseCast(x), x.metadata_constraints); + return f(internal::BaseCast(x), x.metadata_constraints, + x.selected_field, x.open_as_void); }; static inline const auto default_json_binder = jb::Sequence( @@ -121,14 +127,34 @@ class ZarrDriverSpec "metadata", jb::Validate( [](const auto& options, auto* obj) { - TENSORSTORE_RETURN_IF_ERROR(obj->schema.Set( - obj->metadata_constraints.data_type.value_or(DataType()))); + if (obj->metadata_constraints.data_type) { + if (auto dtype = GetScalarDataType( + *obj->metadata_constraints.data_type)) { + TENSORSTORE_RETURN_IF_ERROR(obj->schema.Set(*dtype)); + } else if (obj->schema.dtype().valid()) { + return absl::InvalidArgumentError( + "schema dtype must be unspecified for structured " + "zarr3 data types"); + } else { + // Leave dtype unspecified; structured dtypes are handled + // at metadata level only. + } + } TENSORSTORE_RETURN_IF_ERROR(obj->schema.Set( RankConstraint{obj->metadata_constraints.rank})); return absl::OkStatus(); }, jb::Projection<&ZarrDriverSpec::metadata_constraints>( - jb::DefaultInitializedValue())))); + jb::DefaultInitializedValue()))), + jb::Member("field", jb::Projection<&ZarrDriverSpec::selected_field>( + jb::DefaultValue( + [](auto* obj) { *obj = std::string{}; }))), + jb::Member("open_as_void", jb::Projection<&ZarrDriverSpec::open_as_void>( + jb::DefaultValue( + [](auto* v) { *v = false; })))); + + + absl::Status ApplyOptions(SpecOptions&& options) override { if (options.minimal_spec) { @@ -145,12 +171,74 @@ class ZarrDriverSpec IndexTransformView<> transform) const override { SharedArray fill_value{schema.fill_value()}; - const auto& metadata = metadata_constraints; - if (metadata.fill_value) { - fill_value = *metadata.fill_value; + const auto& constraints = metadata_constraints; + + // If constraints don't specify a fill value, just use the schema's. + if (!constraints.fill_value || constraints.fill_value->empty()) { + return fill_value; + } + + const auto& vec = *constraints.fill_value; + + // If we don't have dtype information, we can't do field-aware logic. + if (!constraints.data_type) { + if (!vec.empty()) return vec[0]; + return fill_value; + } + + const ZarrDType& dtype = *constraints.data_type; + + // Determine which field this spec refers to (or void access). + TENSORSTORE_ASSIGN_OR_RETURN( + size_t field_index, + GetFieldIndex(dtype, selected_field, open_as_void)); + + // ── Normal field access: just return that field's fill_value ─────────────── + if (field_index != kVoidFieldIndex) { + if (field_index < vec.size()) { + return vec[field_index]; + } + // Fallback to "no fill". + return SharedArray(); + } + + // ── Void access: synthesize a byte-level fill value ──────────────────────── + // + // We want a 1D byte array of length bytes_per_outer_element whose contents + // are exactly the Zarr-defined struct layout built from per-field fills. + + // Special case: "raw bytes" field (single byte_t field with flexible shape). + // In that case the existing fill array already has the correct bytes. + if (dtype.fields.size() == 1 && + dtype.fields[0].dtype.id() == DataTypeId::byte_t && + !dtype.fields[0].flexible_shape.empty()) { + // vec[0] should be a byte array of size bytes_per_outer_element. + return vec[0]; + } + + const Index nbytes = dtype.bytes_per_outer_element; + + auto byte_arr = AllocateArray( + span({nbytes}), c_order, default_init, + dtype_v); + auto* dst = static_cast(byte_arr.data()); + std::memset(dst, 0, static_cast(nbytes)); + + // Pack each field's scalar fill into its byte_offset region. + for (size_t i = 0; i < dtype.fields.size() && i < vec.size(); ++i) { + const auto& field = dtype.fields[i]; + const auto& field_fill = vec[i]; + if (!field_fill.valid()) continue; + + // We assume a single outer element per field here (which is exactly how + // FillValueJsonBinder constructs per-field fill values). + std::memcpy( + dst + field.byte_offset, + static_cast(field_fill.data()), + static_cast(field.num_bytes)); } - return fill_value; + return byte_arr; } Result GetDimensionUnits() const override { @@ -247,12 +335,29 @@ class DataCacheBase DimensionSet& implicit_lower_bounds, DimensionSet& implicit_upper_bounds) override { const auto& metadata = *static_cast(metadata_ptr); - assert(bounds.rank() == static_cast(metadata.shape.size())); - std::fill(bounds.origin().begin(), bounds.origin().end(), Index(0)); + assert(bounds.rank() >= static_cast(metadata.shape.size())); + std::fill(bounds.origin().begin(), + bounds.origin().begin() + metadata.shape.size(), Index(0)); std::copy(metadata.shape.begin(), metadata.shape.end(), bounds.shape().begin()); implicit_lower_bounds = false; - implicit_upper_bounds = true; + implicit_upper_bounds = false; + for (DimensionIndex i = 0; + i < static_cast(metadata.shape.size()); ++i) { + implicit_upper_bounds[i] = true; + } + if (bounds.rank() > static_cast(metadata.shape.size()) && + metadata.data_type.fields.size() == 1) { + const auto& field = metadata.data_type.fields[0]; + if (static_cast(metadata.shape.size() + + field.field_shape.size()) == + bounds.rank()) { + for (size_t i = 0; i < field.field_shape.size(); ++i) { + bounds.shape()[metadata.shape.size() + i] = field.field_shape[i]; + bounds.origin()[metadata.shape.size() + i] = 0; + } + } + } } Result> GetResizedMetadata( @@ -273,21 +378,102 @@ class DataCacheBase } static internal::ChunkGridSpecification GetChunkGridSpecification( - const ZarrMetadata& metadata) { - auto fill_value = - BroadcastArray(metadata.fill_value, BoxView<>(metadata.rank)).value(); + const ZarrMetadata& metadata, size_t field_index = 0) { + assert(!metadata.fill_value.empty()); internal::ChunkGridSpecification::ComponentList components; - auto& component = components.emplace_back( - internal::AsyncWriteArray::Spec{ - std::move(fill_value), - // Since all dimensions are resizable, just - // specify unbounded `valid_data_bounds`. - Box<>(metadata.rank), - ContiguousLayoutPermutation<>( - span(metadata.inner_order.data(), metadata.rank))}, - metadata.chunk_shape); - component.array_spec.fill_value_comparison_kind = - EqualityComparisonKind::identical; + + // Special case: void access - create single component for entire struct + if (field_index == kVoidFieldIndex) { + // For void access, create a zero-filled byte array as the fill value + const Index bytes_per_element = metadata.data_type.bytes_per_outer_element; + auto base_fill_value = AllocateArray( + span({bytes_per_element}), c_order, value_init, + dtype_v); + + // Broadcast to shape [unbounded, unbounded, ..., struct_size] + std::vector target_shape(metadata.rank, kInfIndex); + target_shape.push_back(bytes_per_element); + auto chunk_fill_value = + BroadcastArray(base_fill_value, BoxView<>(target_shape)).value(); + + // Add extra dimension for struct size in bytes + std::vector chunk_shape_with_bytes = metadata.chunk_shape; + chunk_shape_with_bytes.push_back(bytes_per_element); + + // Create permutation: copy existing inner_order and add the new dimension + std::vector void_permutation(metadata.rank + 1); + std::copy_n(metadata.inner_order.data(), metadata.rank, + void_permutation.begin()); + void_permutation[metadata.rank] = metadata.rank; // Add the bytes dimension + + auto& component = components.emplace_back( + internal::AsyncWriteArray::Spec{ + std::move(chunk_fill_value), + // Since all dimensions are resizable, just + // specify unbounded `valid_data_bounds`. + Box<>(metadata.rank + 1), + ContiguousLayoutPermutation<>( + span(void_permutation.data(), metadata.rank + 1))}, + chunk_shape_with_bytes); + component.array_spec.fill_value_comparison_kind = + EqualityComparisonKind::identical; + return internal::ChunkGridSpecification(std::move(components)); + } + + // Create one component per field (like zarr v2) + for (size_t field_i = 0; field_i < metadata.data_type.fields.size(); + ++field_i) { + const auto& field = metadata.data_type.fields[field_i]; + auto fill_value = metadata.fill_value[field_i]; + if (!fill_value.valid()) { + // Use value-initialized rank-0 fill value (like zarr v2) + fill_value = AllocateArray(span{}, c_order, value_init, + field.dtype); + } + + // Handle fields with shape (e.g. raw_bytes) + const size_t field_rank = field.field_shape.size(); + + // 1. Construct target shape for broadcasting + std::vector target_shape(metadata.rank, kInfIndex); + target_shape.insert(target_shape.end(), field.field_shape.begin(), + field.field_shape.end()); + + auto chunk_fill_value = + BroadcastArray(fill_value, BoxView<>(target_shape)).value(); + + // 2. Construct component chunk shape + std::vector component_chunk_shape = metadata.chunk_shape; + component_chunk_shape.insert(component_chunk_shape.end(), + field.field_shape.begin(), + field.field_shape.end()); + + // 3. Construct permutation + std::vector component_permutation(metadata.rank + + field_rank); + std::copy_n(metadata.inner_order.data(), metadata.rank, + component_permutation.begin()); + std::iota(component_permutation.begin() + metadata.rank, + component_permutation.end(), metadata.rank); + + // 4. Construct bounds + Box<> valid_data_bounds(metadata.rank + field_rank); + for (size_t i = 0; i < field_rank; ++i) { + valid_data_bounds[metadata.rank + i] = + IndexInterval::UncheckedSized(0, field.field_shape[i]); + } + + auto& component = components.emplace_back( + internal::AsyncWriteArray::Spec{ + std::move(chunk_fill_value), + // Since all dimensions are resizable, just + // specify unbounded `valid_data_bounds`. + std::move(valid_data_bounds), + ContiguousLayoutPermutation<>(component_permutation)}, + component_chunk_shape); + component.array_spec.fill_value_comparison_kind = + EqualityComparisonKind::identical; + } return internal::ChunkGridSpecification(std::move(components)); } @@ -312,7 +498,7 @@ class DataCacheBase [](std::string& out, DimensionIndex dim, Index grid_index) { absl::StrAppend(&out, grid_index); }, - rank, grid_indices); + rank, grid_indices.subspan(0, rank)); return key; } @@ -325,17 +511,21 @@ class DataCacheBase key_prefix_.size() + (metadata.chunk_key_encoding.kind == ChunkKeyEncoding::kDefault ? 2 : 0)); - return internal::ParseGridIndexKeyWithDimensionSeparator( - metadata.chunk_key_encoding.separator, - [](std::string_view part, DimensionIndex dim, Index& grid_index) { - if (part.empty() || !absl::ascii_isdigit(part.front()) || - !absl::ascii_isdigit(part.back()) || - !absl::SimpleAtoi(part, &grid_index)) { - return false; - } - return true; - }, - key, grid_indices); + if (!internal::ParseGridIndexKeyWithDimensionSeparator( + metadata.chunk_key_encoding.separator, + [](std::string_view part, DimensionIndex dim, Index& grid_index) { + if (part.empty() || !absl::ascii_isdigit(part.front()) || + !absl::ascii_isdigit(part.back()) || + !absl::SimpleAtoi(part, &grid_index)) { + return false; + } + return true; + }, + key, grid_indices.subspan(0, metadata.rank))) { + return false; + } + std::fill(grid_indices.begin() + metadata.rank, grid_indices.end(), 0); + return true; } Index MinGridIndexForLexicographicalOrder( @@ -348,7 +538,7 @@ class DataCacheBase *static_cast(initial_metadata().get()); if (metadata.chunk_key_encoding.kind == ChunkKeyEncoding::kDefault) { std::string key = tensorstore::StrCat(key_prefix_, "c"); - for (DimensionIndex i = 0; i < cell_indices.size(); ++i) { + for (DimensionIndex i = 0; i < metadata.rank; ++i) { tensorstore::StrAppend( &key, std::string_view(&metadata.chunk_key_encoding.separator, 1), cell_indices[i]); @@ -358,7 +548,7 @@ class DataCacheBase // Use "0" for rank 0 as a special case. std::string key = tensorstore::StrCat( key_prefix_, cell_indices.empty() ? 0 : cell_indices[0]); - for (DimensionIndex i = 1; i < cell_indices.size(); ++i) { + for (DimensionIndex i = 1; i < metadata.rank; ++i) { tensorstore::StrAppend( &key, std::string_view(&metadata.chunk_key_encoding.separator, 1), cell_indices[i]); @@ -368,9 +558,13 @@ class DataCacheBase Result> GetExternalToInternalTransform( const void* metadata_ptr, size_t component_index) override { - assert(component_index == 0); + // component_index corresponds to the selected field index const auto& metadata = *static_cast(metadata_ptr); + const auto& field = metadata.data_type.fields[component_index]; const DimensionIndex rank = metadata.rank; + const DimensionIndex field_rank = field.field_shape.size(); + const DimensionIndex total_rank = rank + field_rank; + std::string_view normalized_dimension_names[kMaxRank]; for (DimensionIndex i = 0; i < rank; ++i) { if (const auto& name = metadata.dimension_names[i]; name.has_value()) { @@ -378,11 +572,20 @@ class DataCacheBase } } auto builder = - tensorstore::IndexTransformBuilder<>(rank, rank) - .input_shape(metadata.shape) - .input_labels(span(&normalized_dimension_names[0], rank)); - builder.implicit_upper_bounds(true); + tensorstore::IndexTransformBuilder<>(total_rank, total_rank); + std::vector full_shape = metadata.shape; + full_shape.insert(full_shape.end(), field.field_shape.begin(), + field.field_shape.end()); + builder.input_shape(full_shape); + builder.input_labels(span(&normalized_dimension_names[0], total_rank)); + + DimensionSet implicit_upper_bounds(false); for (DimensionIndex i = 0; i < rank; ++i) { + implicit_upper_bounds[i] = true; + } + builder.implicit_upper_bounds(implicit_upper_bounds); + + for (DimensionIndex i = 0; i < total_rank; ++i) { builder.output_single_input_dimension(i, i); } return builder.Finalize(); @@ -391,10 +594,16 @@ class DataCacheBase absl::Status GetBoundSpecData(KvsDriverSpec& spec_base, const void* metadata_ptr, size_t component_index) override { - assert(component_index == 0); auto& spec = static_cast(spec_base); const auto& metadata = *static_cast(metadata_ptr); spec.metadata_constraints = ZarrMetadataConstraints(metadata); + // Encode selected_field from component_index + if (metadata.data_type.has_fields && + component_index < metadata.data_type.fields.size()) { + spec.selected_field = metadata.data_type.fields[component_index].name; + } else { + spec.selected_field.clear(); + } return absl::OkStatus(); } @@ -402,9 +611,16 @@ class DataCacheBase const void* metadata_ptr, size_t component_index) override { const auto& metadata = *static_cast(metadata_ptr); ChunkLayout chunk_layout; + SpecRankAndFieldInfo info; + info.chunked_rank = metadata.rank; + if (!metadata.data_type.fields.empty()) { + info.field = &metadata.data_type.fields[0]; + } + std::optional> chunk_shape_span; + chunk_shape_span.emplace(metadata.chunk_shape.data(), + metadata.chunk_shape.size()); TENSORSTORE_RETURN_IF_ERROR(SetChunkLayoutFromMetadata( - metadata.data_type, metadata.rank, metadata.chunk_shape, - &metadata.codec_specs, chunk_layout)); + info, chunk_shape_span, &metadata.codec_specs, chunk_layout)); TENSORSTORE_RETURN_IF_ERROR(chunk_layout.Finalize()); return chunk_layout; } @@ -424,7 +640,10 @@ class ZarrDataCache : public ChunkCacheImpl, public DataCacheBase { std::string key_prefix, U&&... arg) : ChunkCacheImpl(std::move(initializer.store), std::forward(arg)...), DataCacheBase(std::move(initializer), std::move(key_prefix)), - grid_(DataCacheBase::GetChunkGridSpecification(metadata())) {} + grid_(DataCacheBase::GetChunkGridSpecification( + metadata(), + // Check if this is void access by examining the dtype + ChunkCacheImpl::open_as_void_ ? kVoidFieldIndex : false)) {} const internal::LexicographicalGridIndexKeyParser& GetChunkStorageKeyParser() final { @@ -450,6 +669,51 @@ class ZarrDataCache : public ChunkCacheImpl, public DataCacheBase { return DataCacheBase::executor(); } + // Override to handle void access - check the dtype to see if this is void + Result> GetExternalToInternalTransform( + const void* metadata_ptr, size_t component_index) override { + const auto& metadata = *static_cast(metadata_ptr); + + // Check if this is void access by examining the stored flag + const bool is_void_access = ChunkCacheImpl::open_as_void_; + + if (is_void_access) { + // For void access, create transform with extra bytes dimension + const DimensionIndex rank = metadata.rank; + const Index bytes_per_element = metadata.data_type.bytes_per_outer_element; + const DimensionIndex total_rank = rank + 1; + + std::string_view normalized_dimension_names[kMaxRank]; + for (DimensionIndex i = 0; i < rank; ++i) { + if (const auto& name = metadata.dimension_names[i]; name.has_value()) { + normalized_dimension_names[i] = *name; + } + } + + auto builder = + tensorstore::IndexTransformBuilder<>(total_rank, total_rank); + std::vector full_shape = metadata.shape; + full_shape.push_back(bytes_per_element); + builder.input_shape(full_shape); + builder.input_labels(span(&normalized_dimension_names[0], total_rank)); + + DimensionSet implicit_upper_bounds(false); + for (DimensionIndex i = 0; i < rank; ++i) { + implicit_upper_bounds[i] = true; + } + builder.implicit_upper_bounds(implicit_upper_bounds); + + for (DimensionIndex i = 0; i < total_rank; ++i) { + builder.output_single_input_dimension(i, i); + } + return builder.Finalize(); + } + + // Not void access - delegate to base implementation + return DataCacheBase::GetExternalToInternalTransform(metadata_ptr, + component_index); + } + internal::ChunkGridSpecification grid_; }; @@ -470,7 +734,14 @@ class ZarrDriver : public ZarrDriverBase { Result> GetFillValue( IndexTransformView<> transform) override { const auto& metadata = this->metadata(); - return metadata.fill_value; + if (metadata.fill_value.empty()) { + return SharedArray(); + } + size_t index = this->component_index(); + if (index >= metadata.fill_value.size()) { + return absl::OutOfRangeError("Component index out of bounds"); + } + return metadata.fill_value[index]; } Future GetStorageStatistics( @@ -490,7 +761,8 @@ class ZarrDriver : public ZarrDriverBase { AnyFlowReceiver> receiver) override { return cache()->zarr_chunk_cache().Read( - {std::move(request), GetCurrentDataStalenessBound(), + {std::move(request), this->component_index(), + GetCurrentDataStalenessBound(), this->fill_value_mode_.fill_missing_data_reads}, std::move(receiver)); } @@ -500,7 +772,7 @@ class ZarrDriver : public ZarrDriverBase { AnyFlowReceiver> receiver) override { return cache()->zarr_chunk_cache().Write( - {std::move(request), + {std::move(request), this->component_index(), this->fill_value_mode_.store_data_equal_to_fill_value}, std::move(receiver)); } @@ -567,12 +839,16 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase { std::string GetDataCacheKey(const void* metadata) override { std::string result; + const auto& zarr_metadata = *static_cast(metadata); internal::EncodeCacheKey( - &result, spec().store.path, - static_cast(metadata)->GetCompatibilityKey()); + &result, + spec().store.path, + zarr_metadata.GetCompatibilityKey(), + spec().open_as_void ? "void" : "normal"); return result; } + Result> Create(const void* existing_metadata, CreateOptions options) override { if (existing_metadata) { @@ -581,7 +857,7 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase { TENSORSTORE_ASSIGN_OR_RETURN( auto metadata, internal_zarr3::GetNewMetadata(spec().metadata_constraints, - spec().schema), + spec().schema, spec().selected_field, spec().open_as_void), tensorstore::MaybeAnnotateStatus( _, "Cannot create using specified \"metadata\" and schema")); return metadata; @@ -596,9 +872,28 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase { DataCacheInitializer&& initializer) override { const auto& metadata = *static_cast(initializer.metadata.get()); + // For void access, modify the dtype to indicate special handling + ZarrDType dtype = metadata.data_type; + if (spec().open_as_void) { + // Create a synthetic dtype for void access + dtype = ZarrDType{ + /*.has_fields=*/false, + /*.fields=*/{ZarrDType::Field{ + ZarrDType::BaseDType{"", dtype_v, + {metadata.data_type.bytes_per_outer_element}}, + /*.outer_shape=*/{}, + /*.name=*/"", + /*.field_shape=*/{metadata.data_type.bytes_per_outer_element}, + /*.num_inner_elements=*/metadata.data_type.bytes_per_outer_element, + /*.byte_offset=*/0, + /*.num_bytes=*/metadata.data_type.bytes_per_outer_element}}, + /*.bytes_per_outer_element=*/metadata.data_type.bytes_per_outer_element}; + } return internal_zarr3::MakeZarrChunkCache( *metadata.codecs, std::move(initializer), spec().store.path, - metadata.codec_state, /*data_cache_pool=*/*cache_pool()); + metadata.codec_state, dtype, + /*data_cache_pool=*/*cache_pool(), + spec().open_as_void); } Result GetComponentIndex(const void* metadata_ptr, @@ -606,9 +901,16 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase { const auto& metadata = *static_cast(metadata_ptr); TENSORSTORE_RETURN_IF_ERROR( ValidateMetadata(metadata, spec().metadata_constraints)); + TENSORSTORE_ASSIGN_OR_RETURN( + auto field_index, + GetFieldIndex(metadata.data_type, spec().selected_field, spec().open_as_void)); + // For void access, map to component index 0 + if (field_index == kVoidFieldIndex) { + field_index = 0; + } TENSORSTORE_RETURN_IF_ERROR( - ValidateMetadataSchema(metadata, spec().schema)); - return 0; + ValidateMetadataSchema(metadata, field_index, spec().schema)); + return field_index; } }; diff --git a/tensorstore/driver/zarr3/dtype.cc b/tensorstore/driver/zarr3/dtype.cc new file mode 100644 index 000000000..b8aacaa68 --- /dev/null +++ b/tensorstore/driver/zarr3/dtype.cc @@ -0,0 +1,402 @@ +// Copyright 2025 The TensorStore Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorstore/driver/zarr3/dtype.h" + +#include + +#include + +#include "absl/base/optimization.h" +#include "absl/strings/ascii.h" +#include "tensorstore/data_type.h" +#include "tensorstore/internal/json_binding/json_binding.h" +#include "tensorstore/util/endian.h" +#include "tensorstore/util/extents.h" +#include "tensorstore/util/quote_string.h" +#include "tensorstore/util/str_cat.h" + +namespace tensorstore { +namespace internal_zarr3 { + +Result ParseBaseDType(std::string_view dtype) { + using D = ZarrDType::BaseDType; + const auto make_dtype = [&](DataType result_dtype) -> Result { + return D{std::string(dtype), result_dtype, {}}; + }; + + if (dtype == "bool") return make_dtype(dtype_v); + if (dtype == "uint8") return make_dtype(dtype_v); + if (dtype == "uint16") return make_dtype(dtype_v); + if (dtype == "uint32") return make_dtype(dtype_v); + if (dtype == "uint64") return make_dtype(dtype_v); + if (dtype == "int8") return make_dtype(dtype_v); + if (dtype == "int16") return make_dtype(dtype_v); + if (dtype == "int32") return make_dtype(dtype_v); + if (dtype == "int64") return make_dtype(dtype_v); + if (dtype == "bfloat16") + return make_dtype(dtype_v<::tensorstore::dtypes::bfloat16_t>); + if (dtype == "float16") + return make_dtype(dtype_v<::tensorstore::dtypes::float16_t>); + if (dtype == "float32") + return make_dtype(dtype_v<::tensorstore::dtypes::float32_t>); + if (dtype == "float64") + return make_dtype(dtype_v<::tensorstore::dtypes::float64_t>); + if (dtype == "complex64") + return make_dtype(dtype_v<::tensorstore::dtypes::complex64_t>); + if (dtype == "complex128") + return make_dtype(dtype_v<::tensorstore::dtypes::complex128_t>); + + // Handle r raw bits type where N is number of bits (must be multiple of 8) + if (dtype.size() > 1 && dtype[0] == 'r' && absl::ascii_isdigit(dtype[1])) { + std::string_view suffix = dtype.substr(1); + Index num_bits = 0; + if (!absl::SimpleAtoi(suffix, &num_bits) || + num_bits == 0 || + num_bits % 8 != 0) { + return absl::InvalidArgumentError(tensorstore::StrCat( + dtype, " data type is invalid; expected r where N is a positive " + "multiple of 8")); + } + Index num_bytes = num_bits / 8; + return ZarrDType::BaseDType{std::string(dtype), + dtype_v<::tensorstore::dtypes::byte_t>, + {num_bytes}}; + } + + // Handle bare "r" - must have a number after it + if (dtype.size() >= 1 && dtype[0] == 'r') { + return absl::InvalidArgumentError(tensorstore::StrCat( + dtype, " data type is invalid; expected r where N is a positive " + "multiple of 8")); + } + + constexpr std::string_view kSupported = + "bool, uint8, uint16, uint32, uint64, int8, int16, int32, int64, " + "bfloat16, float16, float32, float64, complex64, complex128, r"; + return absl::InvalidArgumentError( + tensorstore::StrCat(dtype, " data type is not one of the supported " + "data types: ", + kSupported)); +} + +namespace { + +/// Parses a zarr metadata "dtype" JSON specification, but does not compute any +/// derived values, and does not check for duplicate field names. +/// +/// This is called by `ParseDType`. +/// +/// \param value The zarr metadata "dtype" JSON specification. +/// \param out[out] Must be non-null. Filled with the parsed dtype on success. +/// \error `absl::StatusCode::kInvalidArgument' if `value` is invalid. +// Helper to parse fields array (used by both array format and object format) +absl::Status ParseFieldsArray(const nlohmann::json& fields_json, + ZarrDType& out) { + out.has_fields = true; + return internal_json::JsonParseArray( + fields_json, + [&](ptrdiff_t size) { + out.fields.resize(size); + return absl::OkStatus(); + }, + [&](const ::nlohmann::json& x, ptrdiff_t field_i) { + auto& field = out.fields[field_i]; + return internal_json::JsonParseArray( + x, + [&](ptrdiff_t size) { + if (size < 2 || size > 3) { + return absl::InvalidArgumentError(tensorstore::StrCat( + "Expected array of size 2 or 3, but received: ", x.dump())); + } + return absl::OkStatus(); + }, + [&](const ::nlohmann::json& v, ptrdiff_t i) { + switch (i) { + case 0: + if (internal_json::JsonRequireValueAs(v, &field.name).ok()) { + if (!field.name.empty()) return absl::OkStatus(); + } + return absl::InvalidArgumentError(tensorstore::StrCat( + "Expected non-empty string, but received: ", v.dump())); + case 1: { + std::string dtype_string; + TENSORSTORE_RETURN_IF_ERROR( + internal_json::JsonRequireValueAs(v, &dtype_string)); + TENSORSTORE_ASSIGN_OR_RETURN( + static_cast(field), + ParseBaseDType(dtype_string)); + return absl::OkStatus(); + } + case 2: { + return internal_json::JsonParseArray( + v, + [&](ptrdiff_t size) { + field.outer_shape.resize(size); + return absl::OkStatus(); + }, + [&](const ::nlohmann::json& x, ptrdiff_t j) { + return internal_json::JsonRequireInteger( + x, &field.outer_shape[j], /*strict=*/true, 1, + kInfIndex); + }); + } + default: + ABSL_UNREACHABLE(); // COV_NF_LINE + } + }); + }); +} + +Result ParseDTypeNoDerived(const nlohmann::json& value) { + ZarrDType out; + if (value.is_string()) { + // Single field. + out.has_fields = false; + out.fields.resize(1); + TENSORSTORE_ASSIGN_OR_RETURN( + static_cast(out.fields[0]), + ParseBaseDType(value.get())); + return out; + } + // Handle extended object format: + // {"name": "structured", "configuration": {"fields": [...]}} + if (value.is_object()) { + if (value.contains("name") && value.contains("configuration")) { + std::string type_name; + TENSORSTORE_RETURN_IF_ERROR( + internal_json::JsonRequireValueAs(value["name"], &type_name)); + if (type_name == "structured") { + const auto& config = value["configuration"]; + if (!config.is_object() || !config.contains("fields")) { + return absl::InvalidArgumentError( + "Structured data type requires 'configuration' object with " + "'fields' array"); + } + TENSORSTORE_RETURN_IF_ERROR(ParseFieldsArray(config["fields"], out)); + return out; + } + if (type_name == "raw_bytes") { + const auto& config = value["configuration"]; + if (!config.is_object() || !config.contains("length_bytes")) { + return absl::InvalidArgumentError( + "raw_bytes data type requires 'configuration' object with " + "'length_bytes' field"); + } + Index length_bytes; + TENSORSTORE_RETURN_IF_ERROR( + internal_json::JsonRequireValueAs(config["length_bytes"], &length_bytes)); + if (length_bytes <= 0) { + return absl::InvalidArgumentError( + "raw_bytes length_bytes must be positive"); + } + out.has_fields = false; + out.fields.resize(1); + out.fields[0].encoded_dtype = "raw_bytes"; + out.fields[0].dtype = dtype_v; + out.fields[0].flexible_shape = {length_bytes}; + out.fields[0].outer_shape = {}; + out.fields[0].name = ""; + out.fields[0].field_shape = {length_bytes}; + out.fields[0].num_inner_elements = length_bytes; + out.fields[0].byte_offset = 0; + out.fields[0].num_bytes = length_bytes; + out.bytes_per_outer_element = length_bytes; + return out; + } + // For other named types, try to parse as a base dtype + out.has_fields = false; + out.fields.resize(1); + TENSORSTORE_ASSIGN_OR_RETURN( + static_cast(out.fields[0]), + ParseBaseDType(type_name)); + return out; + } + return absl::InvalidArgumentError(tensorstore::StrCat( + "Expected string, array, or object with 'name' and 'configuration', " + "but received: ", + value.dump())); + } + // Handle array format: [["field1", "type1"], ["field2", "type2"], ...] + TENSORSTORE_RETURN_IF_ERROR(ParseFieldsArray(value, out)); + return out; +} + +} // namespace + +absl::Status ValidateDType(ZarrDType& dtype) { + dtype.bytes_per_outer_element = 0; + for (size_t field_i = 0; field_i < dtype.fields.size(); ++field_i) { + auto& field = dtype.fields[field_i]; + if (std::any_of( + dtype.fields.begin(), dtype.fields.begin() + field_i, + [&](const ZarrDType::Field& f) { return f.name == field.name; })) { + return absl::InvalidArgumentError(tensorstore::StrCat( + "Field name ", QuoteString(field.name), " occurs more than once")); + } + field.field_shape.resize(field.flexible_shape.size() + + field.outer_shape.size()); + std::copy(field.flexible_shape.begin(), field.flexible_shape.end(), + std::copy(field.outer_shape.begin(), field.outer_shape.end(), + field.field_shape.begin())); + + field.num_inner_elements = ProductOfExtents(span(field.field_shape)); + if (field.num_inner_elements == std::numeric_limits::max()) { + return absl::InvalidArgumentError(tensorstore::StrCat( + "Product of dimensions ", span(field.field_shape), " is too large")); + } + if (internal::MulOverflow(field.num_inner_elements, + static_cast(field.dtype->size), + &field.num_bytes)) { + return absl::InvalidArgumentError("Field size in bytes is too large"); + } + field.byte_offset = dtype.bytes_per_outer_element; + if (internal::AddOverflow(dtype.bytes_per_outer_element, field.num_bytes, + &dtype.bytes_per_outer_element)) { + return absl::InvalidArgumentError( + "Total number of bytes per outer array element is too large"); + } + } + return absl::OkStatus(); +} + +std::optional GetScalarDataType(const ZarrDType& dtype) { + if (!dtype.has_fields && !dtype.fields.empty()) { + return dtype.fields[0].dtype; + } + return std::nullopt; +} + +Result ParseDType(const nlohmann::json& value) { + TENSORSTORE_ASSIGN_OR_RETURN(ZarrDType dtype, ParseDTypeNoDerived(value)); + TENSORSTORE_RETURN_IF_ERROR(ValidateDType(dtype)); + return dtype; +} + +bool operator==(const ZarrDType::BaseDType& a, + const ZarrDType::BaseDType& b) { + return a.encoded_dtype == b.encoded_dtype && a.dtype == b.dtype && + a.flexible_shape == b.flexible_shape; +} + +bool operator!=(const ZarrDType::BaseDType& a, + const ZarrDType::BaseDType& b) { + return !(a == b); +} + +bool operator==(const ZarrDType::Field& a, const ZarrDType::Field& b) { + return static_cast(a) == + static_cast(b) && + a.outer_shape == b.outer_shape && a.name == b.name && + a.field_shape == b.field_shape && + a.num_inner_elements == b.num_inner_elements && + a.byte_offset == b.byte_offset && a.num_bytes == b.num_bytes; +} + +bool operator!=(const ZarrDType::Field& a, const ZarrDType::Field& b) { + return !(a == b); +} + +bool operator==(const ZarrDType& a, const ZarrDType& b) { + return a.has_fields == b.has_fields && + a.bytes_per_outer_element == b.bytes_per_outer_element && + a.fields == b.fields; +} + +bool operator!=(const ZarrDType& a, const ZarrDType& b) { return !(a == b); } + +void to_json(::nlohmann::json& out, const ZarrDType::Field& field) { + using array_t = ::nlohmann::json::array_t; + if (field.outer_shape.empty()) { + out = array_t{field.name, field.encoded_dtype}; + } else { + out = array_t{field.name, field.encoded_dtype, field.outer_shape}; + } +} + +void to_json(::nlohmann::json& out, // NOLINT + const ZarrDType& dtype) { + if (!dtype.has_fields) { + out = dtype.fields[0].encoded_dtype; + } else { + out = dtype.fields; + } +} + +TENSORSTORE_DEFINE_JSON_DEFAULT_BINDER(ZarrDType, [](auto is_loading, + const auto& options, + auto* obj, auto* j) { + if constexpr (is_loading) { + TENSORSTORE_ASSIGN_OR_RETURN(*obj, ParseDType(*j)); + } else { + to_json(*j, *obj); + } + return absl::OkStatus(); +}) + +namespace { + +Result MakeBaseDType(std::string_view name, + DataType dtype) { + ZarrDType::BaseDType base_dtype; + base_dtype.dtype = dtype; + base_dtype.encoded_dtype = std::string(name); + return base_dtype; +} + +} // namespace + +Result ChooseBaseDType(DataType dtype) { + if (dtype == dtype_v) return MakeBaseDType("bool", dtype); + if (dtype == dtype_v) return MakeBaseDType("uint8", dtype); + if (dtype == dtype_v) return MakeBaseDType("uint16", dtype); + if (dtype == dtype_v) return MakeBaseDType("uint32", dtype); + if (dtype == dtype_v) return MakeBaseDType("uint64", dtype); + if (dtype == dtype_v) return MakeBaseDType("int8", dtype); + if (dtype == dtype_v) return MakeBaseDType("int16", dtype); + if (dtype == dtype_v) return MakeBaseDType("int32", dtype); + if (dtype == dtype_v) return MakeBaseDType("int64", dtype); + if (dtype == dtype_v<::tensorstore::dtypes::bfloat16_t>) + return MakeBaseDType("bfloat16", dtype); + if (dtype == dtype_v<::tensorstore::dtypes::float16_t>) + return MakeBaseDType("float16", dtype); + if (dtype == dtype_v<::tensorstore::dtypes::float32_t>) + return MakeBaseDType("float32", dtype); + if (dtype == dtype_v<::tensorstore::dtypes::float64_t>) + return MakeBaseDType("float64", dtype); + if (dtype == dtype_v<::tensorstore::dtypes::complex64_t>) + return MakeBaseDType("complex64", dtype); + if (dtype == dtype_v<::tensorstore::dtypes::complex128_t>) + return MakeBaseDType("complex128", dtype); + if (dtype == dtype_v<::tensorstore::dtypes::byte_t>) { + ZarrDType::BaseDType base_dtype; + base_dtype.dtype = dtype; + base_dtype.encoded_dtype = "r8"; + base_dtype.flexible_shape = {1}; + return base_dtype; + } + if (dtype == dtype_v<::tensorstore::dtypes::char_t>) { + // char_t encodes as r8, which parses back to byte_t + ZarrDType::BaseDType base_dtype; + base_dtype.dtype = dtype_v<::tensorstore::dtypes::byte_t>; + base_dtype.encoded_dtype = "r8"; + base_dtype.flexible_shape = {1}; + return base_dtype; + } + return absl::InvalidArgumentError( + tensorstore::StrCat("Data type not supported: ", dtype)); +} + +} // namespace internal_zarr3 +} // namespace tensorstore diff --git a/tensorstore/driver/zarr3/dtype.h b/tensorstore/driver/zarr3/dtype.h new file mode 100644 index 000000000..73a6b0961 --- /dev/null +++ b/tensorstore/driver/zarr3/dtype.h @@ -0,0 +1,144 @@ +// Copyright 2025 The TensorStore Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORSTORE_DRIVER_ZARR3_DTYPE_H_ +#define TENSORSTORE_DRIVER_ZARR3_DTYPE_H_ + +/// \file +/// Support for encoding/decoding zarr "dtype" specifications. +/// See: https://zarr-specs.readthedocs.io/en/latest/v3/core/v3.0.html#data-type + +#include +#include +#include "tensorstore/data_type.h" +#include "tensorstore/internal/json_binding/bindable.h" +#include "tensorstore/util/endian.h" +#include "tensorstore/util/result.h" + +namespace tensorstore { +namespace internal_zarr3 { + +/// Decoded representation of a zarr "dtype" specification. +/// +/// A zarr "dtype" is a JSON value that is either: +/// +/// 1. A string, which specifies a single data type (e.g. "int32"). +/// In this case, the zarr array is considered to have a single, unnamed field. +/// +/// 2. An array, where each element of the array is of the form: +/// `[name, type]` or `[name, type, shape]`, where `name` is a JSON +/// string specifying the unique, non-empty field name, `type` is a data type +/// string, and `shape` is an optional "inner" array shape (specified +/// as a JSON array of non-negative integers) which defaults to the rank-0 +/// shape `[]` if not specified. +/// +/// Each field is encoded according to `type` into a fixed-size sequence of +/// bytes. If the optional "inner" array `shape` is specified, the individual +/// elements are encoded in C order. The encoding of each multi-field array +/// element is simply the concatenation of the encodings of each field. +struct ZarrDType { + /// Decoded representation of single value. + struct BaseDType { + /// Data type string. + std::string encoded_dtype; + + /// Corresponding DataType used for in-memory representation. + DataType dtype; + + /// For "flexible" data types that are themselves arrays, this specifies the + /// shape. For regular data types, this is empty. + std::vector flexible_shape; + }; + + /// Decoded representation of a single field. + struct Field : public BaseDType { + /// Optional `shape` dimensions specified by a zarr "dtype" field specified + /// as a JSON array. If the zarr dtype was specified as a single `typestr` + /// value, or as a two-element array, this is empty. + std::vector outer_shape; + + /// Field name. Must be non-empty and unique if the zarr "dtype" was + /// specified as an array. Otherwise, is empty. + std::string name; + + /// The inner array dimensions of this field, equal to the concatenation of + /// `outer_shape` and `flexible_shape` (derived value). + std::vector field_shape; + + /// Product of `field_shape` dimensions (derived value). + Index num_inner_elements; + + /// Byte offset of this field within an "outer" element (derived value). + Index byte_offset; + + /// Number of bytes occupied by this field within an "outer" element + /// (derived value). + Index num_bytes; + }; + + /// Equal to `true` if the zarr "dtype" was specified as an array, in which + /// case all fields must have a unique, non-empty `name`. If `false`, there + /// must be a single field with an empty `name`. + bool has_fields; + + /// Decoded representation of the fields. + std::vector fields; + + /// Bytes per "outer" element (derived value). + Index bytes_per_outer_element; + + TENSORSTORE_DECLARE_JSON_DEFAULT_BINDER(ZarrDType, + internal_json_binding::NoOptions) + + friend void to_json(::nlohmann::json& out, // NOLINT + const ZarrDType& dtype); +}; + +bool operator==(const ZarrDType::BaseDType& a, + const ZarrDType::BaseDType& b); +bool operator!=(const ZarrDType::BaseDType& a, + const ZarrDType::BaseDType& b); +bool operator==(const ZarrDType::Field& a, const ZarrDType::Field& b); +bool operator!=(const ZarrDType::Field& a, const ZarrDType::Field& b); +bool operator==(const ZarrDType& a, const ZarrDType& b); +bool operator!=(const ZarrDType& a, const ZarrDType& b); + +/// Parses a zarr metadata "dtype" JSON specification. +/// +/// \error `absl::StatusCode::kInvalidArgument` if `value` is not valid. +Result ParseDType(const ::nlohmann::json& value); + +/// Validates `dtype and computes derived values. +/// +/// \error `absl::StatusCode::kInvalidArgument` if two fields have the same +/// name. +/// \error `absl::StatusCode::kInvalidArgument` if the field size is too large. +absl::Status ValidateDType(ZarrDType& dtype); + +/// Returns the underlying TensorStore `DataType` if `dtype` represents an +/// unstructured scalar array, otherwise `std::nullopt`. +std::optional GetScalarDataType(const ZarrDType& dtype); + + /// Parses a Zarr 3 data type string. + /// + /// \error `absl::StatusCode::kInvalidArgument` if `dtype` is not valid. + Result ParseBaseDType(std::string_view dtype); + + /// Chooses a zarr data type corresponding to `dtype`. + Result ChooseBaseDType(DataType dtype); + +} // namespace internal_zarr3 +} // namespace tensorstore + +#endif // TENSORSTORE_DRIVER_ZARR3_DTYPE_H_ diff --git a/tensorstore/driver/zarr3/dtype_test.cc b/tensorstore/driver/zarr3/dtype_test.cc new file mode 100644 index 000000000..a41830069 --- /dev/null +++ b/tensorstore/driver/zarr3/dtype_test.cc @@ -0,0 +1,311 @@ +// Copyright 2025 The TensorStore Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorstore/driver/zarr3/dtype.h" + +#include +#include + +#include +#include + +#include +#include +#include "absl/status/status.h" +#include +#include "tensorstore/data_type.h" +#include "tensorstore/index.h" +#include "tensorstore/internal/testing/json_gtest.h" +#include "tensorstore/util/status_testutil.h" +#include "tensorstore/util/str_cat.h" + +namespace { + +using ::tensorstore::DataType; +using ::tensorstore::dtype_v; +using ::tensorstore::Index; +using ::tensorstore::kInfIndex; +using ::tensorstore::StatusIs; +using ::tensorstore::internal_zarr3::ChooseBaseDType; +using ::tensorstore::internal_zarr3::ParseBaseDType; +using ::tensorstore::internal_zarr3::ParseDType; +using ::tensorstore::internal_zarr3::ZarrDType; +using ::testing::HasSubstr; +using ::testing::MatchesRegex; + +void CheckBaseDType(std::string dtype, DataType r, + std::vector flexible_shape) { + EXPECT_THAT(ParseBaseDType(dtype), ::testing::Optional(ZarrDType::BaseDType{ + dtype, r, flexible_shape})) + << dtype; +} + +TEST(ParseBaseDType, Success) { + CheckBaseDType("bool", dtype_v, {}); + CheckBaseDType("int8", dtype_v, {}); + CheckBaseDType("uint8", dtype_v, {}); + CheckBaseDType("int16", dtype_v, {}); + CheckBaseDType("uint16", dtype_v, {}); + CheckBaseDType("int32", dtype_v, {}); + CheckBaseDType("uint32", dtype_v, {}); + CheckBaseDType("int64", dtype_v, {}); + CheckBaseDType("uint64", dtype_v, {}); + CheckBaseDType("float16", dtype_v, {}); + CheckBaseDType("bfloat16", dtype_v, {}); + CheckBaseDType("float32", dtype_v, {}); + CheckBaseDType("float64", dtype_v, {}); + CheckBaseDType("complex64", dtype_v, {}); + CheckBaseDType("complex128", dtype_v, {}); + CheckBaseDType("r8", dtype_v, {1}); + CheckBaseDType("r16", dtype_v, {2}); + CheckBaseDType("r64", dtype_v, {8}); +} + +TEST(ParseBaseDType, Failure) { + EXPECT_THAT( + ParseBaseDType(""), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("data type is not one of the supported data types"))); + EXPECT_THAT(ParseBaseDType("float"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(ParseBaseDType("string"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(ParseBaseDType(""))); + EXPECT_THAT(ParseBaseDType("r7"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("data type is invalid; expected r"))); + EXPECT_THAT(ParseBaseDType("r0"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("data type is invalid; expected r"))); +} + +void CheckDType(const ::nlohmann::json& json, const ZarrDType& expected) { + SCOPED_TRACE(json.dump()); + TENSORSTORE_ASSERT_OK_AND_ASSIGN(auto dtype, ParseDType(json)); + EXPECT_EQ(expected, dtype); + // Check round trip. + EXPECT_EQ(json, ::nlohmann::json(dtype)); +} + +TEST(ParseDType, SimpleStringBool) { + CheckDType("bool", ZarrDType{ + /*.has_fields=*/false, + /*.fields=*/ + { + {{ + /*.encoded_dtype=*/"bool", + /*.dtype=*/dtype_v, + /*.flexible_shape=*/{}, + }, + /*.outer_shape=*/{}, + /*.name=*/"", + /*.field_shape=*/{}, + /*.num_inner_elements=*/1, + /*.byte_offset=*/0, + /*.num_bytes=*/1}, + }, + /*.bytes_per_outer_element=*/1, + }); +} + +TEST(ParseDType, SingleNamedFieldChar) { + // Zarr 3 doesn't support fixed size strings natively in core, so we use uint8 for testing bytes + CheckDType(::nlohmann::json::array_t{{"x", "uint8"}}, + ZarrDType{ + /*.has_fields=*/true, + /*.fields=*/ + { + {{ + /*.encoded_dtype=*/"uint8", + /*.dtype=*/dtype_v, + /*.flexible_shape=*/{}, + }, + /*.outer_shape=*/{}, + /*.name=*/"x", + /*.field_shape=*/{}, + /*.num_inner_elements=*/1, + /*.byte_offset=*/0, + /*.num_bytes=*/1}, + }, + /*.bytes_per_outer_element=*/1, + }); +} + +TEST(ParseDType, TwoNamedFields) { + CheckDType( + ::nlohmann::json::array_t{{"x", "int8", {2, 3}}, {"y", "int16", {5}}}, + ZarrDType{ + /*.has_fields=*/true, + /*.fields=*/ + { + {{ + /*.encoded_dtype=*/"int8", + /*.dtype=*/dtype_v, + /*.flexible_shape=*/{}, + }, + /*.outer_shape=*/{2, 3}, + /*.name=*/"x", + /*.field_shape=*/{2, 3}, + /*.num_inner_elements=*/2 * 3, + /*.byte_offset=*/0, + /*.num_bytes=*/1 * 2 * 3}, + {{ + /*.encoded_dtype=*/"int16", + /*.dtype=*/dtype_v, + /*.flexible_shape=*/{}, + }, + /*.outer_shape=*/{5}, + /*.name=*/"y", + /*.field_shape=*/{5}, + /*.num_inner_elements=*/5, + /*.byte_offset=*/1 * 2 * 3, + /*.num_bytes=*/2 * 5}, + }, + /*.bytes_per_outer_element=*/1 * 2 * 3 + 2 * 5, + }); +} + +TEST(ParseDType, FieldSpecTooShort) { + EXPECT_THAT( + ParseDType(::nlohmann::json::array_t{{"x"}}), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("Error parsing value at position 0: " + "Expected array of size 2 or 3, but received: [\"x\"]"))); +} + +TEST(ParseDType, FieldSpecTooLong) { + EXPECT_THAT( + ParseDType(::nlohmann::json::array_t{{"x", "int16", {2, 3}, 5}}), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("Error parsing value at position 0: " + "Expected array of size 2 or 3, but received: " + "[\"x\",\"int16\",[2,3],5]"))); +} + +TEST(ParseDType, InvalidFieldName) { + EXPECT_THAT( + ParseDType(::nlohmann::json::array_t{{3, "int16"}}), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Error parsing value at position 0: " + "Error parsing value at position 0: " + "Expected non-empty string, but received: 3"))); +} + +TEST(ParseDType, EmptyFieldName) { + EXPECT_THAT( + ParseDType(::nlohmann::json::array_t{{"", "int16"}}), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Error parsing value at position 0: " + "Error parsing value at position 0: " + "Expected non-empty string, but received: \"\""))); +} + +TEST(ParseDType, DuplicateFieldName) { + EXPECT_THAT( + ParseDType(::nlohmann::json::array_t{{"x", "int16"}, {"x", "uint16"}}), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Field name \"x\" occurs more than once"))); +} + +TEST(ParseDType, NonStringFieldBaseDType) { + EXPECT_THAT(ParseDType(::nlohmann::json::array_t{{"x", 3}}), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Error parsing value at position 0: " + "Error parsing value at position 1: " + "Expected string, but received: 3"))); +} + +TEST(ParseDType, InvalidFieldBaseDType) { + EXPECT_THAT(ParseDType(::nlohmann::json::array_t{{"x", "unknown"}}), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Error parsing value at position 0: " + "Error parsing value at position 1: " + "unknown data type is not one of the " + "supported data types"))); +} + +TEST(ParseDType, ProductOfDimensionsOverflow) { + EXPECT_THAT( + ParseDType( + ::nlohmann::json::array_t{{"x", "int8", {kInfIndex, kInfIndex}}}), + StatusIs(absl::StatusCode::kInvalidArgument, + MatchesRegex(".*Product of dimensions .* is too large.*"))); +} + +TEST(ParseDType, FieldSizeInBytesOverflow) { + EXPECT_THAT( + ParseDType(::nlohmann::json::array_t{{"x", "float64", {kInfIndex}}}), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Field size in bytes is too large"))); +} + +TEST(ParseDType, BytesPerOuterElementOverflow) { + EXPECT_THAT( + ParseDType(::nlohmann::json::array_t{{"x", "int16", {kInfIndex}}, + {"y", "int16", {kInfIndex}}}), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "Total number of bytes per outer array element is too large"))); +} + +TEST(ChooseBaseDTypeTest, RoundTrip) { + constexpr tensorstore::DataType kSupportedDataTypes[] = { + dtype_v, dtype_v, dtype_v, dtype_v, + dtype_v, dtype_v, dtype_v, + dtype_v, dtype_v, + dtype_v, + dtype_v, + dtype_v, + dtype_v, + dtype_v, + dtype_v, + dtype_v, + dtype_v, + }; + for (auto dtype : kSupportedDataTypes) { + SCOPED_TRACE(tensorstore::StrCat("dtype=", dtype)); + TENSORSTORE_ASSERT_OK_AND_ASSIGN(auto base_zarr_dtype, + ChooseBaseDType(dtype)); + // byte_t and char_t both encode as r8, which parses back to byte_t + DataType expected_dtype = dtype; + if (dtype == dtype_v) { + expected_dtype = dtype_v; + } + EXPECT_EQ(expected_dtype, base_zarr_dtype.dtype); + TENSORSTORE_ASSERT_OK_AND_ASSIGN( + auto parsed, ParseBaseDType(base_zarr_dtype.encoded_dtype)); + EXPECT_EQ(expected_dtype, parsed.dtype); + EXPECT_EQ(base_zarr_dtype.flexible_shape, parsed.flexible_shape); + EXPECT_EQ(base_zarr_dtype.encoded_dtype, parsed.encoded_dtype); + } +} + +TEST(ChooseBaseDTypeTest, Invalid) { + struct X {}; + EXPECT_THAT(ChooseBaseDType(dtype_v), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Data type not supported"))); + EXPECT_THAT(ChooseBaseDType(dtype_v<::tensorstore::dtypes::string_t>), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Data type not supported: string"))); +} + +} // namespace diff --git a/tensorstore/driver/zarr3/metadata.cc b/tensorstore/driver/zarr3/metadata.cc index 528d373ae..ba4454de4 100644 --- a/tensorstore/driver/zarr3/metadata.cc +++ b/tensorstore/driver/zarr3/metadata.cc @@ -31,7 +31,10 @@ #include #include +#include + #include "absl/algorithm/container.h" +#include "absl/strings/escaping.h" #include "absl/base/casts.h" #include "absl/base/optimization.h" #include "absl/meta/type_traits.h" @@ -50,6 +53,7 @@ #include "tensorstore/driver/zarr3/codec/codec_spec.h" #include "tensorstore/driver/zarr3/codec/sharding_indexed.h" #include "tensorstore/driver/zarr3/default_nan.h" +#include "tensorstore/driver/zarr3/dtype.h" #include "tensorstore/driver/zarr3/name_configuration_json_binder.h" #include "tensorstore/index.h" #include "tensorstore/index_space/dimension_units.h" @@ -246,30 +250,189 @@ constexpr std::array FillValueDataTypeFunctions::Make<::tensorstore::dtypes::T>(); \ /**/ TENSORSTORE_ZARR3_FOR_EACH_DATA_TYPE(TENSORSTORE_INTERNAL_DO_DEF) + // Add char_t support for string data types + functions[static_cast(DataTypeId::char_t)] = + FillValueDataTypeFunctions::Make<::tensorstore::dtypes::char_t>(); + // byte_t is handled specially to use uint8_t functions #undef TENSORSTORE_INTERNAL_DO_DEF return functions; }(); } // namespace -absl::Status FillValueJsonBinder::operator()(std::true_type is_loading, - internal_json_binding::NoOptions, - SharedArray* obj, - ::nlohmann::json* j) const { +FillValueJsonBinder::FillValueJsonBinder(ZarrDType dtype, + bool allow_missing_dtype) + : dtype(std::move(dtype)), allow_missing_dtype(allow_missing_dtype) {} + +FillValueJsonBinder::FillValueJsonBinder(DataType data_type, + bool allow_missing_dtype) + : allow_missing_dtype(allow_missing_dtype) { + dtype.has_fields = false; + dtype.fields.resize(1); + auto& field = dtype.fields[0]; + field.name.clear(); + field.outer_shape.clear(); + field.flexible_shape.clear(); + field.field_shape.clear(); + field.num_inner_elements = 1; + field.byte_offset = 0; + field.num_bytes = data_type->size; + field.dtype = data_type; + field.encoded_dtype = std::string(data_type.name()); +} + +absl::Status FillValueJsonBinder::operator()( + std::true_type is_loading, internal_json_binding::NoOptions, + std::vector>* obj, ::nlohmann::json* j) const { + obj->resize(dtype.fields.size()); + if (dtype.fields.size() == 1) { + // Special case: raw_bytes (single field with byte_t and flexible shape) + if (dtype.fields[0].dtype.id() == DataTypeId::byte_t && + !dtype.fields[0].flexible_shape.empty()) { + // Handle base64-encoded fill value for raw_bytes + if (!j->is_string()) { + return absl::InvalidArgumentError( + "Expected base64-encoded string for raw_bytes fill_value"); + } + std::string b64_decoded; + if (!absl::Base64Unescape(j->get(), &b64_decoded)) { + return absl::InvalidArgumentError(tensorstore::StrCat( + "Expected valid base64-encoded fill value, but received: ", + j->dump())); + } + // Verify size matches expected byte array size + Index expected_size = dtype.fields[0].num_inner_elements; + if (static_cast(b64_decoded.size()) != expected_size) { + return absl::InvalidArgumentError(tensorstore::StrCat( + "Expected ", expected_size, + " base64-encoded bytes for fill_value, but received ", + b64_decoded.size(), " bytes")); + } + // Create fill value array + auto fill_arr = AllocateArray(dtype.fields[0].field_shape, c_order, + default_init, dtype.fields[0].dtype); + std::memcpy(fill_arr.data(), b64_decoded.data(), b64_decoded.size()); + (*obj)[0] = std::move(fill_arr); + } else { + TENSORSTORE_RETURN_IF_ERROR( + DecodeSingle(*j, dtype.fields[0].dtype, (*obj)[0])); + } + } else { + // For structured types, handle both array format and base64-encoded string + if (j->is_string()) { + // Decode base64-encoded fill value for entire struct + std::string b64_decoded; + if (!absl::Base64Unescape(j->get(), &b64_decoded)) { + return absl::InvalidArgumentError(tensorstore::StrCat( + "Expected valid base64-encoded fill value, but received: ", + j->dump())); + } + // Verify size matches expected struct size + if (static_cast(b64_decoded.size()) != + dtype.bytes_per_outer_element) { + return absl::InvalidArgumentError(tensorstore::StrCat( + "Expected ", dtype.bytes_per_outer_element, + " base64-encoded bytes for fill_value, but received ", + b64_decoded.size(), " bytes")); + } + // Extract per-field fill values from decoded bytes + for (size_t i = 0; i < dtype.fields.size(); ++i) { + const auto& field = dtype.fields[i]; + auto arr = AllocateArray(span{}, c_order, default_init, + field.dtype); + std::memcpy(arr.data(), b64_decoded.data() + field.byte_offset, + field.dtype->size); + (*obj)[i] = std::move(arr); + } + } else if (j->is_array()) { + if (j->size() != dtype.fields.size()) { + return internal_json::ExpectedError( + *j, tensorstore::StrCat("array of size ", dtype.fields.size())); + } + for (size_t i = 0; i < dtype.fields.size(); ++i) { + TENSORSTORE_RETURN_IF_ERROR( + DecodeSingle((*j)[i], dtype.fields[i].dtype, (*obj)[i])); + } + } else { + return internal_json::ExpectedError(*j, + "array or base64-encoded string"); + } + } + return absl::OkStatus(); +} + +absl::Status FillValueJsonBinder::operator()( + std::false_type is_loading, internal_json_binding::NoOptions, + const std::vector>* obj, + ::nlohmann::json* j) const { + if (dtype.fields.size() == 1) { + return EncodeSingle((*obj)[0], dtype.fields[0].dtype, *j); + } + // Structured fill value + *j = ::nlohmann::json::array(); + for (size_t i = 0; i < dtype.fields.size(); ++i) { + ::nlohmann::json item; + TENSORSTORE_RETURN_IF_ERROR( + EncodeSingle((*obj)[i], dtype.fields[i].dtype, item)); + j->push_back(std::move(item)); + } + return absl::OkStatus(); +} + +absl::Status FillValueJsonBinder::DecodeSingle(::nlohmann::json& j, + DataType data_type, + SharedArray& out) const { + if (!data_type.valid()) { + if (allow_missing_dtype) { + out = SharedArray(); + return absl::OkStatus(); + } + return absl::InvalidArgumentError( + "data_type must be specified before fill_value"); + } auto arr = AllocateArray(span{}, c_order, default_init, data_type); void* data = arr.data(); - *obj = std::move(arr); - return kFillValueDataTypeFunctions[static_cast(data_type.id())] - .decode(data, *j); + out = std::move(arr); + // Special handling for byte_t: use uint8_t functions since they're binary compatible + auto type_id = data_type.id(); + if (type_id == DataTypeId::byte_t) { + type_id = DataTypeId::uint8_t; + } + + const auto& functions = + kFillValueDataTypeFunctions[static_cast(type_id)]; + if (!functions.decode) { + if (allow_missing_dtype) { + out = SharedArray(); + return absl::OkStatus(); + } + return absl::FailedPreconditionError( + "fill_value unsupported for specified data_type"); + } + return functions.decode(data, j); } -absl::Status FillValueJsonBinder::operator()(std::false_type is_loading, - internal_json_binding::NoOptions, - const SharedArray* obj, - ::nlohmann::json* j) const { - return kFillValueDataTypeFunctions[static_cast(data_type.id())] - .encode(obj->data(), *j); +absl::Status FillValueJsonBinder::EncodeSingle( + const SharedArray& arr, DataType data_type, + ::nlohmann::json& j) const { + if (!data_type.valid()) { + return absl::InvalidArgumentError( + "data_type must be specified before fill_value"); + } + // Special handling for byte_t: use uint8_t functions since they're binary compatible + auto type_id = data_type.id(); + if (type_id == DataTypeId::byte_t) { + type_id = DataTypeId::uint8_t; + } + + const auto& functions = + kFillValueDataTypeFunctions[static_cast(type_id)]; + if (!functions.encode) { + return absl::FailedPreconditionError( + "fill_value unsupported for specified data_type"); + } + return functions.encode(arr.data(), j); } TENSORSTORE_DEFINE_JSON_DEFAULT_BINDER(ChunkKeyEncoding, [](auto is_loading, @@ -357,7 +520,7 @@ constexpr auto MetadataJsonBinder = [] { rank = &obj->rank; } - auto ensure_data_type = [&]() -> Result { + auto ensure_data_type = [&]() -> Result { if constexpr (std::is_same_v) { return obj->data_type; } @@ -378,19 +541,18 @@ constexpr auto MetadataJsonBinder = [] { maybe_optional_member("node_type", jb::Constant([] { return "array"; })), jb::Member("data_type", - jb::Projection<&Self::data_type>(maybe_optional(jb::Validate( - [](const auto& options, auto* obj) { - return ValidateDataType(*obj); - }, - jb::DataTypeJsonBinder)))), + jb::Projection<&Self::data_type>(maybe_optional( + jb::DefaultBinder<>))), jb::Member( "fill_value", jb::Projection<&Self::fill_value>(maybe_optional( [&](auto is_loading, const auto& options, auto* obj, auto* j) { TENSORSTORE_ASSIGN_OR_RETURN(auto data_type, ensure_data_type()); - return FillValueJsonBinder{data_type}(is_loading, options, - obj, j); + constexpr bool allow_missing_dtype = + std::is_same_v; + return FillValueJsonBinder{data_type, allow_missing_dtype}( + is_loading, options, obj, j); }))), non_compatibility_field( jb::Member("shape", jb::Projection<&Self::shape>( @@ -475,11 +637,35 @@ std::string ZarrMetadata::GetCompatibilityKey() const { } absl::Status ValidateMetadata(ZarrMetadata& metadata) { + // Determine if this is a structured type with multiple fields + const bool is_structured = + metadata.data_type.fields.size() > 1 || + (metadata.data_type.fields.size() == 1 && + !metadata.data_type.fields[0].outer_shape.empty()); + + // Build the codec shape - for structured types, include bytes dimension + std::vector codec_shape(metadata.chunk_shape.begin(), + metadata.chunk_shape.end()); + if (is_structured) { + codec_shape.push_back(metadata.data_type.bytes_per_outer_element); + } + if (!metadata.codecs) { ArrayCodecResolveParameters decoded; - decoded.dtype = metadata.data_type; - decoded.rank = metadata.rank; - decoded.fill_value = metadata.fill_value; + if (!is_structured) { + decoded.dtype = metadata.data_type.fields[0].dtype; + decoded.rank = metadata.rank; + } else { + // For structured types, use byte dtype with extra dimension + decoded.dtype = dtype_v; + decoded.rank = metadata.rank + 1; + } + // Fill value for codec resolve might be complex. + // For structured types, create a byte fill value + if (metadata.fill_value.size() == 1 && !is_structured) { + decoded.fill_value = metadata.fill_value[0]; + } + BytesCodecResolveParameters encoded; TENSORSTORE_ASSIGN_OR_RETURN( metadata.codecs, @@ -488,10 +674,19 @@ absl::Status ValidateMetadata(ZarrMetadata& metadata) { // Get codec chunk layout info. ArrayDataTypeAndShapeInfo array_info; - array_info.dtype = metadata.data_type; - array_info.rank = metadata.rank; - std::copy_n(metadata.chunk_shape.begin(), metadata.rank, - array_info.shape.emplace().begin()); + if (!is_structured) { + array_info.dtype = metadata.data_type.fields[0].dtype; + array_info.rank = metadata.rank; + std::copy_n(metadata.chunk_shape.begin(), metadata.rank, + array_info.shape.emplace().begin()); + } else { + array_info.dtype = dtype_v; + array_info.rank = metadata.rank + 1; + auto& shape = array_info.shape.emplace(); + std::copy_n(metadata.chunk_shape.begin(), metadata.rank, shape.begin()); + shape[metadata.rank] = metadata.data_type.bytes_per_outer_element; + } + ArrayCodecChunkLayoutInfo layout_info; TENSORSTORE_RETURN_IF_ERROR( metadata.codec_specs.GetDecodedChunkLayout(array_info, layout_info)); @@ -505,24 +700,41 @@ absl::Status ValidateMetadata(ZarrMetadata& metadata) { } TENSORSTORE_ASSIGN_OR_RETURN(metadata.codec_state, - metadata.codecs->Prepare(metadata.chunk_shape)); + metadata.codecs->Prepare(codec_shape)); return absl::OkStatus(); } absl::Status ValidateMetadata(const ZarrMetadata& metadata, const ZarrMetadataConstraints& constraints) { using internal::MetadataMismatchError; - if (constraints.data_type && *constraints.data_type != metadata.data_type) { - return MetadataMismatchError("data_type", constraints.data_type->name(), - metadata.data_type.name()); - } - if (constraints.fill_value && - !AreArraysIdenticallyEqual(*constraints.fill_value, - metadata.fill_value)) { - auto binder = FillValueJsonBinder{metadata.data_type}; - auto constraint_json = jb::ToJson(*constraints.fill_value, binder).value(); - auto metadata_json = jb::ToJson(metadata.fill_value, binder).value(); - return MetadataMismatchError("fill_value", constraint_json, metadata_json); + if (constraints.data_type) { + // Compare ZarrDType + if (::nlohmann::json(*constraints.data_type) != + ::nlohmann::json(metadata.data_type)) { + return MetadataMismatchError( + "data_type", ::nlohmann::json(*constraints.data_type).dump(), + ::nlohmann::json(metadata.data_type).dump()); + } + } + if (constraints.fill_value) { + // Compare vector of arrays + if (constraints.fill_value->size() != metadata.fill_value.size()) { + return MetadataMismatchError("fill_value size", + constraints.fill_value->size(), + metadata.fill_value.size()); + } + for (size_t i = 0; i < metadata.fill_value.size(); ++i) { + if (!AreArraysIdenticallyEqual((*constraints.fill_value)[i], + metadata.fill_value[i])) { + auto binder = FillValueJsonBinder{metadata.data_type}; + auto constraint_json = + jb::ToJson(*constraints.fill_value, binder).value(); + auto metadata_json = + jb::ToJson(metadata.fill_value, binder).value(); + return MetadataMismatchError("fill_value", constraint_json, + metadata_json); + } + } } if (constraints.shape && *constraints.shape != metadata.shape) { return MetadataMismatchError("shape", *constraints.shape, metadata.shape); @@ -574,24 +786,97 @@ absl::Status ValidateMetadata(const ZarrMetadata& metadata, metadata.unknown_extension_attributes); } +namespace { +std::string GetFieldNames(const ZarrDType& dtype) { + std::vector field_names; + for (const auto& field : dtype.fields) { + field_names.push_back(field.name); + } + return ::nlohmann::json(field_names).dump(); +} +} // namespace + +constexpr size_t kVoidFieldIndex = size_t(-1); + +Result GetFieldIndex(const ZarrDType& dtype, + std::string_view selected_field, + bool open_as_void) { + // Special case: open_as_void requests raw byte access (works for any dtype) + + if (open_as_void) { + if (dtype.fields.empty()) { + return absl::FailedPreconditionError( + "Requested void access but dtype has no fields"); + } + return kVoidFieldIndex; + } + + if (selected_field.empty()) { + if (dtype.fields.size() != 1) { + return absl::FailedPreconditionError(tensorstore::StrCat( + "Must specify a \"field\" that is one of: ", GetFieldNames(dtype))); + } + return 0; + } + if (!dtype.has_fields) { + return absl::FailedPreconditionError( + tensorstore::StrCat("Requested field ", QuoteString(selected_field), + " but dtype does not have named fields")); + } + for (size_t field_index = 0; field_index < dtype.fields.size(); + ++field_index) { + if (dtype.fields[field_index].name == selected_field) return field_index; + } + return absl::FailedPreconditionError( + tensorstore::StrCat("Requested field ", QuoteString(selected_field), + " is not one of: ", GetFieldNames(dtype))); +} + +SpecRankAndFieldInfo GetSpecRankAndFieldInfo(const ZarrMetadata& metadata, + size_t field_index) { + SpecRankAndFieldInfo info; + info.chunked_rank = metadata.rank; + info.field = &metadata.data_type.fields[field_index]; + if (!info.field->field_shape.empty()) { + info.chunked_rank += info.field->field_shape.size(); + } + return info; +} + Result> GetEffectiveDomain( - DimensionIndex rank, std::optional> shape, + const SpecRankAndFieldInfo& info, + std::optional> metadata_shape, std::optional>> dimension_names, - const Schema& schema, bool* dimension_names_used = nullptr) { + const Schema& schema, bool* dimension_names_used) { + const DimensionIndex rank = info.chunked_rank; if (dimension_names_used) *dimension_names_used = false; auto domain = schema.domain(); - if (!shape && !dimension_names && !domain.valid()) { + if (!metadata_shape && !dimension_names && !domain.valid()) { if (schema.rank() == 0) return {std::in_place, 0}; - // No information about the domain available. return {std::in_place}; } - // Rank is already validated by caller. assert(RankConstraint::EqualOrUnspecified(schema.rank(), rank)); IndexDomainBuilder builder(std::max(schema.rank().rank, rank)); - if (shape) { - builder.shape(*shape); - builder.implicit_upper_bounds(true); + if (metadata_shape) { + if (static_cast(metadata_shape->size()) < rank && + info.field && !info.field->field_shape.empty() && + static_cast(metadata_shape->size() + + info.field->field_shape.size()) == rank) { + std::vector full_shape(metadata_shape->begin(), + metadata_shape->end()); + full_shape.insert(full_shape.end(), info.field->field_shape.begin(), + info.field->field_shape.end()); + builder.shape(full_shape); + DimensionSet implicit_upper_bounds(false); + for (size_t i = 0; i < metadata_shape->size(); ++i) { + implicit_upper_bounds[i] = true; + } + builder.implicit_upper_bounds(implicit_upper_bounds); + } else { + builder.shape(*metadata_shape); + builder.implicit_upper_bounds(true); + } } else { builder.origin(GetConstantVector(builder.rank())); } @@ -602,12 +887,12 @@ Result> GetEffectiveDomain( normalized_dimension_names[i] = *name; } } - // Use dimension_names as labels if they are valid. - if (internal::ValidateDimensionLabelsAreUnique(normalized_dimension_names) + if (internal::ValidateDimensionLabelsAreUnique( + span(&normalized_dimension_names[0], rank)) .ok()) { - if (dimension_names_used) *dimension_names_used = true; builder.labels( span(&normalized_dimension_names[0], rank)); + if (dimension_names_used) *dimension_names_used = true; } } @@ -618,36 +903,53 @@ Result> GetEffectiveDomain( tensorstore::MaybeAnnotateStatus( _, "Mismatch between metadata and schema"))); return WithImplicitDimensions(domain, false, true); - return domain; } Result> GetEffectiveDomain( const ZarrMetadataConstraints& metadata_constraints, const Schema& schema, bool* dimension_names_used) { - return GetEffectiveDomain( - metadata_constraints.rank, metadata_constraints.shape, - metadata_constraints.dimension_names, schema, dimension_names_used); + SpecRankAndFieldInfo info; + info.chunked_rank = metadata_constraints.rank; + if (info.chunked_rank == dynamic_rank && metadata_constraints.shape) { + info.chunked_rank = metadata_constraints.shape->size(); + } + + std::optional> shape_span; + if (metadata_constraints.shape) { + shape_span.emplace(metadata_constraints.shape->data(), + metadata_constraints.shape->size()); + } + std::optional>> names_span; + if (metadata_constraints.dimension_names) { + names_span.emplace(metadata_constraints.dimension_names->data(), + metadata_constraints.dimension_names->size()); + } + + return GetEffectiveDomain(info, shape_span, names_span, schema, + dimension_names_used); } absl::Status SetChunkLayoutFromMetadata( - DataType dtype, DimensionIndex rank, + const SpecRankAndFieldInfo& info, std::optional> chunk_shape, const ZarrCodecChainSpec* codecs, ChunkLayout& chunk_layout) { - TENSORSTORE_RETURN_IF_ERROR(chunk_layout.Set(RankConstraint{rank})); - rank = chunk_layout.rank(); - if (rank == dynamic_rank) return absl::OkStatus(); + const DimensionIndex rank = info.chunked_rank; + if (rank == dynamic_rank) { + return absl::OkStatus(); + } + TENSORSTORE_RETURN_IF_ERROR(chunk_layout.Set(RankConstraint(rank))); + TENSORSTORE_RETURN_IF_ERROR(chunk_layout.Set( + ChunkLayout::GridOrigin(GetConstantVector(rank)))); if (chunk_shape) { assert(chunk_shape->size() == rank); TENSORSTORE_RETURN_IF_ERROR( chunk_layout.Set(ChunkLayout::WriteChunkShape(*chunk_shape))); } - TENSORSTORE_RETURN_IF_ERROR(chunk_layout.Set( - ChunkLayout::GridOrigin(GetConstantVector(rank)))); if (codecs) { ArrayDataTypeAndShapeInfo array_info; - array_info.dtype = dtype; + array_info.dtype = info.field ? info.field->dtype : dtype_v; array_info.rank = rank; if (chunk_shape) { std::copy_n(chunk_shape->begin(), rank, @@ -669,30 +971,47 @@ absl::Status SetChunkLayoutFromMetadata( span(layout_info.codec_chunk_shape->data(), rank)))); } } + return absl::OkStatus(); } -Result GetEffectiveChunkLayout( +absl::Status SetChunkLayoutFromMetadata( DataType dtype, DimensionIndex rank, std::optional> chunk_shape, - const ZarrCodecChainSpec* codecs, const Schema& schema) { - auto chunk_layout = schema.chunk_layout(); - TENSORSTORE_RETURN_IF_ERROR(SetChunkLayoutFromMetadata( - dtype, rank, chunk_shape, codecs, chunk_layout)); - return chunk_layout; + const ZarrCodecChainSpec* codecs, ChunkLayout& chunk_layout) { + SpecRankAndFieldInfo info; + info.chunked_rank = rank; + info.field = nullptr; + return SetChunkLayoutFromMetadata(info, chunk_shape, codecs, chunk_layout); } Result GetEffectiveChunkLayout( const ZarrMetadataConstraints& metadata_constraints, const Schema& schema) { - assert(RankConstraint::EqualOrUnspecified(metadata_constraints.rank, - schema.rank())); - return GetEffectiveChunkLayout( - metadata_constraints.data_type.value_or(DataType{}), - std::max(metadata_constraints.rank, schema.rank().rank), - metadata_constraints.chunk_shape, + // Approximation: assume whole array access or simple array + SpecRankAndFieldInfo info; + info.chunked_rank = std::max(metadata_constraints.rank, schema.rank().rank); + if (info.chunked_rank == dynamic_rank && metadata_constraints.shape) { + info.chunked_rank = metadata_constraints.shape->size(); + } + if (info.chunked_rank == dynamic_rank && metadata_constraints.chunk_shape) { + info.chunked_rank = metadata_constraints.chunk_shape->size(); + } + // We can't easily know field info from constraints unless we parse data_type. + // If data_type is present and has 1 field, we can check it. + // For now, basic implementation. + + ChunkLayout chunk_layout = schema.chunk_layout(); + std::optional> chunk_shape_span; + if (metadata_constraints.chunk_shape) { + chunk_shape_span.emplace(metadata_constraints.chunk_shape->data(), + metadata_constraints.chunk_shape->size()); + } + TENSORSTORE_RETURN_IF_ERROR(SetChunkLayoutFromMetadata( + info, chunk_shape_span, metadata_constraints.codec_specs ? &*metadata_constraints.codec_specs : nullptr, - schema); + chunk_layout)); + return chunk_layout; } Result GetDimensionUnits( @@ -732,53 +1051,63 @@ CodecSpec GetCodecFromMetadata(const ZarrMetadata& metadata) { } absl::Status ValidateMetadataSchema(const ZarrMetadata& metadata, - const Schema& schema) { - if (!RankConstraint::EqualOrUnspecified(metadata.rank, schema.rank())) { + size_t field_index, const Schema& schema) { + auto info = GetSpecRankAndFieldInfo(metadata, field_index); + const auto& field = metadata.data_type.fields[field_index]; + + if (!RankConstraint::EqualOrUnspecified(schema.rank(), info.chunked_rank)) { return absl::FailedPreconditionError(tensorstore::StrCat( "Rank specified by schema (", schema.rank(), - ") does not match rank specified by metadata (", metadata.rank, ")")); + ") does not match rank specified by metadata (", info.chunked_rank, + ")")); } if (schema.domain().valid()) { + std::optional> metadata_shape_span; + metadata_shape_span.emplace(metadata.shape.data(), metadata.shape.size()); + std::optional>> dimension_names_span; + dimension_names_span.emplace(metadata.dimension_names.data(), + metadata.dimension_names.size()); TENSORSTORE_RETURN_IF_ERROR(GetEffectiveDomain( - metadata.rank, metadata.shape, metadata.dimension_names, schema)); + info, metadata_shape_span, dimension_names_span, schema, + /*dimension_names_used=*/nullptr)); } if (auto dtype = schema.dtype(); - !IsPossiblySameDataType(metadata.data_type, dtype)) { + !IsPossiblySameDataType(field.dtype, dtype)) { return absl::FailedPreconditionError( - tensorstore::StrCat("data_type from metadata (", metadata.data_type, + tensorstore::StrCat("data_type from metadata (", field.dtype, ") does not match dtype in schema (", dtype, ")")); } if (schema.chunk_layout().rank() != dynamic_rank) { - TENSORSTORE_ASSIGN_OR_RETURN( - auto chunk_layout, - GetEffectiveChunkLayout(metadata.data_type, metadata.rank, - metadata.chunk_shape, &metadata.codec_specs, - schema)); + ChunkLayout chunk_layout = schema.chunk_layout(); + std::optional> chunk_shape_span; + chunk_shape_span.emplace(metadata.chunk_shape.data(), + metadata.chunk_shape.size()); + TENSORSTORE_RETURN_IF_ERROR(SetChunkLayoutFromMetadata( + info, chunk_shape_span, &metadata.codec_specs, chunk_layout)); if (chunk_layout.codec_chunk_shape().hard_constraint) { return absl::InvalidArgumentError("codec_chunk_shape not supported"); } } if (auto schema_fill_value = schema.fill_value(); schema_fill_value.valid()) { - const auto& fill_value = metadata.fill_value; + const auto& fill_value = metadata.fill_value[field_index]; TENSORSTORE_ASSIGN_OR_RETURN( auto broadcast_fill_value, tensorstore::BroadcastArray(schema_fill_value, span{})); TENSORSTORE_ASSIGN_OR_RETURN( SharedArray converted_fill_value, tensorstore::MakeCopy(std::move(broadcast_fill_value), - skip_repeated_elements, metadata.data_type)); + skip_repeated_elements, field.dtype)); if (!AreArraysIdenticallyEqual(converted_fill_value, fill_value)) { auto binder = FillValueJsonBinder{metadata.data_type}; - auto schema_json = jb::ToJson(converted_fill_value, binder).value(); - auto metadata_json = jb::ToJson(metadata.fill_value, binder).value(); + // Error message generation might be tricky with binder return absl::FailedPreconditionError(tensorstore::StrCat( "Invalid fill_value: schema requires fill value of ", - schema_json.dump(), ", but metadata specifies fill value of ", - metadata_json.dump())); + schema_fill_value, ", but metadata specifies fill value of ", + fill_value)); } } @@ -804,8 +1133,14 @@ absl::Status ValidateMetadataSchema(const ZarrMetadata& metadata, return absl::OkStatus(); } +absl::Status ValidateMetadataSchema(const ZarrMetadata& metadata, + const Schema& schema) { + return ValidateMetadataSchema(metadata, /*field_index=*/0, schema); +} + Result> GetNewMetadata( - const ZarrMetadataConstraints& metadata_constraints, const Schema& schema) { + const ZarrMetadataConstraints& metadata_constraints, const Schema& schema, + std::string_view selected_field, bool open_as_void) { auto metadata = std::make_shared(); metadata->zarr_format = metadata_constraints.zarr_format.value_or(3); @@ -813,51 +1148,85 @@ Result> GetNewMetadata( metadata_constraints.chunk_key_encoding.value_or(ChunkKeyEncoding{ /*.kind=*/ChunkKeyEncoding::kDefault, /*.separator=*/'/'}); + // Determine data type first + if (metadata_constraints.data_type) { + metadata->data_type = *metadata_constraints.data_type; + } else if (!selected_field.empty()) { + return absl::InvalidArgumentError( + "\"dtype\" must be specified in \"metadata\" if \"field\" is " + "specified"); + } else if (auto dtype = schema.dtype(); dtype.valid()) { + TENSORSTORE_ASSIGN_OR_RETURN( + static_cast( + metadata->data_type.fields.emplace_back()), + ChooseBaseDType(dtype)); + metadata->data_type.has_fields = false; + TENSORSTORE_RETURN_IF_ERROR(ValidateDType(metadata->data_type)); + } else { + return absl::InvalidArgumentError("dtype must be specified"); + } + + TENSORSTORE_ASSIGN_OR_RETURN( + size_t field_index, GetFieldIndex(metadata->data_type, selected_field, open_as_void)); + SpecRankAndFieldInfo info; + info.field = &metadata->data_type.fields[field_index]; + info.chunked_rank = metadata_constraints.rank; + if (info.chunked_rank == dynamic_rank && metadata_constraints.shape) { + info.chunked_rank = metadata_constraints.shape->size(); + } + if (info.chunked_rank == dynamic_rank && + schema.rank().rank != dynamic_rank) { + info.chunked_rank = schema.rank().rank; + } + // Set domain - bool dimension_names_used; + bool dimension_names_used = false; + std::optional> constraint_shape_span; + if (metadata_constraints.shape) { + constraint_shape_span.emplace(metadata_constraints.shape->data(), + metadata_constraints.shape->size()); + } + std::optional>> constraint_names_span; + if (metadata_constraints.dimension_names) { + constraint_names_span.emplace( + metadata_constraints.dimension_names->data(), + metadata_constraints.dimension_names->size()); + } TENSORSTORE_ASSIGN_OR_RETURN( - auto domain, - GetEffectiveDomain(metadata_constraints, schema, &dimension_names_used)); + auto domain, GetEffectiveDomain(info, constraint_shape_span, + constraint_names_span, schema, + &dimension_names_used)); if (!domain.valid() || !IsFinite(domain.box())) { return absl::InvalidArgumentError("domain must be specified"); } - const DimensionIndex rank = metadata->rank = domain.rank(); - metadata->shape.assign(domain.shape().begin(), domain.shape().end()); + const DimensionIndex rank = domain.rank(); + metadata->rank = rank; + info.chunked_rank = rank; + metadata->shape.assign(domain.shape().begin(), + domain.shape().begin() + rank); metadata->dimension_names.assign(domain.labels().begin(), - domain.labels().end()); - // Normalize empty string dimension names to `std::nullopt`. This is more - // consistent with the zarr v3 dimension name semantics, and ensures that the - // `dimension_names` metadata field will be excluded entirely if all dimension - // names are the empty string. - // - // However, if empty string dimension names were specified explicitly in - // `metadata_constraints`, leave them exactly as specified. + domain.labels().begin() + rank); + for (DimensionIndex i = 0; i < rank; ++i) { auto& name = metadata->dimension_names[i]; if (!name || !name->empty()) continue; - // Dimension name equals the empty string. - if (dimension_names_used && (*metadata_constraints.dimension_names)[i]) { - // Empty dimension name was explicitly specified in - // `metadata_constraints`, leave it as is. + if (dimension_names_used && metadata_constraints.dimension_names && + (*metadata_constraints.dimension_names)[i]) { assert((*metadata_constraints.dimension_names)[i]->empty()); continue; } - // Name was not explicitly specified in `metadata_constraints` as an empty - // string. Normalize it to `std::nullopt`. name = std::nullopt; } - // Set dtype - auto dtype = schema.dtype(); - if (!dtype.valid()) { - return absl::InvalidArgumentError("dtype must be specified"); - } - TENSORSTORE_RETURN_IF_ERROR(ValidateDataType(dtype)); - metadata->data_type = dtype; - if (metadata_constraints.fill_value) { metadata->fill_value = *metadata_constraints.fill_value; } else if (auto fill_value = schema.fill_value(); fill_value.valid()) { + // Assuming single field if setting from schema + if (metadata->data_type.fields.size() != 1) { + return absl::InvalidArgumentError( + "Cannot specify fill_value through schema for structured zarr data " + "type"); + } const auto status = [&] { TENSORSTORE_ASSIGN_OR_RETURN( auto broadcast_fill_value, @@ -865,23 +1234,26 @@ Result> GetNewMetadata( TENSORSTORE_ASSIGN_OR_RETURN( auto converted_fill_value, tensorstore::MakeCopy(std::move(broadcast_fill_value), - skip_repeated_elements, metadata->data_type)); - metadata->fill_value = std::move(converted_fill_value); + skip_repeated_elements, + metadata->data_type.fields[0].dtype)); + metadata->fill_value.push_back(std::move(converted_fill_value)); return absl::OkStatus(); }(); TENSORSTORE_RETURN_IF_ERROR( status, tensorstore::MaybeAnnotateStatus(_, "Invalid fill_value")); } else { - metadata->fill_value = tensorstore::AllocateArray( - /*shape=*/span(), c_order, value_init, - metadata->data_type); + metadata->fill_value.resize(metadata->data_type.fields.size()); + for (size_t i = 0; i < metadata->fill_value.size(); ++i) { + metadata->fill_value[i] = tensorstore::AllocateArray( + /*shape=*/span(), c_order, value_init, + metadata->data_type.fields[i].dtype); + } } metadata->user_attributes = metadata_constraints.user_attributes; metadata->unknown_extension_attributes = metadata_constraints.unknown_extension_attributes; - // Set dimension units TENSORSTORE_ASSIGN_OR_RETURN( auto dimension_units, GetEffectiveDimensionUnits(rank, metadata_constraints.dimension_units, @@ -895,12 +1267,16 @@ Result> GetNewMetadata( TENSORSTORE_ASSIGN_OR_RETURN(auto codec_spec, GetEffectiveCodec(metadata_constraints, schema)); - // Set chunk shape - ArrayCodecResolveParameters decoded; - decoded.dtype = metadata->data_type; + if (metadata->data_type.fields.size() == 1 && + metadata->data_type.fields[0].outer_shape.empty()) { + decoded.dtype = metadata->data_type.fields[0].dtype; + } else { + decoded.dtype = dtype_v; + } decoded.rank = metadata->rank; - decoded.fill_value = metadata->fill_value; + if (metadata->fill_value.size() == 1) + decoded.fill_value = metadata->fill_value[0]; TENSORSTORE_ASSIGN_OR_RETURN( auto chunk_layout, GetEffectiveChunkLayout(metadata_constraints, schema)); @@ -920,8 +1296,6 @@ Result> GetNewMetadata( if (!internal::RangesEqual(span(metadata->chunk_shape), span(read_chunk_shape))) { - // Read chunk and write chunk shapes differ. Insert sharding codec if there - // is not already one. if (!codec_spec->codecs || codec_spec->codecs->sharding_height() == 0) { auto sharding_codec = internal::MakeIntrusivePtr( @@ -945,7 +1319,8 @@ Result> GetNewMetadata( TENSORSTORE_RETURN_IF_ERROR(set_up_codecs( codec_spec->codecs ? *codec_spec->codecs : ZarrCodecChainSpec{})); TENSORSTORE_RETURN_IF_ERROR(ValidateMetadata(*metadata)); - TENSORSTORE_RETURN_IF_ERROR(ValidateMetadataSchema(*metadata, schema)); + TENSORSTORE_RETURN_IF_ERROR( + ValidateMetadataSchema(*metadata, field_index, schema)); return metadata; } diff --git a/tensorstore/driver/zarr3/metadata.h b/tensorstore/driver/zarr3/metadata.h index 05b8c6be3..857210546 100644 --- a/tensorstore/driver/zarr3/metadata.h +++ b/tensorstore/driver/zarr3/metadata.h @@ -33,6 +33,7 @@ #include "tensorstore/data_type.h" #include "tensorstore/driver/zarr3/codec/codec.h" #include "tensorstore/driver/zarr3/codec/codec_chain_spec.h" +#include "tensorstore/driver/zarr3/dtype.h" #include "tensorstore/index.h" #include "tensorstore/index_space/dimension_units.h" #include "tensorstore/index_space/index_domain.h" @@ -72,19 +73,35 @@ struct ChunkKeyEncoding { }; struct FillValueJsonBinder { - DataType data_type; + ZarrDType dtype; + bool allow_missing_dtype = false; + FillValueJsonBinder() = default; + explicit FillValueJsonBinder(ZarrDType dtype, + bool allow_missing_dtype = false); + explicit FillValueJsonBinder(DataType dtype, + bool allow_missing_dtype = false); absl::Status operator()(std::true_type is_loading, internal_json_binding::NoOptions, - SharedArray* obj, + std::vector>* obj, ::nlohmann::json* j) const; absl::Status operator()(std::false_type is_loading, internal_json_binding::NoOptions, - const SharedArray* obj, + const std::vector>* obj, ::nlohmann::json* j) const; + + private: + absl::Status DecodeSingle(::nlohmann::json& j, DataType data_type, + SharedArray& out) const; + absl::Status EncodeSingle(const SharedArray& arr, + DataType data_type, + ::nlohmann::json& j) const; }; +struct SpecRankAndFieldInfo; + + struct ZarrMetadata { // The following members are common to `ZarrMetadata` and // `ZarrMetadataConstraints`, except that in `ZarrMetadataConstraints` some @@ -94,14 +111,14 @@ struct ZarrMetadata { int zarr_format; std::vector shape; - DataType data_type; + ZarrDType data_type; ::nlohmann::json::object_t user_attributes; std::optional dimension_units; std::vector> dimension_names; ChunkKeyEncoding chunk_key_encoding; std::vector chunk_shape; ZarrCodecChainSpec codec_specs; - SharedArray fill_value; + std::vector> fill_value; ::nlohmann::json::object_t unknown_extension_attributes; std::string GetCompatibilityKey() const; @@ -123,14 +140,14 @@ struct ZarrMetadataConstraints { std::optional zarr_format; std::optional> shape; - std::optional data_type; + std::optional data_type; ::nlohmann::json::object_t user_attributes; std::optional dimension_units; std::optional>> dimension_names; std::optional chunk_key_encoding; std::optional> chunk_shape; std::optional codec_specs; - std::optional> fill_value; + std::optional>> fill_value; ::nlohmann::json::object_t unknown_extension_attributes; TENSORSTORE_DECLARE_JSON_DEFAULT_BINDER(ZarrMetadataConstraints, @@ -159,6 +176,10 @@ Result> GetEffectiveDomain( /// Sets chunk layout constraints implied by `dtype`, `rank`, `chunk_shape`, and /// `codecs`. +absl::Status SetChunkLayoutFromMetadata( + const SpecRankAndFieldInfo& info, + std::optional> chunk_shape, + const ZarrCodecChainSpec* codecs, ChunkLayout& chunk_layout); absl::Status SetChunkLayoutFromMetadata( DataType dtype, DimensionIndex rank, std::optional> chunk_shape, @@ -198,6 +219,8 @@ Result> GetEffectiveCodec( CodecSpec GetCodecFromMetadata(const ZarrMetadata& metadata); /// Validates that `schema` is compatible with `metadata`. +absl::Status ValidateMetadataSchema(const ZarrMetadata& metadata, + size_t field_index, const Schema& schema); absl::Status ValidateMetadataSchema(const ZarrMetadata& metadata, const Schema& schema); @@ -206,10 +229,24 @@ absl::Status ValidateMetadataSchema(const ZarrMetadata& metadata, /// \error `absl::StatusCode::kInvalidArgument` if any required fields are /// unspecified. Result> GetNewMetadata( - const ZarrMetadataConstraints& metadata_constraints, const Schema& schema); + const ZarrMetadataConstraints& metadata_constraints, + const Schema& schema, std::string_view selected_field = {}, + bool open_as_void = false); absl::Status ValidateDataType(DataType dtype); +Result GetFieldIndex(const ZarrDType& dtype, + std::string_view selected_field, + bool open_as_void = false); + +struct SpecRankAndFieldInfo { + DimensionIndex chunked_rank = dynamic_rank; + const ZarrDType::Field* field = nullptr; +}; + +SpecRankAndFieldInfo GetSpecRankAndFieldInfo(const ZarrMetadata& metadata, + size_t field_index); + } // namespace internal_zarr3 } // namespace tensorstore diff --git a/tensorstore/driver/zarr3/metadata_test.cc b/tensorstore/driver/zarr3/metadata_test.cc index 0b140fa80..ba7a26593 100644 --- a/tensorstore/driver/zarr3/metadata_test.cc +++ b/tensorstore/driver/zarr3/metadata_test.cc @@ -51,6 +51,7 @@ namespace { namespace jb = ::tensorstore::internal_json_binding; using ::tensorstore::ChunkLayout; +using ::tensorstore::DataType; using ::tensorstore::CodecSpec; using ::tensorstore::dtype_v; using ::tensorstore::Index; @@ -68,6 +69,7 @@ using ::tensorstore::dtypes::float32_t; using ::tensorstore::dtypes::float64_t; using ::tensorstore::internal::uint_t; using ::tensorstore::internal_zarr3::FillValueJsonBinder; +using ::tensorstore::internal_zarr3::ZarrDType; using ::tensorstore::internal_zarr3::ZarrMetadata; using ::tensorstore::internal_zarr3::ZarrMetadataConstraints; using ::testing::HasSubstr; @@ -90,13 +92,30 @@ ::nlohmann::json GetBasicMetadata() { }; } +ZarrDType MakeScalarZarrDType(DataType dtype) { + ZarrDType dtype_info; + dtype_info.has_fields = false; + dtype_info.fields.resize(1); + auto& field = dtype_info.fields[0]; + field.dtype = dtype; + field.encoded_dtype = std::string(dtype.name()); + field.outer_shape.clear(); + field.flexible_shape.clear(); + field.field_shape.clear(); + field.num_inner_elements = 1; + field.byte_offset = 0; + field.num_bytes = dtype->size; + return dtype_info; +} + TEST(MetadataTest, ParseValid) { auto json = GetBasicMetadata(); tensorstore::TestJsonBinderRoundTripJsonOnly({json}); TENSORSTORE_ASSERT_OK_AND_ASSIGN(auto metadata, ZarrMetadata::FromJson(json)); EXPECT_THAT(metadata.shape, ::testing::ElementsAre(10, 11, 12)); EXPECT_THAT(metadata.chunk_shape, ::testing::ElementsAre(1, 2, 3)); - EXPECT_THAT(metadata.data_type, tensorstore::dtype_v); + ASSERT_EQ(metadata.data_type.fields.size(), 1); + EXPECT_EQ(tensorstore::dtype_v, metadata.data_type.fields[0].dtype); EXPECT_THAT(metadata.dimension_names, ::testing::ElementsAre("a", std::nullopt, "")); EXPECT_THAT(metadata.user_attributes, MatchesJson({{"a", "b"}, {"c", "d"}})); @@ -115,7 +134,8 @@ TEST(MetadataTest, ParseValidNoDimensionNames) { TENSORSTORE_ASSERT_OK_AND_ASSIGN(auto metadata, ZarrMetadata::FromJson(json)); EXPECT_THAT(metadata.shape, ::testing::ElementsAre(10, 11, 12)); EXPECT_THAT(metadata.chunk_shape, ::testing::ElementsAre(1, 2, 3)); - EXPECT_THAT(metadata.data_type, tensorstore::dtype_v); + ASSERT_EQ(metadata.data_type.fields.size(), 1); + EXPECT_EQ(tensorstore::dtype_v, metadata.data_type.fields[0].dtype); EXPECT_THAT(metadata.dimension_names, ::testing::ElementsAre(std::nullopt, std::nullopt, std::nullopt)); EXPECT_THAT(metadata.user_attributes, MatchesJson({{"a", "b"}, {"c", "d"}})); @@ -418,7 +438,7 @@ Result> TestGetNewMetadata( TENSORSTORE_RETURN_IF_ERROR(status); TENSORSTORE_ASSIGN_OR_RETURN( auto constraints, ZarrMetadataConstraints::FromJson(constraints_json)); - return GetNewMetadata(constraints, schema); + return GetNewMetadata(constraints, schema, /*selected_field=*/{}, /*open_as_void=*/false); } TEST(GetNewMetadataTest, DuplicateDimensionNames) { @@ -486,7 +506,9 @@ TEST(MetadataTest, DataTypes) { } TENSORSTORE_ASSERT_OK_AND_ASSIGN(auto metadata, ZarrMetadata::FromJson(json)); - EXPECT_EQ(tensorstore::GetDataType(data_type_name), metadata.data_type); + ASSERT_FALSE(metadata.data_type.fields.empty()); + EXPECT_EQ(tensorstore::GetDataType(data_type_name), + metadata.data_type.fields[0].dtype); } } @@ -503,18 +525,20 @@ TEST(MetadataTest, InvalidDataType) { template void TestFillValue(std::vector> cases, bool skip_to_json = false) { - auto binder = FillValueJsonBinder{dtype_v}; + FillValueJsonBinder binder(MakeScalarZarrDType(dtype_v)); for (const auto& [value, json] : cases) { SharedArray expected_fill_value = tensorstore::MakeScalarArray(value); if (!skip_to_json) { - EXPECT_THAT(jb::ToJson(expected_fill_value, binder), + std::vector> vec{expected_fill_value}; + EXPECT_THAT(jb::ToJson(vec, binder), ::testing::Optional(MatchesJson(json))) << "value=" << value << ", json=" << json; } - EXPECT_THAT(jb::FromJson>(json, binder), - ::testing::Optional( - tensorstore::MatchesArrayIdentically(expected_fill_value))) + EXPECT_THAT( + jb::FromJson>>(json, binder), + ::testing::Optional(::testing::ElementsAre( + tensorstore::MatchesArrayIdentically(expected_fill_value)))) << "json=" << json; } } @@ -522,10 +546,11 @@ void TestFillValue(std::vector> cases, template void TestFillValueInvalid( std::vector> cases) { - auto binder = FillValueJsonBinder{dtype_v}; + FillValueJsonBinder binder(MakeScalarZarrDType(dtype_v)); for (const auto& [json, matcher] : cases) { EXPECT_THAT( - jb::FromJson>(json, binder).status(), + jb::FromJson>>(json, binder) + .status(), StatusIs(absl::StatusCode::kInvalidArgument, MatchesRegex(matcher))) << "json=" << json; } diff --git a/tensorstore/driver/zarr3/schema.yml b/tensorstore/driver/zarr3/schema.yml index 4f9733415..9491027b1 100644 --- a/tensorstore/driver/zarr3/schema.yml +++ b/tensorstore/driver/zarr3/schema.yml @@ -17,6 +17,31 @@ allOf: automatically. When creating a new array, the new metadata is obtained by combining these metadata constraints with any `Schema` constraints. $ref: driver/zarr3/Metadata + field: + type: string + title: Field selection for structured arrays. + description: | + Name of the field to select from a structured array. When specified, + the tensorstore will provide access to only the specified field of + each element in the structured array. + open_as_void: + type: boolean + default: false + title: Raw byte access mode. + description: | + When true, opens the array as raw bytes instead of interpreting it + as structured data. The resulting array will have an additional + dimension representing the byte layout of each element. + oneOf: + - not: + anyOf: + - required: ["field"] + - required: ["open_as_void"] + - allOf: + - not: + required: ["field"] + - not: + required: ["open_as_void"] examples: - driver: zarr3 kvstore: