From edcf1a7530f66cb39b8ed517f79c108b9a2a92c4 Mon Sep 17 00:00:00 2001 From: jhw Date: Mon, 25 May 2026 21:55:44 +0800 Subject: [PATCH] feat(phase-2): refine sql select options --- crates/cli/src/commands/sql.rs | 292 ++++++++++++++++++++++++++++++-- crates/core/src/lib.rs | 7 +- crates/core/src/select.rs | 84 ++++++++++ crates/s3/src/select.rs | 297 ++++++++++++++++++++++++++++++--- 4 files changed, 640 insertions(+), 40 deletions(-) diff --git a/crates/cli/src/commands/sql.rs b/crates/cli/src/commands/sql.rs index fb8144a..797793b 100644 --- a/crates/cli/src/commands/sql.rs +++ b/crates/cli/src/commands/sql.rs @@ -2,8 +2,10 @@ use clap::{Args, ValueEnum}; use rc_core::{ - AliasManager, ObjectStore, SelectCompression, SelectInputFormat, SelectOptions, - SelectOutputFormat, parse_object_path, + AliasManager, ObjectStore, SelectCompression, SelectCsvFileHeaderInfo, SelectCsvInputOptions, + SelectCsvOutputOptions, SelectInputFormat, SelectJsonInputOptions, SelectJsonInputType, + SelectJsonOutputOptions, SelectOptions, SelectOutputFormat, SelectQuoteFields, + SelectScanRangeOptions, SelectSseCustomerOptions, parse_object_path, }; use rc_s3::S3Client; @@ -31,6 +33,74 @@ pub struct SqlArgs { /// Compression of the stored object (input decompression) #[arg(long, value_enum, default_value_t = CompressionArg::None)] pub compression: CompressionArg, + + /// CSV input header handling + #[arg(long, value_enum, default_value_t = CsvFileHeaderInfoArg::None)] + pub csv_file_header_info: CsvFileHeaderInfoArg, + + /// CSV input field delimiter + #[arg(long)] + pub csv_input_field_delimiter: Option, + + /// CSV input quote character + #[arg(long)] + pub csv_input_quote: Option, + + /// CSV input quote escape character + #[arg(long)] + pub csv_input_quote_escape: Option, + + /// CSV input comment character + #[arg(long)] + pub csv_input_comment: Option, + + /// CSV output field delimiter + #[arg(long)] + pub csv_output_field_delimiter: Option, + + /// CSV output record delimiter + #[arg(long)] + pub csv_output_record_delimiter: Option, + + /// CSV output quote character + #[arg(long)] + pub csv_output_quote: Option, + + /// CSV output quote escape character + #[arg(long)] + pub csv_output_quote_escape: Option, + + /// CSV output quote behavior + #[arg(long, value_enum, default_value_t = QuoteFieldsArg::AsNeeded)] + pub csv_output_quote_fields: QuoteFieldsArg, + + /// JSON input shape + #[arg(long, value_enum, default_value_t = JsonTypeArg::Lines)] + pub json_type: JsonTypeArg, + + /// JSON output record delimiter + #[arg(long)] + pub json_output_record_delimiter: Option, + + /// Select ScanRange start byte + #[arg(long)] + pub scan_start: Option, + + /// Select ScanRange end byte + #[arg(long)] + pub scan_end: Option, + + /// SSE-C customer algorithm, usually AES256 + #[arg(long)] + pub sse_customer_algorithm: Option, + + /// SSE-C customer key + #[arg(long)] + pub sse_customer_key: Option, + + /// SSE-C customer key MD5 + #[arg(long)] + pub sse_customer_key_md5: Option, } #[derive(Clone, Copy, Debug, ValueEnum)] @@ -54,6 +124,25 @@ pub enum CompressionArg { Bzip2, } +#[derive(Clone, Copy, Debug, ValueEnum)] +pub enum CsvFileHeaderInfoArg { + None, + Ignore, + Use, +} + +#[derive(Clone, Copy, Debug, ValueEnum)] +pub enum JsonTypeArg { + Lines, + Document, +} + +#[derive(Clone, Copy, Debug, ValueEnum)] +pub enum QuoteFieldsArg { + Always, + AsNeeded, +} + impl From for SelectInputFormat { fn from(value: InputFormatArg) -> Self { match value { @@ -83,6 +172,34 @@ impl From for SelectCompression { } } +impl From for SelectCsvFileHeaderInfo { + fn from(value: CsvFileHeaderInfoArg) -> Self { + match value { + CsvFileHeaderInfoArg::None => SelectCsvFileHeaderInfo::None, + CsvFileHeaderInfoArg::Ignore => SelectCsvFileHeaderInfo::Ignore, + CsvFileHeaderInfoArg::Use => SelectCsvFileHeaderInfo::Use, + } + } +} + +impl From for SelectJsonInputType { + fn from(value: JsonTypeArg) -> Self { + match value { + JsonTypeArg::Lines => SelectJsonInputType::Lines, + JsonTypeArg::Document => SelectJsonInputType::Document, + } + } +} + +impl From for SelectQuoteFields { + fn from(value: QuoteFieldsArg) -> Self { + match value { + QuoteFieldsArg::Always => SelectQuoteFields::Always, + QuoteFieldsArg::AsNeeded => SelectQuoteFields::AsNeeded, + } + } +} + /// Execute the `sql` command. pub async fn execute(args: SqlArgs, output_config: OutputConfig) -> ExitCode { let formatter = Formatter::new(output_config); @@ -100,6 +217,11 @@ pub async fn execute(args: SqlArgs, output_config: OutputConfig) -> ExitCode { } }; + if let Err(message) = validate_select_args(&args) { + formatter.error(&message); + return ExitCode::UsageError; + } + let alias_manager = match AliasManager::new() { Ok(am) => am, Err(e) => { @@ -129,6 +251,35 @@ pub async fn execute(args: SqlArgs, output_config: OutputConfig) -> ExitCode { input_format: args.input_format.into(), output_format: args.output_format.into(), compression: args.compression.into(), + csv_input: SelectCsvInputOptions { + file_header_info: args.csv_file_header_info.into(), + field_delimiter: args.csv_input_field_delimiter, + quote_character: args.csv_input_quote, + quote_escape_character: args.csv_input_quote_escape, + comments: args.csv_input_comment, + }, + csv_output: SelectCsvOutputOptions { + field_delimiter: args.csv_output_field_delimiter, + record_delimiter: args.csv_output_record_delimiter, + quote_character: args.csv_output_quote, + quote_escape_character: args.csv_output_quote_escape, + quote_fields: args.csv_output_quote_fields.into(), + }, + json_input: SelectJsonInputOptions { + input_type: args.json_type.into(), + }, + json_output: SelectJsonOutputOptions { + record_delimiter: args.json_output_record_delimiter, + }, + scan_range: SelectScanRangeOptions { + start: args.scan_start, + end: args.scan_end, + }, + sse_customer: SelectSseCustomerOptions { + algorithm: args.sse_customer_algorithm, + key: args.sse_customer_key, + key_md5: args.sse_customer_key_md5, + }, }; let mut stdout = tokio::io::stdout(); @@ -145,6 +296,75 @@ pub async fn execute(args: SqlArgs, output_config: OutputConfig) -> ExitCode { } } +fn validate_select_args(args: &SqlArgs) -> std::result::Result<(), String> { + validate_single_byte( + "--csv-input-field-delimiter", + args.csv_input_field_delimiter.as_deref(), + )?; + validate_single_byte("--csv-input-quote", args.csv_input_quote.as_deref())?; + validate_single_byte( + "--csv-input-quote-escape", + args.csv_input_quote_escape.as_deref(), + )?; + validate_single_byte("--csv-input-comment", args.csv_input_comment.as_deref())?; + validate_single_byte( + "--csv-output-field-delimiter", + args.csv_output_field_delimiter.as_deref(), + )?; + validate_record_delimiter( + "--csv-output-record-delimiter", + args.csv_output_record_delimiter.as_deref(), + )?; + validate_single_byte("--csv-output-quote", args.csv_output_quote.as_deref())?; + validate_single_byte( + "--csv-output-quote-escape", + args.csv_output_quote_escape.as_deref(), + )?; + validate_scan_range_args(args) +} + +fn validate_single_byte(name: &str, value: Option<&str>) -> std::result::Result<(), String> { + if let Some(value) = value + && value.len() != 1 + { + return Err(format!("{name} must be exactly one byte")); + } + Ok(()) +} + +fn validate_record_delimiter(name: &str, value: Option<&str>) -> std::result::Result<(), String> { + if let Some(value) = value + && value.len() != 1 + && value != "\r\n" + { + return Err(format!("{name} must be exactly one byte or CRLF")); + } + Ok(()) +} + +fn validate_scan_range_args(args: &SqlArgs) -> std::result::Result<(), String> { + if args.scan_start.is_none() && args.scan_end.is_none() { + return Ok(()); + } + if matches!(args.input_format, InputFormatArg::Parquet) { + return Err("ScanRange is not supported for Parquet input".to_string()); + } + if matches!(args.input_format, InputFormatArg::Json) + && matches!(args.json_type, JsonTypeArg::Document) + { + return Err("ScanRange is not supported for JSON document input".to_string()); + } + if args.scan_start.is_some_and(|start| start < 0) || args.scan_end.is_some_and(|end| end < 0) { + return Err("ScanRange start and end must be non-negative".to_string()); + } + if let (Some(start), Some(end)) = (args.scan_start, args.scan_end) + && start > end + { + return Err("ScanRange start must not be greater than end".to_string()); + } + Ok(()) +} + fn exit_code_from_error(error: &rc_core::Error) -> ExitCode { ExitCode::from_i32(error.exit_code()).unwrap_or(ExitCode::GeneralError) } @@ -155,28 +375,70 @@ mod tests { use crate::output::OutputConfig; use rc_core::Error; - #[tokio::test] - async fn sql_empty_query_is_usage_error() { - let args = SqlArgs { - path: "a/b/c".to_string(), - query: " ".to_string(), + fn base_args(path: &str, query: &str) -> SqlArgs { + SqlArgs { + path: path.to_string(), + query: query.to_string(), input_format: InputFormatArg::Csv, output_format: OutputFormatArg::Csv, compression: CompressionArg::None, - }; + csv_file_header_info: CsvFileHeaderInfoArg::None, + csv_input_field_delimiter: None, + csv_input_quote: None, + csv_input_quote_escape: None, + csv_input_comment: None, + csv_output_field_delimiter: None, + csv_output_record_delimiter: None, + csv_output_quote: None, + csv_output_quote_escape: None, + csv_output_quote_fields: QuoteFieldsArg::AsNeeded, + json_type: JsonTypeArg::Lines, + json_output_record_delimiter: None, + scan_start: None, + scan_end: None, + sse_customer_algorithm: None, + sse_customer_key: None, + sse_customer_key_md5: None, + } + } + + #[tokio::test] + async fn sql_empty_query_is_usage_error() { + let args = base_args("a/b/c", " "); let code = execute(args, OutputConfig::default()).await; assert_eq!(code, ExitCode::UsageError); } #[tokio::test] async fn sql_invalid_object_path_is_usage_error() { - let args = SqlArgs { - path: "a/b".to_string(), - query: "SELECT 1".to_string(), - input_format: InputFormatArg::Csv, - output_format: OutputFormatArg::Csv, - compression: CompressionArg::None, - }; + let args = base_args("a/b", "SELECT 1"); + let code = execute(args, OutputConfig::default()).await; + assert_eq!(code, ExitCode::UsageError); + } + + #[tokio::test] + async fn sql_rejects_multi_byte_csv_delimiter() { + let mut args = base_args("a/b/c", "SELECT * FROM S3Object"); + args.csv_input_field_delimiter = Some("||".to_string()); + let code = execute(args, OutputConfig::default()).await; + assert_eq!(code, ExitCode::UsageError); + } + + #[tokio::test] + async fn sql_rejects_scan_range_for_json_document() { + let mut args = base_args("a/b/c", "SELECT * FROM S3Object"); + args.input_format = InputFormatArg::Json; + args.json_type = JsonTypeArg::Document; + args.scan_start = Some(0); + let code = execute(args, OutputConfig::default()).await; + assert_eq!(code, ExitCode::UsageError); + } + + #[tokio::test] + async fn sql_rejects_scan_start_after_end() { + let mut args = base_args("a/b/c", "SELECT * FROM S3Object"); + args.scan_start = Some(20); + args.scan_end = Some(10); let code = execute(args, OutputConfig::default()).await; assert_eq!(code, ExitCode::UsageError); } diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 3fc526e..367cefb 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -38,7 +38,12 @@ pub use replication::{ ReplicationRule, ReplicationRuleStatus, }; pub use retry::{RetryBuilder, is_retryable_error, retry_with_backoff}; -pub use select::{SelectCompression, SelectInputFormat, SelectOptions, SelectOutputFormat}; +pub use select::{ + SelectCompression, SelectCsvFileHeaderInfo, SelectCsvInputOptions, SelectCsvOutputOptions, + SelectInputFormat, SelectJsonInputOptions, SelectJsonInputType, SelectJsonOutputOptions, + SelectOptions, SelectOutputFormat, SelectQuoteFields, SelectScanRangeOptions, + SelectSseCustomerOptions, +}; pub use traits::{ BucketNotification, Capabilities, ListOptions, ListResult, NotificationTarget, ObjectInfo, ObjectStore, ObjectVersion, ObjectVersionListResult, diff --git a/crates/core/src/select.rs b/crates/core/src/select.rs index adaede3..c384058 100644 --- a/crates/core/src/select.rs +++ b/crates/core/src/select.rs @@ -26,6 +26,78 @@ pub enum SelectCompression { Bzip2, } +/// CSV header handling for S3 Select input. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum SelectCsvFileHeaderInfo { + #[default] + None, + Ignore, + Use, +} + +/// JSON input shape for S3 Select. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum SelectJsonInputType { + #[default] + Lines, + Document, +} + +/// CSV output quote behavior. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum SelectQuoteFields { + Always, + #[default] + AsNeeded, +} + +/// Supported CSV input serialization options. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct SelectCsvInputOptions { + pub file_header_info: SelectCsvFileHeaderInfo, + pub field_delimiter: Option, + pub quote_character: Option, + pub quote_escape_character: Option, + pub comments: Option, +} + +/// Supported CSV output serialization options. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct SelectCsvOutputOptions { + pub field_delimiter: Option, + pub record_delimiter: Option, + pub quote_character: Option, + pub quote_escape_character: Option, + pub quote_fields: SelectQuoteFields, +} + +/// Supported JSON input serialization options. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct SelectJsonInputOptions { + pub input_type: SelectJsonInputType, +} + +/// Supported JSON output serialization options. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct SelectJsonOutputOptions { + pub record_delimiter: Option, +} + +/// ScanRange request body parameters. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct SelectScanRangeOptions { + pub start: Option, + pub end: Option, +} + +/// SSE-C parameters for encrypted objects. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct SelectSseCustomerOptions { + pub algorithm: Option, + pub key: Option, + pub key_md5: Option, +} + /// Options for running an S3 Select query on one object. #[derive(Debug, Clone, PartialEq, Eq)] pub struct SelectOptions { @@ -34,6 +106,12 @@ pub struct SelectOptions { pub input_format: SelectInputFormat, pub output_format: SelectOutputFormat, pub compression: SelectCompression, + pub csv_input: SelectCsvInputOptions, + pub csv_output: SelectCsvOutputOptions, + pub json_input: SelectJsonInputOptions, + pub json_output: SelectJsonOutputOptions, + pub scan_range: SelectScanRangeOptions, + pub sse_customer: SelectSseCustomerOptions, } impl Default for SelectOptions { @@ -43,6 +121,12 @@ impl Default for SelectOptions { input_format: SelectInputFormat::Csv, output_format: SelectOutputFormat::Csv, compression: SelectCompression::None, + csv_input: SelectCsvInputOptions::default(), + csv_output: SelectCsvOutputOptions::default(), + json_input: SelectJsonInputOptions::default(), + json_output: SelectJsonOutputOptions::default(), + scan_range: SelectScanRangeOptions::default(), + sse_customer: SelectSseCustomerOptions::default(), } } } diff --git a/crates/s3/src/select.rs b/crates/s3/src/select.rs index a39d0fd..bf3875b 100644 --- a/crates/s3/src/select.rs +++ b/crates/s3/src/select.rs @@ -2,7 +2,7 @@ use aws_sdk_s3::types::{ CompressionType, CsvInput, CsvOutput, ExpressionType, FileHeaderInfo, InputSerialization, - JsonInput, JsonOutput, JsonType, OutputSerialization, ParquetInput, QuoteFields, + JsonInput, JsonOutput, JsonType, OutputSerialization, ParquetInput, QuoteFields, ScanRange, SelectObjectContentEventStream, }; use aws_smithy_runtime_api::client::orchestrator::HttpResponse; @@ -10,8 +10,9 @@ use aws_smithy_runtime_api::client::result::SdkError; use aws_smithy_types::error::metadata::ProvideErrorMetadata; use aws_smithy_types::event_stream::RawMessage; use rc_core::{ - Error, RemotePath, Result, SelectCompression, SelectInputFormat, SelectOptions, - SelectOutputFormat, + Error, RemotePath, Result, SelectCompression, SelectCsvFileHeaderInfo, SelectInputFormat, + SelectJsonInputType, SelectOptions, SelectOutputFormat, + SelectQuoteFields as RcSelectQuoteFields, }; use tokio::io::{AsyncWrite, AsyncWriteExt}; @@ -23,20 +24,31 @@ pub async fn select_object_content( writer: &mut (dyn AsyncWrite + Send + Unpin), ) -> Result<()> { let input = build_input_serialization(options)?; - let output = build_output_serialization(options); + let output = build_output_serialization(options)?; // aws-sdk-s3 `SelectObjectContent` does not expose object `VersionId`; the current object is used. - let resp = client + let mut request = client .select_object_content() .bucket(&path.bucket) .key(&path.key) .expression(&options.expression) .expression_type(ExpressionType::Sql) .input_serialization(input) - .output_serialization(output) - .send() - .await - .map_err(map_select_initial_error)?; + .output_serialization(output); + if let Some(scan_range) = build_scan_range(options)? { + request = request.scan_range(scan_range); + } + if let Some(algorithm) = options.sse_customer.algorithm.as_deref() { + request = request.sse_customer_algorithm(algorithm); + } + if let Some(key) = options.sse_customer.key.as_deref() { + request = request.sse_customer_key(key); + } + if let Some(key_md5) = options.sse_customer.key_md5.as_deref() { + request = request.sse_customer_key_md5(key_md5); + } + + let resp = request.send().await.map_err(map_select_initial_error)?; let mut events = resp.payload; while let Some(ev) = events.recv().await.map_err(map_select_stream_error)? { @@ -62,6 +74,87 @@ fn compression_type(c: SelectCompression) -> CompressionType { } } +fn csv_file_header_info(info: SelectCsvFileHeaderInfo) -> FileHeaderInfo { + match info { + SelectCsvFileHeaderInfo::None => FileHeaderInfo::None, + SelectCsvFileHeaderInfo::Ignore => FileHeaderInfo::Ignore, + SelectCsvFileHeaderInfo::Use => FileHeaderInfo::Use, + } +} + +fn json_input_type(input_type: SelectJsonInputType) -> JsonType { + match input_type { + SelectJsonInputType::Lines => JsonType::Lines, + SelectJsonInputType::Document => JsonType::Document, + } +} + +fn quote_fields(quote_fields: RcSelectQuoteFields) -> QuoteFields { + match quote_fields { + RcSelectQuoteFields::Always => QuoteFields::Always, + RcSelectQuoteFields::AsNeeded => QuoteFields::Asneeded, + } +} + +fn build_scan_range(options: &SelectOptions) -> Result> { + let scan_range = &options.scan_range; + if scan_range.start.is_none() && scan_range.end.is_none() { + return Ok(None); + } + if matches!(options.input_format, SelectInputFormat::Parquet) { + return Err(Error::General( + "ScanRange is not supported for Parquet input.".to_string(), + )); + } + if matches!(options.input_format, SelectInputFormat::Json) + && matches!(options.json_input.input_type, SelectJsonInputType::Document) + { + return Err(Error::General( + "ScanRange is not supported for JSON document input.".to_string(), + )); + } + if scan_range.start.is_some_and(|start| start < 0) || scan_range.end.is_some_and(|end| end < 0) + { + return Err(Error::General( + "ScanRange start and end must be non-negative.".to_string(), + )); + } + if let (Some(start), Some(end)) = (scan_range.start, scan_range.end) + && start > end + { + return Err(Error::General( + "ScanRange start must not be greater than end.".to_string(), + )); + } + Ok(Some( + ScanRange::builder() + .set_start(scan_range.start) + .set_end(scan_range.end) + .build(), + )) +} + +fn validate_single_byte(name: &str, value: Option<&str>) -> Result<()> { + if let Some(value) = value + && value.len() != 1 + { + return Err(Error::General(format!("{name} must be exactly one byte."))); + } + Ok(()) +} + +fn validate_record_delimiter(name: &str, value: Option<&str>) -> Result<()> { + if let Some(value) = value + && value.len() != 1 + && value != "\r\n" + { + return Err(Error::General(format!( + "{name} must be exactly one byte or CRLF." + ))); + } + Ok(()) +} + fn build_input_serialization(options: &SelectOptions) -> Result { if matches!(options.input_format, SelectInputFormat::Parquet) && !matches!(options.compression, SelectCompression::None) @@ -75,14 +168,43 @@ fn build_input_serialization(options: &SelectOptions) -> Result { - let csv = CsvInput::builder() - .file_header_info(FileHeaderInfo::None) - .build(); + validate_single_byte( + "CSV input field delimiter", + options.csv_input.field_delimiter.as_deref(), + )?; + validate_single_byte( + "CSV input quote character", + options.csv_input.quote_character.as_deref(), + )?; + validate_single_byte( + "CSV input quote escape character", + options.csv_input.quote_escape_character.as_deref(), + )?; + validate_single_byte( + "CSV input comment character", + options.csv_input.comments.as_deref(), + )?; + let mut csv = CsvInput::builder() + .file_header_info(csv_file_header_info(options.csv_input.file_header_info)); + if let Some(delimiter) = options.csv_input.field_delimiter.as_deref() { + csv = csv.field_delimiter(delimiter); + } + if let Some(quote) = options.csv_input.quote_character.as_deref() { + csv = csv.quote_character(quote); + } + if let Some(escape) = options.csv_input.quote_escape_character.as_deref() { + csv = csv.quote_escape_character(escape); + } + if let Some(comments) = options.csv_input.comments.as_deref() { + csv = csv.comments(comments); + } + let csv = csv.build(); b = b.csv(csv); } SelectInputFormat::Json => { - // JSONL: one JSON object per line (S3 Select `Type=LINES`). - let json = JsonInput::builder().r#type(JsonType::Lines).build(); + let json = JsonInput::builder() + .r#type(json_input_type(options.json_input.input_type)) + .build(); b = b.json(json); } SelectInputFormat::Parquet => { @@ -93,21 +215,53 @@ fn build_input_serialization(options: &SelectOptions) -> Result OutputSerialization { +fn build_output_serialization(options: &SelectOptions) -> Result { let mut b = OutputSerialization::builder(); match options.output_format { SelectOutputFormat::Csv => { - let csv = CsvOutput::builder() - .quote_fields(QuoteFields::Asneeded) - .build(); + validate_single_byte( + "CSV output field delimiter", + options.csv_output.field_delimiter.as_deref(), + )?; + validate_record_delimiter( + "CSV output record delimiter", + options.csv_output.record_delimiter.as_deref(), + )?; + validate_single_byte( + "CSV output quote character", + options.csv_output.quote_character.as_deref(), + )?; + validate_single_byte( + "CSV output quote escape character", + options.csv_output.quote_escape_character.as_deref(), + )?; + let mut csv = + CsvOutput::builder().quote_fields(quote_fields(options.csv_output.quote_fields)); + if let Some(delimiter) = options.csv_output.field_delimiter.as_deref() { + csv = csv.field_delimiter(delimiter); + } + if let Some(record_delimiter) = options.csv_output.record_delimiter.as_deref() { + csv = csv.record_delimiter(record_delimiter); + } + if let Some(quote) = options.csv_output.quote_character.as_deref() { + csv = csv.quote_character(quote); + } + if let Some(escape) = options.csv_output.quote_escape_character.as_deref() { + csv = csv.quote_escape_character(escape); + } + let csv = csv.build(); b = b.csv(csv); } SelectOutputFormat::Json => { - let json = JsonOutput::builder().build(); + let mut json = JsonOutput::builder(); + if let Some(record_delimiter) = options.json_output.record_delimiter.as_deref() { + json = json.record_delimiter(record_delimiter); + } + let json = json.build(); b = b.json(json); } } - b.build() + Ok(b.build()) } fn resolve_http_service_error_code<'a, E: ProvideErrorMetadata + ?Sized>( @@ -205,10 +359,16 @@ fn classify_aws_code_missing_metadata(text: &str) -> Error { #[cfg(test)] mod tests { - use super::{build_input_serialization, build_output_serialization, classify_aws_code}; + use super::{ + build_input_serialization, build_output_serialization, build_scan_range, classify_aws_code, + }; use aws_sdk_s3::types::{CompressionType, FileHeaderInfo, JsonType, QuoteFields}; use rc_core::Error; - use rc_core::{SelectCompression, SelectInputFormat, SelectOptions, SelectOutputFormat}; + use rc_core::{ + SelectCompression, SelectCsvInputOptions, SelectCsvOutputOptions, SelectInputFormat, + SelectJsonInputOptions, SelectJsonInputType, SelectOptions, SelectOutputFormat, + SelectScanRangeOptions, + }; #[test] fn classify_maps_no_such_key() { @@ -301,6 +461,7 @@ mod tests { input_format: SelectInputFormat::Parquet, output_format: SelectOutputFormat::Csv, compression: SelectCompression::Gzip, + ..SelectOptions::default() }; let error = build_input_serialization(&options) @@ -315,6 +476,7 @@ mod tests { input_format: SelectInputFormat::Parquet, output_format: SelectOutputFormat::Csv, compression: SelectCompression::None, + ..SelectOptions::default() }; build_input_serialization(&options).expect("parquet without whole-object compression"); @@ -327,6 +489,7 @@ mod tests { input_format: SelectInputFormat::Csv, output_format: SelectOutputFormat::Csv, compression: SelectCompression::Bzip2, + ..SelectOptions::default() }; let input = build_input_serialization(&options).expect("csv input serialization"); @@ -345,6 +508,7 @@ mod tests { input_format: SelectInputFormat::Json, output_format: SelectOutputFormat::Json, compression: SelectCompression::Gzip, + ..SelectOptions::default() }; let input = build_input_serialization(&options).expect("json input serialization"); @@ -363,8 +527,9 @@ mod tests { input_format: SelectInputFormat::Csv, output_format: SelectOutputFormat::Csv, compression: SelectCompression::None, + ..SelectOptions::default() }; - let csv_output = build_output_serialization(&csv_options); + let csv_output = build_output_serialization(&csv_options).expect("csv output"); let csv = csv_output.csv().expect("csv output is configured"); assert_eq!(csv.quote_fields(), Some(&QuoteFields::Asneeded)); assert!(csv_output.json().is_none()); @@ -374,9 +539,93 @@ mod tests { input_format: SelectInputFormat::Json, output_format: SelectOutputFormat::Json, compression: SelectCompression::None, + ..SelectOptions::default() }; - let json_output = build_output_serialization(&json_options); + let json_output = build_output_serialization(&json_options).expect("json output"); assert!(json_output.json().is_some()); assert!(json_output.csv().is_none()); } + + #[test] + fn csv_input_rejects_multi_byte_delimiter() { + let options = SelectOptions { + expression: "SELECT * FROM S3Object".to_string(), + csv_input: SelectCsvInputOptions { + field_delimiter: Some("||".to_string()), + ..SelectCsvInputOptions::default() + }, + ..SelectOptions::default() + }; + + let error = build_input_serialization(&options) + .expect_err("multi-byte CSV input delimiter should be rejected"); + assert!(matches!(error, Error::General(msg) if msg.contains("field delimiter"))); + } + + #[test] + fn csv_output_record_delimiter_allows_crlf() { + let options = SelectOptions { + expression: "SELECT * FROM S3Object".to_string(), + csv_output: SelectCsvOutputOptions { + record_delimiter: Some("\r\n".to_string()), + ..SelectCsvOutputOptions::default() + }, + ..SelectOptions::default() + }; + + let output = build_output_serialization(&options).expect("CRLF record delimiter"); + let csv = output.csv().expect("csv output is configured"); + assert_eq!(csv.record_delimiter(), Some("\r\n")); + } + + #[test] + fn csv_output_rejects_multi_byte_record_delimiter() { + let options = SelectOptions { + expression: "SELECT * FROM S3Object".to_string(), + csv_output: SelectCsvOutputOptions { + record_delimiter: Some("||".to_string()), + ..SelectCsvOutputOptions::default() + }, + ..SelectOptions::default() + }; + + let error = build_output_serialization(&options) + .expect_err("multi-byte CSV output record delimiter should be rejected"); + assert!(matches!(error, Error::General(msg) if msg.contains("record delimiter"))); + } + + #[test] + fn scan_range_rejects_json_document() { + let options = SelectOptions { + expression: "SELECT * FROM S3Object".to_string(), + input_format: SelectInputFormat::Json, + json_input: SelectJsonInputOptions { + input_type: SelectJsonInputType::Document, + }, + scan_range: SelectScanRangeOptions { + start: Some(0), + end: None, + }, + ..SelectOptions::default() + }; + + let error = + build_scan_range(&options).expect_err("scan range should reject JSON document input"); + assert!(matches!(error, Error::General(msg) if msg.contains("JSON document"))); + } + + #[test] + fn scan_range_rejects_start_after_end() { + let options = SelectOptions { + expression: "SELECT * FROM S3Object".to_string(), + scan_range: SelectScanRangeOptions { + start: Some(20), + end: Some(10), + }, + ..SelectOptions::default() + }; + + let error = build_scan_range(&options).expect_err("start after end should be rejected"); + assert!(matches!(error, Error::General(msg) if msg.contains("greater than end"))); + } }