diff --git a/Cargo.lock b/Cargo.lock index 38fa83dd12119..15a3b64fb3a13 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1756,6 +1756,7 @@ dependencies = [ "datafusion-session", "datafusion-sql", "doc-comment", + "encoding_rs", "env_logger", "flate2", "futures", @@ -2017,6 +2018,7 @@ dependencies = [ "datafusion-physical-expr-common", "datafusion-physical-plan", "datafusion-session", + "encoding_rs", "futures", "object_store", "regex", @@ -2830,6 +2832,15 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + [[package]] name = "endian-type" version = "0.1.2" diff --git a/Cargo.toml b/Cargo.toml index d057261f7a2e1..c4825e7878510 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -151,6 +151,7 @@ datafusion-sql = { path = "datafusion/sql", version = "52.2.0" } datafusion-substrait = { path = "datafusion/substrait", version = "52.2.0" } doc-comment = "0.3" +encoding_rs = "0.8" env_logger = "0.11" flate2 = "1.1.9" futures = "0.3" diff --git a/README.md b/README.md index 630d4295bd427..661610205f084 100644 --- a/README.md +++ b/README.md @@ -131,6 +131,7 @@ Optional features: - `avro`: support for reading the [Apache Avro] format - `backtrace`: include backtrace information in error messages +- `encoding_rs`: support for reading files with character encodings other than UTF-8 - `parquet_encryption`: support for using [Parquet Modular Encryption] - `serde`: enable arrow-schema's `serde` feature diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index d71af206c78d5..c045d609f1934 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -2921,6 +2921,7 @@ config_namespace! { /// /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. pub newlines_in_values: Option, default = None + pub charset: Option, default = None pub compression: CompressionTypeVariant, default = CompressionTypeVariant::UNCOMPRESSED /// Compression level for the output file. The valid range depends on the /// compression algorithm: @@ -3033,6 +3034,13 @@ impl CsvOptions { self } + /// Specifies the character encoding the file is encoded with. + /// - defaults to UTF-8 + pub fn with_charset(mut self, charset: impl Into) -> Self { + self.charset = Some(charset.into()); + self + } + /// Set a `CompressionTypeVariant` of CSV /// - defaults to `CompressionTypeVariant::UNCOMPRESSED` pub fn with_file_compression_type( diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 8965948a0f4e2..0e55d5b2bce8a 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -68,6 +68,7 @@ default = [ "recursive_protection", "sql", ] +encoding_rs = ["datafusion-datasource-csv/encoding_rs"] encoding_expressions = ["datafusion-functions/encoding_expressions"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = ["datafusion-physical-plan/force_hash_collisions", "datafusion-common/force_hash_collisions"] @@ -171,6 +172,7 @@ datafusion-functions-window-common = { workspace = true } datafusion-macros = { workspace = true } datafusion-physical-optimizer = { workspace = true } doc-comment = { workspace = true } +encoding_rs = { workspace = true } env_logger = { workspace = true } glob = { workspace = true } insta = { workspace = true } diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 51d799a5b65c1..d22aeec0f154c 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -1619,4 +1619,49 @@ mod tests { Ok(()) } + + #[cfg(feature = "encoding_rs")] + #[tokio::test] + async fn test_read_shift_jis_csv() -> Result<()> { + use std::io::Write; + + // Encode a test CSV into SHIFT-JIS + let data = r#"ID,Name,Price,Description,Notes +001,山本 大輔,\2945,桜餅と抹茶のセット,数量限定 +002,加藤 由美,\9575,和牛ステーキセット,取り寄せ中 +003,田中 太郎,\1853,抹茶アイスクリーム,ポイント2倍 +004,渡辺 さくら,\9494,和牛ステーキセット,送料無料 +005,加藤 由美,\558,和牛ステーキセット,新商品 +006,渡辺 さくら,\7704,天ぷら盛り合わせ,割引対象外 +007,田中 太郎,\212,桜餅と抹茶のセット,取り寄せ中 +008,中村 陽子,\8847,和牛ステーキセット,期間限定 +009,伊藤 健太,\5997,季節の野菜カレー,お一人様1点限り +010,高橋 美咲,\6594,季節の野菜カレー,冷凍保存"#; + let (data, _, _) = encoding_rs::SHIFT_JIS.encode(data); + + // Write the CSV data to a temp file + let mut tmp = tempfile::Builder::new().suffix(".csv").tempfile()?; + tmp.write_all(&*data)?; + let path = tmp.path().to_str().unwrap().to_string(); + + // Read the file + let ctx = SessionContext::new(); + let opts = CsvReadOptions::new().has_header(true).charset("SHIFT-JIS"); + let batches = ctx.read_csv(path, opts).await?.collect().await?; + + // Check + let num_rows = batches.iter().map(|b| b.num_rows()).sum::(); + assert_eq!(num_rows, 10); + + let names = batches[0] + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(names.value(0), "山本 大輔"); + assert_eq!(names.value(1), "加藤 由美"); + assert_eq!(names.value(2), "田中 太郎"); + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index bd0ac36087381..68e8596e12141 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -85,6 +85,8 @@ pub struct CsvReadOptions<'a> { pub file_extension: &'a str, /// Partition Columns pub table_partition_cols: Vec<(String, DataType)>, + /// Character encoding + pub charset: Option<&'a str>, /// File compression type pub file_compression_type: FileCompressionType, /// Indicates how the file is sorted @@ -118,6 +120,7 @@ impl<'a> CsvReadOptions<'a> { newlines_in_values: false, file_extension: DEFAULT_CSV_EXTENSION, table_partition_cols: vec![], + charset: None, file_compression_type: FileCompressionType::UNCOMPRESSED, file_sort_order: vec![], comment: None, @@ -209,6 +212,12 @@ impl<'a> CsvReadOptions<'a> { self } + /// Configure the character set encoding + pub fn charset(mut self, charset: &'a str) -> Self { + self.charset = Some(charset); + self + } + /// Configure file compression type pub fn file_compression_type( mut self, @@ -633,6 +642,7 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { .with_terminator(self.terminator) .with_newlines_in_values(self.newlines_in_values) .with_schema_infer_max_rec(self.schema_infer_max_records) + .with_charset(self.charset.map(ToOwned::to_owned)) .with_file_compression_type(self.file_compression_type.to_owned()) .with_null_regex(self.null_regex.clone()) .with_truncated_rows(self.truncated_rows); diff --git a/datafusion/datasource-csv/Cargo.toml b/datafusion/datasource-csv/Cargo.toml index 295092512742b..0f83fa8ac79fc 100644 --- a/datafusion/datasource-csv/Cargo.toml +++ b/datafusion/datasource-csv/Cargo.toml @@ -42,6 +42,7 @@ datafusion-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-session = { workspace = true } +encoding_rs = { workspace = true, optional = true } futures = { workspace = true } object_store = { workspace = true } regex = { workspace = true } diff --git a/datafusion/datasource-csv/src/charset.rs b/datafusion/datasource-csv/src/charset.rs new file mode 100644 index 0000000000000..d5e17e174ff0a --- /dev/null +++ b/datafusion/datasource-csv/src/charset.rs @@ -0,0 +1,263 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt::Debug; +use std::io::{BufRead, Read}; + +use arrow::array::RecordBatch; +use arrow::error::ArrowError; +use datafusion_common::{DataFusionError, Result}; +use datafusion_datasource::decoder::Decoder; +use encoding_rs::{CoderResult, Encoding, UTF_8}; + +use self::buffer::Buffer; + +/// Default capacity of the buffer used to decode non-UTF-8 charset streams +static DECODE_BUFFER_CAP: usize = 8 * 1024; + +pub fn lookup_charset(enc: Option<&str>) -> Result> { + match enc { + Some(enc) => match Encoding::for_label(enc.as_bytes()) { + Some(enc) => Ok(Some(enc).filter(|enc| *enc != UTF_8)), + None => Err(DataFusionError::Configuration(format!( + "Unknown character set '{enc}'" + )))?, + }, + None => Ok(None), + } +} + +/// A record batch `Decoder` that decodes input bytes from the specified +/// character encoding to UTF-8 before passing them onto the inner `Decoder`. +#[derive(Debug)] +pub struct CharsetBatchDecoder { + inner: T, + decoder: CharsetDecoder, +} + +impl CharsetBatchDecoder { + pub fn new(inner: T, encoding: &'static Encoding) -> Self { + let decoder = CharsetDecoder::new(encoding); + Self { inner, decoder } + } +} + +impl Decoder for CharsetBatchDecoder { + fn decode(&mut self, buf: &[u8]) -> Result { + let last = buf.is_empty(); + let mut buf_offset = 0; + + if !self.decoder.is_empty() { + let decoded = self.inner.decode(self.decoder.read())?; + self.decoder.consume(decoded); + + if decoded == 0 { + return Ok(buf_offset); + } + } + + loop { + let (read, input_empty) = self.decoder.fill(&buf[buf_offset..], last); + buf_offset += read; + + let decoded = self.inner.decode(self.decoder.read())?; + self.decoder.consume(decoded); + + if input_empty || decoded == 0 { + break; + } + } + + Ok(buf_offset) + } + + fn flush(&mut self) -> Result, ArrowError> { + self.inner.flush() + } + + fn can_flush_early(&self) -> bool { + self.inner.can_flush_early() + } +} + +/// A `BufRead` adapter that decodes input bytes from the +/// specified character encoding to UTF-8. +#[derive(Debug)] +pub struct CharsetReader { + inner: R, + decoder: CharsetDecoder, +} + +impl CharsetReader { + pub fn new(inner: R, encoding: &'static Encoding) -> Self { + let decoder = CharsetDecoder::new(encoding); + Self { inner, decoder } + } +} + +impl Read for CharsetReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let src = self.fill_buf()?; + let len = src.len().min(buf.len()); + buf[..len].copy_from_slice(&src[..len]); + Ok(len) + } +} + +impl BufRead for CharsetReader { + fn fill_buf(&mut self) -> std::io::Result<&[u8]> { + if self.decoder.needs_input() { + let buf = self.inner.fill_buf()?; + let (read, _) = self.decoder.fill(buf, buf.is_empty()); + self.inner.consume(read); + } + + Ok(self.decoder.read()) + } + + fn consume(&mut self, amount: usize) { + self.decoder.consume(amount); + } +} + +/// Converts bytes from some character encoding to UTF-8, +/// using an internal fixed-size buffer +pub struct CharsetDecoder { + charset_decoder: encoding_rs::Decoder, + buffer: Buffer, + finished: bool, +} + +impl CharsetDecoder { + /// Creates a new `CharsetDecoder`. + fn new(encoding: &'static Encoding) -> Self { + Self { + charset_decoder: encoding.new_decoder(), + buffer: Buffer::with_capacity(DECODE_BUFFER_CAP), + finished: false, + } + } + + /// Returns `true` if the internal buffer is empty. + fn is_empty(&self) -> bool { + self.buffer.is_empty() + } + + /// Returns `true` if the decoder needs more input to make progress. + fn needs_input(&self) -> bool { + !self.finished && self.buffer.is_empty() + } + + /// Fills the internal buffer by decoding the provided bytes, returning + /// the number of bytes consumed and whether the input was exhausted. + fn fill(&mut self, src: &[u8], last: bool) -> (usize, bool) { + if self.finished { + return (0, true); + } + + self.buffer.backshift(); + + let dst = self.buffer.write_buf(); + let (res, read, written, _) = self.charset_decoder.decode_to_utf8(src, dst, last); + self.buffer.advance(written); + + if last && res == CoderResult::InputEmpty { + self.finished = true; + } + + (read, res == CoderResult::InputEmpty) + } + + /// Returns the unread decoded bytes in the internal buffer. + fn read(&self) -> &[u8] { + self.buffer.read_buf() + } + + /// Marks the given amount of bytes from the internal buffer as having been read. + /// Subsequent calls to `read` only return bytes that have not been marked as read. + fn consume(&mut self, amount: usize) { + self.buffer.consume(amount); + } +} + +impl Debug for CharsetDecoder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CharsetDecoder") + .field("charset_decoder", self.charset_decoder.encoding()) + .field("buffer", &self.buffer) + .field("finished", &self.finished) + .finish() + } +} + +mod buffer { + /// A fixed-sized buffer that maintains both + /// a read position and a write position + #[derive(Debug)] + pub struct Buffer { + buf: Box<[u8]>, + read_ptr: usize, + write_ptr: usize, + } + + impl Buffer { + /// Creates a new `Buffer` with the specified capacity + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { + buf: vec![0; capacity].into_boxed_slice(), + read_ptr: 0, + write_ptr: 0, + } + } + + /// Whether there are no more bytes available to be read + pub fn is_empty(&self) -> bool { + self.read_ptr == self.write_ptr + } + + /// Returns the unread portion of the buffer + pub fn read_buf(&self) -> &[u8] { + &self.buf[self.read_ptr..self.write_ptr] + } + + /// Advances the read position by `amount` bytes + pub fn consume(&mut self, amount: usize) { + self.read_ptr += amount; + debug_assert!(self.read_ptr <= self.write_ptr); + } + + /// Returns the portion of the buffer available for writing + pub fn write_buf(&mut self) -> &mut [u8] { + &mut self.buf[self.write_ptr..] + } + + /// Advances the write position by `amount` bytes + pub fn advance(&mut self, amount: usize) { + self.write_ptr += amount; + debug_assert!(self.write_ptr <= self.buf.len()) + } + + /// Moves any unread bytes to the start of the buffer, + /// creating more space for writing new data + pub fn backshift(&mut self) { + self.buf.copy_within(self.read_ptr..self.write_ptr, 0); + self.write_ptr -= self.read_ptr; + self.read_ptr = 0; + } + } +} diff --git a/datafusion/datasource-csv/src/file_format.rs b/datafusion/datasource-csv/src/file_format.rs index 7a253d81db9f8..ea9a5e92f02a9 100644 --- a/datafusion/datasource-csv/src/file_format.rs +++ b/datafusion/datasource-csv/src/file_format.rs @@ -22,6 +22,7 @@ use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug}; use std::sync::Arc; +use crate::charset::lookup_charset; use crate::source::CsvSource; use arrow::array::RecordBatch; @@ -294,6 +295,13 @@ impl CsvFormat { self } + /// Sets the character encoding of the CSV. + /// Defaults to UTF-8 if unspecified. + pub fn with_charset(mut self, charset: Option) -> Self { + self.options.charset = charset; + self + } + /// Set a `FileCompressionType` of CSV /// - defaults to `FileCompressionType::UNCOMPRESSED` pub fn with_file_compression_type( @@ -540,6 +548,8 @@ impl CsvFormat { pin_mut!(stream); + let charset = lookup_charset(self.options.charset.as_deref())?; + while let Some(chunk) = stream.next().await.transpose()? { record_number += 1; let first_chunk = record_number == 0; @@ -569,8 +579,15 @@ impl CsvFormat { format = format.with_comment(comment); } - let (Schema { fields, .. }, records_read) = - format.infer_schema(chunk.reader(), Some(records_to_read))?; + let (Schema { fields, .. }, records_read) = match charset { + #[cfg(feature = "encoding_rs")] + Some(enc) => { + use crate::charset::CharsetReader; + let reader = CharsetReader::new(chunk.reader(), enc); + format.infer_schema(reader, Some(records_to_read))? + } + None => format.infer_schema(chunk.reader(), Some(records_to_read))?, + }; records_to_read -= records_read; total_records_read += records_read; diff --git a/datafusion/datasource-csv/src/mod.rs b/datafusion/datasource-csv/src/mod.rs index fdfee05d86a79..72e3c29b00bc3 100644 --- a/datafusion/datasource-csv/src/mod.rs +++ b/datafusion/datasource-csv/src/mod.rs @@ -20,6 +20,8 @@ // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] +#[cfg(feature = "encoding_rs")] +mod charset; pub mod file_format; pub mod source; @@ -30,6 +32,7 @@ use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::{file::FileSource, file_scan_config::FileScanConfig}; use datafusion_execution::object_store::ObjectStoreUrl; + pub use file_format::*; /// Returns a [`FileScanConfig`] for given `file_groups` @@ -43,3 +46,19 @@ pub fn partitioned_csv_config( .build(), ) } + +#[cfg(not(feature = "encoding_rs"))] +mod charset { + use core::convert::Infallible; + use datafusion_common::{DataFusionError, Result}; + + pub fn lookup_charset(enc: Option<&str>) -> Result> { + match enc { + Some(_) => Err(DataFusionError::NotImplemented( + "The 'encoding_rs' feature must be enabled to decode non-UTF-8 encodings" + .to_string(), + ))?, + None => Ok(None), + } + } +} diff --git a/datafusion/datasource-csv/src/source.rs b/datafusion/datasource-csv/src/source.rs index 77a0dc9cf7995..558be2687a225 100644 --- a/datafusion/datasource-csv/src/source.rs +++ b/datafusion/datasource-csv/src/source.rs @@ -17,11 +17,12 @@ //! Execution plan for reading CSV files +use datafusion_datasource::decoder::deserialize_reader; use datafusion_datasource::projection::{ProjectionOpener, SplitProjection}; use datafusion_physical_plan::projection::ProjectionExprs; use std::any::Any; use std::fmt; -use std::io::{Read, Seek, SeekFrom}; +use std::io::{BufReader, Read, Seek, SeekFrom}; use std::sync::Arc; use std::task::Poll; @@ -46,6 +47,7 @@ use datafusion_physical_plan::{ DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, }; +use crate::charset::lookup_charset; use crate::file_format::CsvDecoder; use futures::{StreamExt, TryStreamExt}; use object_store::buffered::BufWriter; @@ -182,10 +184,6 @@ impl CsvSource { } impl CsvSource { - fn open(&self, reader: R) -> Result> { - Ok(self.builder().build(reader)?) - } - fn builder(&self) -> csv::ReaderBuilder { let mut builder = csv::ReaderBuilder::new(Arc::clone(self.table_schema.file_schema())) @@ -371,6 +369,7 @@ impl FileOpener for CsvOpener { config.options.truncated_rows = Some(config.truncate_rows()); let file_compression_type = self.file_compression_type.to_owned(); + let charset = lookup_charset(self.config.options.charset.as_deref())?; if partitioned_file.range.is_some() { assert!( @@ -410,43 +409,61 @@ impl FileOpener for CsvOpener { .get_opts(&partitioned_file.object_meta.location, options) .await?; + let decoder = config.builder().build_decoder(); + let decoder = CsvDecoder::new(decoder); + match result.payload { #[cfg(not(target_arch = "wasm32"))] GetResultPayload::File(mut file, _) => { let is_whole_file_scanned = partitioned_file.range.is_none(); - let decoder = if is_whole_file_scanned { - // Don't seek if no range as breaks FIFO files + let reader = if is_whole_file_scanned { + // Don't seek if no range as that would break FIFO files file_compression_type.convert_read(file)? } else { - file.seek(SeekFrom::Start(result.range.start as _))?; - file_compression_type.convert_read( - file.take((result.range.end - result.range.start) as u64), - )? + let bytes = (result.range.end - result.range.start) as u64; + file.seek(SeekFrom::Start(result.range.start as u64))?; + file_compression_type.convert_read(file.take(bytes))? }; - let mut reader = config.open(decoder)?; + let reader = BufReader::new(reader); + + let mut reader = match charset { + #[cfg(feature = "encoding_rs")] + Some(enc) => { + use crate::charset::CharsetBatchDecoder; + let decoder = CharsetBatchDecoder::new(decoder, enc); + deserialize_reader(reader, decoder) + } + None => deserialize_reader(reader, decoder), + }; // Use std::iter::from_fn to wrap execution of iterator's next() method. let iterator = std::iter::from_fn(move || { let mut timer = baseline_metrics.elapsed_compute().timer(); let result = reader.next(); timer.stop(); - result + result.map(|r| r.map_err(Into::into)) }); - Ok(futures::stream::iter(iterator) - .map(|r| r.map_err(Into::into)) - .boxed()) + Ok(futures::stream::iter(iterator).boxed()) } - GetResultPayload::Stream(s) => { - let decoder = config.builder().build_decoder(); - let s = s.map_err(DataFusionError::from); - let input = file_compression_type.convert_stream(s.boxed())?.fuse(); - - let stream = deserialize_stream( - input, - DecoderDeserializer::new(CsvDecoder::new(decoder)), - ); + GetResultPayload::Stream(stream) => { + let stream = stream.map_err(DataFusionError::from).boxed(); + + let stream = file_compression_type.convert_stream(stream)?.fuse(); + + let stream = match charset { + #[cfg(feature = "encoding_rs")] + Some(enc) => { + use crate::charset::CharsetBatchDecoder; + let decoder = CharsetBatchDecoder::new(decoder, enc); + deserialize_stream(stream, DecoderDeserializer::new(decoder)) + } + None => { + deserialize_stream(stream, DecoderDeserializer::new(decoder)) + } + }; + Ok(stream.map_err(Into::into).boxed()) } } diff --git a/datafusion/datasource/src/decoder.rs b/datafusion/datasource/src/decoder.rs index 9f9fc0d94bb1c..ef759f0fb7e2f 100644 --- a/datafusion/datasource/src/decoder.rs +++ b/datafusion/datasource/src/decoder.rs @@ -18,8 +18,7 @@ //! Module containing helper methods for the various file formats //! See write.rs for write related helper methods -use ::arrow::array::RecordBatch; - +use arrow::array::RecordBatch; use arrow::error::ArrowError; use bytes::Buf; use bytes::Bytes; @@ -29,19 +28,9 @@ use futures::stream::BoxStream; use futures::{Stream, ready}; use std::collections::VecDeque; use std::fmt; +use std::io::BufRead; use std::task::Poll; -/// Possible outputs of a [`BatchDeserializer`]. -#[derive(Debug, PartialEq)] -pub enum DeserializerOutput { - /// A successfully deserialized [`RecordBatch`]. - RecordBatch(RecordBatch), - /// The deserializer requires more data to make progress. - RequiresMoreData, - /// The input data has been exhausted. - InputExhausted, -} - /// Trait defining a scheme for deserializing byte streams into structured data. /// Implementors of this trait are responsible for converting raw bytes into /// `RecordBatch` objects. @@ -49,8 +38,8 @@ pub trait BatchDeserializer: Send + fmt::Debug { /// Feeds a message for deserialization, updating the internal state of /// this `BatchDeserializer`. Note that one can call this function multiple /// times before calling `next`, which will queue multiple messages for - /// deserialization. Returns the number of bytes consumed. - fn digest(&mut self, message: T) -> usize; + /// deserialization. + fn digest(&mut self, message: T); /// Attempts to deserialize any pending messages and returns a /// `DeserializerOutput` to indicate progress. @@ -61,6 +50,17 @@ pub trait BatchDeserializer: Send + fmt::Debug { fn finish(&mut self); } +/// Possible outputs of a [`BatchDeserializer`]. +#[derive(Debug, PartialEq)] +pub enum DeserializerOutput { + /// A successfully deserialized [`RecordBatch`]. + RecordBatch(RecordBatch), + /// The deserializer requires more data to make progress. + RequiresMoreData, + /// The input data has been exhausted. + InputExhausted, +} + /// A general interface for decoders such as [`arrow::json::reader::Decoder`] and /// [`arrow::csv::reader::Decoder`]. Defines an interface similar to /// [`Decoder::decode`] and [`Decoder::flush`] methods, but also includes @@ -86,24 +86,37 @@ pub trait Decoder: Send + fmt::Debug { fn can_flush_early(&self) -> bool; } -impl fmt::Debug for DecoderDeserializer { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Deserializer") - .field("buffered_queue", &self.buffered_queue) - .field("finalized", &self.finalized) - .finish() +/// A generic, decoder-based deserialization scheme for processing encoded data. +/// +/// This struct is responsible for converting a stream of bytes, which represent +/// encoded data, into a stream of `RecordBatch` objects, following the specified +/// schema and formatting options. It also handles any buffering necessary to satisfy +/// the `Decoder` interface. +pub struct DecoderDeserializer { + /// The underlying decoder used for deserialization + pub(crate) decoder: T, + /// The buffer used to store the remaining bytes to be decoded + pub(crate) buffered_queue: VecDeque, + /// Whether the input stream has been fully consumed + pub(crate) finalized: bool, +} + +impl DecoderDeserializer { + /// Creates a new `DecoderDeserializer` with the provided decoder. + pub fn new(decoder: T) -> Self { + DecoderDeserializer { + decoder, + buffered_queue: VecDeque::new(), + finalized: false, + } } } impl BatchDeserializer for DecoderDeserializer { - fn digest(&mut self, message: Bytes) -> usize { - if message.is_empty() { - return 0; + fn digest(&mut self, message: Bytes) { + if !message.is_empty() { + self.buffered_queue.push_back(message); } - - let consumed = message.len(); - self.buffered_queue.push_back(message); - consumed } fn next(&mut self) -> Result { @@ -139,29 +152,12 @@ impl BatchDeserializer for DecoderDeserializer { } } -/// A generic, decoder-based deserialization scheme for processing encoded data. -/// -/// This struct is responsible for converting a stream of bytes, which represent -/// encoded data, into a stream of `RecordBatch` objects, following the specified -/// schema and formatting options. It also handles any buffering necessary to satisfy -/// the `Decoder` interface. -pub struct DecoderDeserializer { - /// The underlying decoder used for deserialization - pub(crate) decoder: T, - /// The buffer used to store the remaining bytes to be decoded - pub(crate) buffered_queue: VecDeque, - /// Whether the input stream has been fully consumed - pub(crate) finalized: bool, -} - -impl DecoderDeserializer { - /// Creates a new `DecoderDeserializer` with the provided decoder. - pub fn new(decoder: T) -> Self { - DecoderDeserializer { - decoder, - buffered_queue: VecDeque::new(), - finalized: false, - } +impl fmt::Debug for DecoderDeserializer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Deserializer") + .field("buffered_queue", &self.buffered_queue) + .field("finalized", &self.finalized) + .finish() } } @@ -178,7 +174,7 @@ pub fn deserialize_stream<'a>( futures::stream::poll_fn(move |cx| { loop { match ready!(input.poll_next_unpin(cx)).transpose()? { - Some(b) => _ = deserializer.digest(b), + Some(b) => deserializer.digest(b), None => deserializer.finish(), }; @@ -191,3 +187,27 @@ pub fn deserialize_stream<'a>( }) .boxed() } + +/// Creates an iterator of [`RecordBatch`]es that consumes bytes from an inner [`BufRead`] +/// and deserializes them using the provided decoder. +pub fn deserialize_reader<'a>( + mut reader: impl BufRead + Send + 'a, + mut decoder: impl Decoder + 'a, +) -> Box> + Send + 'a> { + let mut read = move || { + loop { + let buf = reader.fill_buf()?; + + let decoded = decoder.decode(buf)?; + reader.consume(decoded); + + if decoded == 0 || decoder.can_flush_early() { + break; + } + } + + decoder.flush() + }; + + Box::new(std::iter::from_fn(move || read().transpose())) +} diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index 62c6bbe85612a..8600baeb5a339 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -476,6 +476,7 @@ message CsvOptions { bytes terminator = 17; // Optional terminator character as a byte bytes truncated_rows = 18; // Indicates if truncated rows are allowed optional uint32 compression_level = 19; // Optional compression level + string charset = 20; // Optional character encoding } // Options controlling CSV format diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index ca8a269958d73..ee9e01c296521 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -984,6 +984,7 @@ impl TryFrom<&protobuf::CsvOptions> for CsvOptions { escape: proto_opts.escape.first().copied(), double_quote: proto_opts.double_quote.first().map(|h| *h != 0), newlines_in_values: proto_opts.newlines_in_values.first().map(|h| *h != 0), + charset: (!proto_opts.charset.is_empty()).then(|| proto_opts.charset.clone()), compression: proto_opts.compression().into(), compression_level: proto_opts.compression_level, schema_infer_max_rec: proto_opts.schema_infer_max_rec.map(|h| h as usize), diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index b00e7546bba20..840b5470eb3f8 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -1701,6 +1701,9 @@ impl serde::Serialize for CsvOptions { if self.compression_level.is_some() { len += 1; } + if !self.charset.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion_common.CsvOptions", len)?; if !self.has_header.is_empty() { #[allow(clippy::needless_borrow)] @@ -1781,6 +1784,9 @@ impl serde::Serialize for CsvOptions { if let Some(v) = self.compression_level.as_ref() { struct_ser.serialize_field("compressionLevel", v)?; } + if !self.charset.is_empty() { + struct_ser.serialize_field("charset", &self.charset)?; + } struct_ser.end() } } @@ -1823,6 +1829,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { "truncatedRows", "compression_level", "compressionLevel", + "charset", ]; #[allow(clippy::enum_variant_names)] @@ -1846,6 +1853,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { Terminator, TruncatedRows, CompressionLevel, + Charset, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -1886,6 +1894,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { "terminator" => Ok(GeneratedField::Terminator), "truncatedRows" | "truncated_rows" => Ok(GeneratedField::TruncatedRows), "compressionLevel" | "compression_level" => Ok(GeneratedField::CompressionLevel), + "charset" => Ok(GeneratedField::Charset), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -1924,6 +1933,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { let mut terminator__ = None; let mut truncated_rows__ = None; let mut compression_level__ = None; + let mut charset__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::HasHeader => { @@ -2062,6 +2072,12 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0) ; } + GeneratedField::Charset => { + if charset__.is_some() { + return Err(serde::de::Error::duplicate_field("charset")); + } + charset__ = Some(map_.next_value()?); + } } } Ok(CsvOptions { @@ -2084,6 +2100,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { terminator: terminator__.unwrap_or_default(), truncated_rows: truncated_rows__.unwrap_or_default(), compression_level: compression_level__, + charset: charset__.unwrap_or_default(), }) } } diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index a09826a29be52..1a8273bea36e2 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -672,6 +672,9 @@ pub struct CsvOptions { /// Optional compression level #[prost(uint32, optional, tag = "19")] pub compression_level: ::core::option::Option, + /// Optional character encoding + #[prost(string, tag = "20")] + pub charset: ::prost::alloc::string::String, } /// Options controlling CSV format #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index 79e3306a4df1b..3f334d7fbf830 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -986,6 +986,7 @@ impl TryFrom<&CsvOptions> for protobuf::CsvOptions { newlines_in_values: opts .newlines_in_values .map_or_else(Vec::new, |h| vec![h as u8]), + charset: opts.charset.clone().unwrap_or_default(), compression: compression.into(), schema_infer_max_rec: opts.schema_infer_max_rec.map(|h| h as u64), date_format: opts.date_format.clone().unwrap_or_default(), diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index a09826a29be52..1a8273bea36e2 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -672,6 +672,9 @@ pub struct CsvOptions { /// Optional compression level #[prost(uint32, optional, tag = "19")] pub compression_level: ::core::option::Option, + /// Optional character encoding + #[prost(string, tag = "20")] + pub charset: ::prost::alloc::string::String, } /// Options controlling CSV format #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] diff --git a/datafusion/proto/src/logical_plan/file_formats.rs b/datafusion/proto/src/logical_plan/file_formats.rs index 08f42b0af7290..4f02de697b96e 100644 --- a/datafusion/proto/src/logical_plan/file_formats.rs +++ b/datafusion/proto/src/logical_plan/file_formats.rs @@ -45,6 +45,7 @@ impl CsvOptionsProto { terminator: options.terminator.map_or(vec![], |v| vec![v]), escape: options.escape.map_or(vec![], |v| vec![v]), double_quote: options.double_quote.map_or(vec![], |v| vec![v as u8]), + charset: options.charset.clone().unwrap_or_default(), compression: options.compression as i32, schema_infer_max_rec: options.schema_infer_max_rec.map(|v| v as u64), date_format: options.date_format.clone().unwrap_or_default(), @@ -95,6 +96,11 @@ impl From<&CsvOptionsProto> for CsvOptions { } else { None }, + charset: if !proto.charset.is_empty() { + Some(proto.charset.clone()) + } else { + None + }, compression: match proto.compression { 0 => CompressionTypeVariant::GZIP, 1 => CompressionTypeVariant::BZIP2,