diff --git a/crates/fluss/src/error.rs b/crates/fluss/src/error.rs index b1d5d13b..63438b19 100644 --- a/crates/fluss/src/error.rs +++ b/crates/fluss/src/error.rs @@ -39,7 +39,7 @@ pub enum Error { #[error("Row convert error")] RowConvertError(String), - #[error("arrow error")] + #[error("Arrow error: {0}")] ArrowError(#[from] ArrowError), #[error("Write error: {0}")] diff --git a/crates/fluss/src/record/arrow.rs b/crates/fluss/src/record/arrow.rs index 806c9a58..e343e3c8 100644 --- a/crates/fluss/src/record/arrow.rs +++ b/crates/fluss/src/record/arrow.rs @@ -34,6 +34,7 @@ use arrow::{ writer::StreamWriter, }, }; +use arrow_schema::ArrowError::ParseError; use arrow_schema::SchemaRef; use arrow_schema::{DataType as ArrowDataType, Field}; use byteorder::WriteBytesExt; @@ -489,19 +490,15 @@ impl<'a> LogRecordBatch<'a> { let data = &self.data[RECORDS_OFFSET..]; let record_batch = read_context.record_batch(data)?; - let log_record_iterator = match record_batch { - None => LogRecordIterator::empty(), - Some(record_batch) => { - let arrow_reader = ArrowReader::new(Arc::new(record_batch)); - LogRecordIterator::Arrow(ArrowLogRecordIterator { - reader: arrow_reader, - base_offset: self.base_log_offset(), - timestamp: self.commit_timestamp(), - row_id: 0, - change_type: ChangeType::AppendOnly, - }) - } - }; + let arrow_reader = ArrowReader::new(Arc::new(record_batch)); + let log_record_iterator = LogRecordIterator::Arrow(ArrowLogRecordIterator { + reader: arrow_reader, + base_offset: self.base_log_offset(), + timestamp: self.commit_timestamp(), + row_id: 0, + change_type: ChangeType::AppendOnly, + }); + Ok(log_record_iterator) } } @@ -518,15 +515,16 @@ impl<'a> LogRecordBatch<'a> { /// * `data` - The byte slice containing the IPC message. /// /// # Returns -/// Returns `Some((batch_metadata, body_buffer, version))` on success: +/// Returns `Ok((batch_metadata, body_buffer, version))` on success: /// - `batch_metadata`: The RecordBatch metadata from the IPC message. /// - `body_buffer`: The buffer containing the record batch body data. /// - `version`: The Arrow IPC metadata version. /// -/// Returns `None` if the data is malformed or too short. +/// Returns `Err(arrow_error)` on errors +/// - `arrow_error`: Error details e.g. malformed, too short or bad continuation marker. fn parse_ipc_message( data: &[u8], -) -> Option<( +) -> Result<( arrow::ipc::RecordBatch<'_>, Buffer, arrow::ipc::MetadataVersion, @@ -534,29 +532,37 @@ fn parse_ipc_message( const CONTINUATION_MARKER: u32 = 0xFFFFFFFF; if data.len() < 8 { - return None; + Err(ParseError(format!("Invalid data length: {}", data.len())))? } let continuation = LittleEndian::read_u32(&data[0..4]); let metadata_size = LittleEndian::read_u32(&data[4..8]) as usize; if continuation != CONTINUATION_MARKER { - return None; + Err(ParseError(format!( + "Invalid continuation marker: {continuation}" + )))? } if data.len() < 8 + metadata_size { - return None; + Err(ParseError(format!( + "Invalid data length. Remaining data length {} is shorter than specified size {}", + data.len() - 8, + metadata_size + )))? } let metadata_bytes = &data[8..8 + metadata_size]; - let message = root_as_message(metadata_bytes).ok()?; - let batch_metadata = message.header_as_record_batch()?; + let message = root_as_message(metadata_bytes).map_err(|err| ParseError(err.to_string()))?; + let batch_metadata = message + .header_as_record_batch() + .ok_or(ParseError(String::from("Not a record batch")))?; let body_start = 8 + metadata_size; let body_data = &data[body_start..]; let body_buffer = Buffer::from(body_data); - Some((batch_metadata, body_buffer, message.version())) + Ok((batch_metadata, body_buffer, message.version())) } pub fn to_arrow_schema(fluss_schema: &DataType) -> SchemaRef { @@ -577,7 +583,7 @@ pub fn to_arrow_schema(fluss_schema: &DataType) -> SchemaRef { SchemaRef::new(arrow_schema::Schema::new(fields)) } _ => { - panic!("must be row data tyoe.") + panic!("must be row data type.") } } } @@ -766,11 +772,8 @@ impl ReadContext { .map(|p| p.ordered_fields.as_slice()) } - pub fn record_batch(&self, data: &[u8]) -> Result> { - let (batch_metadata, body_buffer, version) = match parse_ipc_message(data) { - Some(result) => result, - None => return Ok(None), - }; + pub fn record_batch(&self, data: &[u8]) -> Result { + let (batch_metadata, body_buffer, version) = parse_ipc_message(data)?; // the record batch from server must be ordered by field pos, // according to project to decide what arrow schema to use @@ -807,7 +810,7 @@ impl ReadContext { } _ => record_batch, }; - Ok(Some(record_batch)) + Ok(record_batch) } } @@ -1017,4 +1020,55 @@ mod tests { fn test_timestamp_ltz_invalid_precision() { to_arrow_type(&DataTypes::timestamp_ltz_with_precision(10)); } + + #[test] + fn test_parse_ipc_message() { + let empty_body: &[u8] = &le_bytes(&[0xFFFFFFFF, 0x00000000]); + let result = parse_ipc_message(empty_body); + assert_eq!( + result.unwrap_err().to_string(), + String::from("Arrow error: Parser error: Range [0, 4) is out of bounds.\n\n") + ); + + let invalid_data = &[]; + assert_eq!( + parse_ipc_message(invalid_data).unwrap_err().to_string(), + String::from("Arrow error: Parser error: Invalid data length: 0") + ); + + let data_with_invalid_continuation: &[u8] = &le_bytes(&[0x00000001, 0x00000000]); + assert_eq!( + parse_ipc_message(data_with_invalid_continuation) + .unwrap_err() + .to_string(), + String::from("Arrow error: Parser error: Invalid continuation marker: 1") + ); + + let data_with_invalid_length: &[u8] = &le_bytes(&[0xFFFFFFFF, 0x00000001]); + assert_eq!( + parse_ipc_message(data_with_invalid_length) + .unwrap_err() + .to_string(), + String::from( + "Arrow error: Parser error: Invalid data length. \ + Remaining data length 0 is shorter than specified size 1" + ) + ); + + let data_with_invalid_length = &le_bytes(&[0xFFFFFFFF, 0x00000004, 0x00000000]); + assert_eq!( + parse_ipc_message(data_with_invalid_length) + .unwrap_err() + .to_string(), + String::from("Arrow error: Parser error: Not a record batch") + ); + } + + fn le_bytes(vals: &[u32]) -> Vec { + let mut out = Vec::with_capacity(vals.len() * 4); + for &v in vals { + out.extend_from_slice(&v.to_le_bytes()); + } + out + } }