diff --git a/rust/fory-core/src/buffer.rs b/rust/fory-core/src/buffer.rs index 4cab51acf1..5e223c12e7 100644 --- a/rust/fory-core/src/buffer.rs +++ b/rust/fory-core/src/buffer.rs @@ -499,6 +499,21 @@ impl<'a> Writer<'a> { self.write_u64(combined); } } + + /// # Safety + #[inline(always)] + pub unsafe fn prepare_write(&mut self, max_bytes: usize) -> (*mut u8, usize) { + self.reserve(max_bytes); + let len = self.bf.len(); + (self.bf.as_mut_ptr().add(len), len) + } + + /// # Safety + #[inline(always)] + pub unsafe fn finish_write(&mut self, new_len: usize) { + debug_assert!(new_len <= self.bf.capacity()); + self.bf.set_len(new_len); + } } #[derive(Default)] diff --git a/rust/fory-core/src/lib.rs b/rust/fory-core/src/lib.rs index 976a760af6..14128c5a69 100644 --- a/rust/fory-core/src/lib.rs +++ b/rust/fory-core/src/lib.rs @@ -187,6 +187,7 @@ pub mod row; pub mod serializer; pub mod types; pub use float16::float16 as Float16; +pub mod unsafe_util; pub mod util; // Re-export paste for use in macros diff --git a/rust/fory-core/src/unsafe_util.rs b/rust/fory-core/src/unsafe_util.rs new file mode 100644 index 0000000000..5ac089089a --- /dev/null +++ b/rust/fory-core/src/unsafe_util.rs @@ -0,0 +1,293 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::float16::float16; + +/// # Safety +#[inline(always)] +pub unsafe fn put_bool_at(ptr: *mut u8, value: bool) -> usize { + *ptr = value as u8; + 1 +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_i8_at(ptr: *mut u8, value: i8) -> usize { + *ptr = value as u8; + 1 +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_u8_at(ptr: *mut u8, value: u8) -> usize { + *ptr = value; + 1 +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_i16_at(ptr: *mut u8, value: i16) -> usize { + put_u16_at(ptr, value as u16) +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_u16_at(ptr: *mut u8, value: u16) -> usize { + (ptr as *mut u16).write_unaligned(value.to_le()); + 2 +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_i32_at(ptr: *mut u8, value: i32) -> usize { + put_u32_at(ptr, value as u32) +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_u32_at(ptr: *mut u8, value: u32) -> usize { + (ptr as *mut u32).write_unaligned(value.to_le()); + 4 +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_i64_at(ptr: *mut u8, value: i64) -> usize { + put_u64_at(ptr, value as u64) +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_u64_at(ptr: *mut u8, value: u64) -> usize { + (ptr as *mut u64).write_unaligned(value.to_le()); + 8 +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_f16_at(ptr: *mut u8, value: float16) -> usize { + put_u16_at(ptr, value.to_bits()) +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_f32_at(ptr: *mut u8, value: f32) -> usize { + (ptr as *mut f32).write_unaligned(value); + 4 +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_f64_at(ptr: *mut u8, value: f64) -> usize { + (ptr as *mut f64).write_unaligned(value); + 8 +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_i128_at(ptr: *mut u8, value: i128) -> usize { + put_u128_at(ptr, value as u128) +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_u128_at(ptr: *mut u8, value: u128) -> usize { + (ptr as *mut u128).write_unaligned(value.to_le()); + 16 +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_varint32_at(ptr: *mut u8, value: i32) -> usize { + let zigzag = ((value as i64) << 1) ^ ((value as i64) >> 31); + put_var_uint32_at(ptr, zigzag as u32) +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_var_uint32_at(ptr: *mut u8, value: u32) -> usize { + if value < 0x80 { + *ptr = value as u8; + 1 + } else if value < 0x4000 { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (value >> 7) as u8; + (ptr as *mut u16).write_unaligned(u16::from_ne_bytes([u1, u2])); + 2 + } else if value < 0x200000 { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (((value >> 7) as u8) & 0x7F) | 0x80; + let u3 = (value >> 14) as u8; + (ptr as *mut u16).write_unaligned(u16::from_ne_bytes([u1, u2])); + *ptr.add(2) = u3; + 3 + } else if value < 0x10000000 { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (((value >> 7) as u8) & 0x7F) | 0x80; + let u3 = (((value >> 14) as u8) & 0x7F) | 0x80; + let u4 = (value >> 21) as u8; + (ptr as *mut u32).write_unaligned(u32::from_ne_bytes([u1, u2, u3, u4])); + 4 + } else { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (((value >> 7) as u8) & 0x7F) | 0x80; + let u3 = (((value >> 14) as u8) & 0x7F) | 0x80; + let u4 = (((value >> 21) as u8) & 0x7F) | 0x80; + let u5 = (value >> 28) as u8; + (ptr as *mut u32).write_unaligned(u32::from_ne_bytes([u1, u2, u3, u4])); + *ptr.add(4) = u5; + 5 + } +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_varint64_at(ptr: *mut u8, value: i64) -> usize { + let zigzag = ((value << 1) ^ (value >> 63)) as u64; + put_var_uint64_at(ptr, zigzag) +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_var_uint64_at(ptr: *mut u8, value: u64) -> usize { + if value < 0x80 { + *ptr = value as u8; + return 1; + } + if value < 0x4000 { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (value >> 7) as u8; + (ptr as *mut u16).write_unaligned(u16::from_ne_bytes([u1, u2])); + return 2; + } + if value < 0x200000 { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (((value >> 7) as u8) & 0x7F) | 0x80; + let u3 = (value >> 14) as u8; + (ptr as *mut u16).write_unaligned(u16::from_ne_bytes([u1, u2])); + *ptr.add(2) = u3; + return 3; + } + if value < 0x10000000 { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (((value >> 7) as u8) & 0x7F) | 0x80; + let u3 = (((value >> 14) as u8) & 0x7F) | 0x80; + let u4 = (value >> 21) as u8; + (ptr as *mut u32).write_unaligned(u32::from_ne_bytes([u1, u2, u3, u4])); + return 4; + } + if value < 0x800000000 { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (((value >> 7) as u8) & 0x7F) | 0x80; + let u3 = (((value >> 14) as u8) & 0x7F) | 0x80; + let u4 = (((value >> 21) as u8) & 0x7F) | 0x80; + let u5 = (value >> 28) as u8; + (ptr as *mut u32).write_unaligned(u32::from_ne_bytes([u1, u2, u3, u4])); + *ptr.add(4) = u5; + return 5; + } + if value < 0x40000000000 { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (((value >> 7) as u8) & 0x7F) | 0x80; + let u3 = (((value >> 14) as u8) & 0x7F) | 0x80; + let u4 = (((value >> 21) as u8) & 0x7F) | 0x80; + let u5 = (((value >> 28) as u8) & 0x7F) | 0x80; + let u6 = (value >> 35) as u8; + (ptr as *mut u32).write_unaligned(u32::from_ne_bytes([u1, u2, u3, u4])); + (ptr.add(4) as *mut u16).write_unaligned(u16::from_ne_bytes([u5, u6])); + return 6; + } + if value < 0x2000000000000 { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (((value >> 7) as u8) & 0x7F) | 0x80; + let u3 = (((value >> 14) as u8) & 0x7F) | 0x80; + let u4 = (((value >> 21) as u8) & 0x7F) | 0x80; + let u5 = (((value >> 28) as u8) & 0x7F) | 0x80; + let u6 = (((value >> 35) as u8) & 0x7F) | 0x80; + let u7 = (value >> 42) as u8; + (ptr as *mut u32).write_unaligned(u32::from_ne_bytes([u1, u2, u3, u4])); + (ptr.add(4) as *mut u16).write_unaligned(u16::from_ne_bytes([u5, u6])); + *ptr.add(6) = u7; + return 7; + } + if value < 0x100000000000000 { + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (((value >> 7) as u8) & 0x7F) | 0x80; + let u3 = (((value >> 14) as u8) & 0x7F) | 0x80; + let u4 = (((value >> 21) as u8) & 0x7F) | 0x80; + let u5 = (((value >> 28) as u8) & 0x7F) | 0x80; + let u6 = (((value >> 35) as u8) & 0x7F) | 0x80; + let u7 = (((value >> 42) as u8) & 0x7F) | 0x80; + let u8v = (value >> 49) as u8; + (ptr as *mut u64).write_unaligned(u64::from_ne_bytes([u1, u2, u3, u4, u5, u6, u7, u8v])); + return 8; + } + let u1 = ((value as u8) & 0x7F) | 0x80; + let u2 = (((value >> 7) as u8) & 0x7F) | 0x80; + let u3 = (((value >> 14) as u8) & 0x7F) | 0x80; + let u4 = (((value >> 21) as u8) & 0x7F) | 0x80; + let u5 = (((value >> 28) as u8) & 0x7F) | 0x80; + let u6 = (((value >> 35) as u8) & 0x7F) | 0x80; + let u7 = (((value >> 42) as u8) & 0x7F) | 0x80; + let u8v = (((value >> 49) as u8) & 0x7F) | 0x80; + let u9 = (value >> 56) as u8; + (ptr as *mut u64).write_unaligned(u64::from_ne_bytes([u1, u2, u3, u4, u5, u6, u7, u8v])); + *ptr.add(8) = u9; + 9 +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_tagged_i64_at(ptr: *mut u8, value: i64) -> usize { + const HALF_MIN: i64 = i32::MIN as i64 / 2; + const HALF_MAX: i64 = i32::MAX as i64 / 2; + if (HALF_MIN..=HALF_MAX).contains(&value) { + let v = (value as i32) << 1; + (ptr as *mut i32).write_unaligned(v.to_le()); + 4 + } else { + *ptr = 0b1; + (ptr.add(1) as *mut i64).write_unaligned(value.to_le()); + 9 + } +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_tagged_u64_at(ptr: *mut u8, value: u64) -> usize { + if value <= i32::MAX as u64 { + let v = (value as u32) << 1; + (ptr as *mut u32).write_unaligned(v.to_le()); + 4 + } else { + *ptr = 0b1; + (ptr.add(1) as *mut u64).write_unaligned(value.to_le()); + 9 + } +} + +/// # Safety +#[inline(always)] +pub unsafe fn put_usize_at(ptr: *mut u8, value: usize) -> usize { + const SIZE: usize = std::mem::size_of::(); + match SIZE { + 2 => put_u16_at(ptr, value as u16), + 4 => put_var_uint32_at(ptr, value as u32), + 8 => put_var_uint64_at(ptr, value as u64), + _ => unreachable!(), + } +} diff --git a/rust/fory-derive/src/object/util.rs b/rust/fory-derive/src/object/util.rs index 07c40b4959..5c998ab362 100644 --- a/rust/fory-derive/src/object/util.rs +++ b/rust/fory-derive/src/object/util.rs @@ -835,6 +835,111 @@ pub(super) fn get_primitive_writer_method_with_encoding( get_primitive_writer_method(type_name) } +pub(super) fn get_max_primitive_bytes( + type_name: &str, + meta: &super::field_meta::ForyFieldMeta, +) -> usize { + use fory_core::types::TypeId; + + if type_name == "i32" { + if let Some(type_id) = meta.type_id { + if type_id == TypeId::INT32 as i16 { + return 4; + } + } + return 5; + } + + if type_name == "u32" { + if let Some(type_id) = meta.type_id { + if type_id == TypeId::INT32 as i16 || type_id == TypeId::UINT32 as i16 { + return 4; + } + } + return 5; + } + + if type_name == "u64" { + if let Some(type_id) = meta.type_id { + if type_id == TypeId::INT32 as i16 || type_id == TypeId::UINT64 as i16 { + return 8; + } else if type_id == TypeId::TAGGED_UINT64 as i16 { + return 9; + } + } + return 9; + } + + if type_name == "i64" { + return 9; + } + + match type_name { + "bool" | "i8" | "u8" => 1, + "i16" | "u16" | "float16" => 2, + "f32" => 4, + "f64" => 8, + "i128" | "u128" => 16, + "isize" | "usize" => 9, + _ => 0, + } +} + +pub(super) fn get_put_at_method_with_encoding( + type_name: &str, + meta: &super::field_meta::ForyFieldMeta, +) -> &'static str { + use fory_core::types::TypeId; + + if type_name == "i32" { + if let Some(type_id) = meta.type_id { + if type_id == TypeId::INT32 as i16 { + return "put_i32_at"; + } + } + return "put_varint32_at"; + } + + if type_name == "u32" { + if let Some(type_id) = meta.type_id { + if type_id == TypeId::INT32 as i16 || type_id == TypeId::UINT32 as i16 { + return "put_u32_at"; + } + } + return "put_var_uint32_at"; + } + + if type_name == "u64" { + if let Some(type_id) = meta.type_id { + if type_id == TypeId::INT32 as i16 || type_id == TypeId::UINT64 as i16 { + return "put_u64_at"; + } else if type_id == TypeId::TAGGED_UINT64 as i16 { + return "put_tagged_u64_at"; + } + } + return "put_var_uint64_at"; + } + + if type_name == "i64" { + return "put_varint64_at"; + } + + match type_name { + "bool" => "put_bool_at", + "i8" => "put_i8_at", + "u8" => "put_u8_at", + "i16" => "put_i16_at", + "u16" => "put_u16_at", + "f32" => "put_f32_at", + "f64" => "put_f64_at", + "float16" => "put_f16_at", + "i128" => "put_i128_at", + "u128" => "put_u128_at", + "usize" => "put_usize_at", + _ => panic!("unsupported primitive type for put_at: {type_name}"), + } +} + /// Get the reader method name for a primitive numeric type /// Panics if type_name is not a primitive type pub(super) fn get_primitive_reader_method(type_name: &str) -> &'static str { diff --git a/rust/fory-derive/src/object/write.rs b/rust/fory-derive/src/object/write.rs index 8300e8784b..1a4031c33e 100644 --- a/rust/fory-derive/src/object/write.rs +++ b/rust/fory-derive/src/object/write.rs @@ -19,10 +19,10 @@ use super::field_meta::parse_field_meta; use super::util::{ classify_trait_object_field, create_wrapper_types_arc, create_wrapper_types_rc, determine_field_ref_mode, extract_type_name, gen_struct_version_hash_ts, get_field_accessor, - get_field_name, get_filtered_source_fields_iter, get_option_inner_primitive_name, - get_primitive_writer_method_with_encoding, get_struct_name, get_type_id_by_type_ast, - is_debug_enabled, is_direct_primitive_type, is_option_encoding_primitive, FieldRefMode, - StructField, + get_field_name, get_filtered_source_fields_iter, get_max_primitive_bytes, + get_option_inner_primitive_name, get_primitive_writer_method_with_encoding, + get_put_at_method_with_encoding, get_struct_name, get_type_id_by_type_ast, is_debug_enabled, + is_direct_primitive_type, is_option_encoding_primitive, FieldRefMode, StructField, }; use crate::util::SourceField; use fory_core::types::TypeId; @@ -280,15 +280,10 @@ fn gen_write_field_impl( <#ty as fory_core::Serializer>::fory_write_data(&#value_ts, context)?; } } else { - // Numeric primitives: use direct buffer methods - // For u32/u64, consider encoding attributes let writer_method = get_primitive_writer_method_with_encoding(&type_name, &meta); let writer_ident = syn::Ident::new(writer_method, proc_macro2::Span::call_site()); - // For primitives: - // - use_self=true: #value_ts is `self.field`, which is T (copy happens automatically) - // - use_self=false: #value_ts is `field` from pattern match on &self, which is &T let value_expr = if use_self { quote! { #value_ts } } else { @@ -359,18 +354,82 @@ fn gen_write_field_impl( pub fn gen_write_data(source_fields: &[SourceField<'_>]) -> TokenStream { let fields: Vec<&Field> = source_fields.iter().map(|sf| sf.field).collect(); - let write_fields_ts: Vec<_> = get_filtered_source_fields_iter(source_fields) - .map(|sf| gen_write_field_with_index(sf.field, sf.original_index, true)) - .collect(); - + let filtered: Vec<_> = get_filtered_source_fields_iter(source_fields).collect(); let version_hash_ts = gen_struct_version_hash_ts(&fields); - quote! { - if context.is_check_struct_version() { - let version_hash: i32 = #version_hash_ts; - context.writer.write_i32(version_hash); + + let fast_count = if !is_debug_enabled() { + filtered + .iter() + .take_while(|sf| { + let ref_mode = determine_field_ref_mode(sf.field); + ref_mode == FieldRefMode::None + && is_direct_primitive_type(&sf.field.ty) + && extract_type_name(&sf.field.ty) != "String" + }) + .count() + } else { + 0 + }; + + if fast_count > 0 { + let (fast, rest) = filtered.split_at(fast_count); + + let max_bytes: usize = fast + .iter() + .map(|sf| { + let tn = extract_type_name(&sf.field.ty); + let meta = parse_field_meta(sf.field).unwrap_or_default(); + get_max_primitive_bytes(&tn, &meta) + }) + .sum(); + + let put_stmts: Vec = fast + .iter() + .map(|sf| { + let value_ts = get_field_accessor(sf.field, sf.original_index, true); + let tn = extract_type_name(&sf.field.ty); + let meta = parse_field_meta(sf.field).unwrap_or_default(); + let method = get_put_at_method_with_encoding(&tn, &meta); + let method_ident = syn::Ident::new(method, proc_macro2::Span::call_site()); + quote! { + offset += fory_core::unsafe_util::#method_ident(ptr.add(offset), #value_ts); + } + }) + .collect(); + + let remaining_ts: Vec<_> = rest + .iter() + .map(|sf| gen_write_field_with_index(sf.field, sf.original_index, true)) + .collect(); + + quote! { + if context.is_check_struct_version() { + let version_hash: i32 = #version_hash_ts; + context.writer.write_i32(version_hash); + } + unsafe { + let (ptr, base_len) = context.writer.prepare_write(#max_bytes); + let mut offset = 0usize; + #(#put_stmts)* + context.writer.finish_write(base_len + offset); + } + #(#remaining_ts)* + Ok(()) + } + } else { + let write_fields_ts: Vec<_> = filtered + .iter() + .map(|sf| gen_write_field_with_index(sf.field, sf.original_index, true)) + .collect(); + + quote! { + if context.is_check_struct_version() { + let version_hash: i32 = #version_hash_ts; + context.writer.write_i32(version_hash); + } + #(#write_fields_ts)* + Ok(()) } - #(#write_fields_ts)* - Ok(()) } }