Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 42 additions & 38 deletions native/core/src/execution/shuffle/spark_unsafe/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -974,9 +974,11 @@ fn append_struct_fields_field_major(
let num_rows = row_end - row_start;
let num_fields = fields.len();

// First pass: Build struct validity and collect which structs are null
// We use a Vec<bool> for simplicity; could use a bitset for better memory
// First pass: Build struct validity and collect nested struct addresses.
// This reads each parent row once, avoiding N*F re-reads in the field loop.
let mut struct_is_null = Vec::with_capacity(num_rows);
let mut nested_addrs: Vec<jlong> = Vec::with_capacity(num_rows);
let mut nested_sizes: Vec<jint> = Vec::with_capacity(num_rows);

for i in row_start..row_end {
read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i);
Expand All @@ -986,23 +988,27 @@ fn append_struct_fields_field_major(

if is_null {
struct_builder.append_null();
nested_addrs.push(0);
nested_sizes.push(0);
} else {
struct_builder.append(true);
let nested_row = parent_row.get_struct(column_idx, num_fields);
nested_addrs.push(nested_row.get_row_addr());
nested_sizes.push(nested_row.get_row_size());
}
}

// Helper macro for processing primitive fields
// Helper macro for processing primitive fields using pre-collected addresses
macro_rules! process_field {
($builder_type:ty, $field_idx:expr, $get_value:expr) => {{
let field_builder = get_field_builder!(struct_builder, $builder_type, $field_idx);
let mut nested_row = SparkUnsafeRow::new_with_num_fields(num_fields);

for (row_idx, i) in (row_start..row_end).enumerate() {
for row_idx in 0..num_rows {
if struct_is_null[row_idx] {
// Struct is null, field is also null
field_builder.append_null();
} else {
read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i);
let nested_row = parent_row.get_struct(column_idx, num_fields);
nested_row.point_to(nested_addrs[row_idx], nested_sizes[row_idx]);

if nested_row.is_null_at($field_idx) {
field_builder.append_null();
Expand Down Expand Up @@ -1058,13 +1064,13 @@ fn append_struct_fields_field_major(
}
DataType::Binary => {
let field_builder = get_field_builder!(struct_builder, BinaryBuilder, field_idx);
let mut nested_row = SparkUnsafeRow::new_with_num_fields(num_fields);

for (row_idx, i) in (row_start..row_end).enumerate() {
for row_idx in 0..num_rows {
if struct_is_null[row_idx] {
field_builder.append_null();
} else {
read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i);
let nested_row = parent_row.get_struct(column_idx, num_fields);
nested_row.point_to(nested_addrs[row_idx], nested_sizes[row_idx]);

if nested_row.is_null_at(field_idx) {
field_builder.append_null();
Expand All @@ -1076,13 +1082,13 @@ fn append_struct_fields_field_major(
}
DataType::Utf8 => {
let field_builder = get_field_builder!(struct_builder, StringBuilder, field_idx);
let mut nested_row = SparkUnsafeRow::new_with_num_fields(num_fields);

for (row_idx, i) in (row_start..row_end).enumerate() {
for row_idx in 0..num_rows {
if struct_is_null[row_idx] {
field_builder.append_null();
} else {
read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i);
let nested_row = parent_row.get_struct(column_idx, num_fields);
nested_row.point_to(nested_addrs[row_idx], nested_sizes[row_idx]);

if nested_row.is_null_at(field_idx) {
field_builder.append_null();
Expand All @@ -1096,13 +1102,13 @@ fn append_struct_fields_field_major(
let p = *p;
let field_builder =
get_field_builder!(struct_builder, Decimal128Builder, field_idx);
let mut nested_row = SparkUnsafeRow::new_with_num_fields(num_fields);

for (row_idx, i) in (row_start..row_end).enumerate() {
for row_idx in 0..num_rows {
if struct_is_null[row_idx] {
field_builder.append_null();
} else {
read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i);
let nested_row = parent_row.get_struct(column_idx, num_fields);
nested_row.point_to(nested_addrs[row_idx], nested_sizes[row_idx]);

if nested_row.is_null_at(field_idx) {
field_builder.append_null();
Expand All @@ -1116,57 +1122,55 @@ fn append_struct_fields_field_major(
DataType::Struct(nested_fields) => {
let nested_builder = get_field_builder!(struct_builder, StructBuilder, field_idx);

// Collect nested struct addresses and sizes in one pass, building validity
let mut nested_addresses: Vec<jlong> = Vec::with_capacity(num_rows);
let mut nested_sizes: Vec<jint> = Vec::with_capacity(num_rows);
// Collect nested-nested struct addresses and sizes in one pass
let mut nested_nested_addrs: Vec<jlong> = Vec::with_capacity(num_rows);
let mut nested_nested_sizes: Vec<jint> = Vec::with_capacity(num_rows);
let mut nested_is_null: Vec<bool> = Vec::with_capacity(num_rows);
let mut nested_row = SparkUnsafeRow::new_with_num_fields(num_fields);

for (row_idx, i) in (row_start..row_end).enumerate() {
for row_idx in 0..num_rows {
if struct_is_null[row_idx] {
// Parent struct is null, nested struct is also null
nested_builder.append_null();
nested_is_null.push(true);
nested_addresses.push(0);
nested_sizes.push(0);
nested_nested_addrs.push(0);
nested_nested_sizes.push(0);
} else {
read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i);
let parent_struct = parent_row.get_struct(column_idx, num_fields);
nested_row.point_to(nested_addrs[row_idx], nested_sizes[row_idx]);

if parent_struct.is_null_at(field_idx) {
if nested_row.is_null_at(field_idx) {
nested_builder.append_null();
nested_is_null.push(true);
nested_addresses.push(0);
nested_sizes.push(0);
nested_nested_addrs.push(0);
nested_nested_sizes.push(0);
} else {
nested_builder.append(true);
nested_is_null.push(false);
// Get nested struct address and size
let nested_row =
parent_struct.get_struct(field_idx, nested_fields.len());
nested_addresses.push(nested_row.get_row_addr());
nested_sizes.push(nested_row.get_row_size());
let inner = nested_row.get_struct(field_idx, nested_fields.len());
nested_nested_addrs.push(inner.get_row_addr());
nested_nested_sizes.push(inner.get_row_size());
}
}
}

// Recursively process nested struct fields in field-major order
append_nested_struct_fields_field_major(
&nested_addresses,
&nested_sizes,
&nested_nested_addrs,
&nested_nested_sizes,
&nested_is_null,
nested_builder,
nested_fields,
)?;
}
// For list and map, fall back to append_field since they have variable-length elements
dt @ (DataType::List(_) | DataType::Map(_, _)) => {
for (row_idx, i) in (row_start..row_end).enumerate() {
let mut nested_row = SparkUnsafeRow::new_with_num_fields(num_fields);

for row_idx in 0..num_rows {
if struct_is_null[row_idx] {
let null_row = SparkUnsafeRow::default();
append_field(dt, struct_builder, &null_row, field_idx)?;
} else {
read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i);
let nested_row = parent_row.get_struct(column_idx, num_fields);
nested_row.point_to(nested_addrs[row_idx], nested_sizes[row_idx]);
append_field(dt, struct_builder, &nested_row, field_idx)?;
}
}
Expand Down
Loading