diff --git a/cold-string/README.md b/cold-string/README.md index 16969fb..a72f237 100644 --- a/cold-string/README.md +++ b/cold-string/README.md @@ -42,16 +42,23 @@ assert_eq!(size_of::(), size_of::()); assert_eq!(align_of::(), 1); assert_eq!(size_of::<(ColdString, u8)>(), size_of::() + 1); -assert_eq!(size_of::>(), size_of::() + 1); +``` + +It has a null-niche, so `Option` is the same size as `ColdString`: +```rust +assert_eq!(size_of::>(), size_of::()); ``` ## How It Works ColdString is an 8-byte tagged pointer (4 bytes on 32-bit machines): + ```rust +use std::ptr::NonNull; + #[repr(packed)] pub struct ColdString { - encoded: *mut u8, + encoded: NonNull, } ``` The 8 bytes encode one of three representations indicated by the 1st byte: @@ -61,6 +68,7 @@ least-significant 2 bits of the address are `00`. On the heap, the UTF-8 charact - `xxxxxxxx`: All 8 bytes are UTF-8. `10xxxxxx` and `11111xxx` are chosen because they cannot be valid first bytes of UTF-8. +An eight character string consisting of only NUL is represented as a tagged NULL pointer. ### Why "Cold"? diff --git a/cold-string/src/lib.rs b/cold-string/src/lib.rs index 3a494a2..87669af 100644 --- a/cold-string/src/lib.rs +++ b/cold-string/src/lib.rs @@ -22,7 +22,9 @@ use core::{ iter::FromIterator, mem, ops::Deref, - ptr, slice, str, + ptr, + ptr::NonNull, + slice, str, }; mod vint; @@ -57,9 +59,11 @@ pub struct ColdString { /// with the LSB bits of the tag byte. The address is always a multiple of 4 (`HEAP_ALIGN`). /// - 11111xxx: xxx is the length in range 0..=7, followed by length UTF-8 bytes. /// - xxxxxxxx (valid UTF-8): 8 UTF-8 bytes. - encoded: *const u8, + encoded: NonNull, } +static EIGHT_NUL: [u8; WIDTH] = [0u8; WIDTH]; + impl ColdString { const TAG_MASK: usize = usize::from_ne_bytes(0b11000000usize.to_le_bytes()); const INLINE_TAG: usize = usize::from_ne_bytes(0b11111000usize.to_le_bytes()); @@ -135,6 +139,13 @@ impl ColdString { } } + #[rustversion::attr(since(1.61), const)] + #[inline] + fn new_eight_nul() -> Self { + // SAFETY: PTR_TAG is non-zero + unsafe { Self::from_inline_buf(Self::PTR_TAG.to_ne_bytes()) } + } + #[inline] const fn inline_buf(s: &str) -> [u8; WIDTH] { debug_assert!(s.len() <= WIDTH); @@ -147,10 +158,12 @@ impl ColdString { buf } + /// SAFETY: b must not be all-zero #[rustversion::attr(since(1.61), const)] #[inline] - fn from_inline_buf(b: [u8; WIDTH]) -> Self { + unsafe fn from_inline_buf(b: [u8; WIDTH]) -> Self { let encoded = ptr::null_mut::().wrapping_add(usize::from_ne_bytes(b)); + let encoded = NonNull::new_unchecked(encoded); Self { encoded } } @@ -161,10 +174,16 @@ impl ColdString { #[inline] fn new_inline(s: &str) -> Self { + if s.as_bytes() == EIGHT_NUL { + return Self::new_eight_nul(); + } let mut buf = Self::inline_buf(s); let start = Self::utf8_start(s.len()); buf[start..s.len() + start].copy_from_slice(s.as_bytes()); - Self::from_inline_buf(buf) + // SAFETY: + // it is checked at the top of the function than s is not all NUL + // and the inline tag is not 0, so shorter strings will also be not all NUL + unsafe { Self::from_inline_buf(buf) } } /// Creates a new inline [`ColdString`] from `&'static str` at compile time. @@ -190,6 +209,14 @@ impl ColdString { "Length for `new_inline_const` must be less than `core::mem::size_of::()`." ); } + if s.len() == WIDTH { + // can't do a slice comparison in const context + let bytes = unsafe { *(s.as_bytes() as *const _ as *const [u8; WIDTH]) }; + let int = usize::from_ne_bytes(bytes); + if int == 0 { + return Self::new_eight_nul(); + } + } let mut buf = Self::inline_buf(s); let start = Self::utf8_start(s.len()); let mut i = 0; @@ -197,18 +224,23 @@ impl ColdString { buf[i + start] = s.as_bytes()[i]; i += 1; } - Self::from_inline_buf(buf) + // SAFETY: + // It is checked at the top of the function than s is not all NUL, + // and the inline tag is not 0, so shorter strings will also be not all NUL. + unsafe { Self::from_inline_buf(buf) } } #[rustversion::attr(since(1.71), const)] #[inline] - unsafe fn ptr(&self) -> *const u8 { - ptr::read_unaligned(ptr::addr_of!(self.encoded)) + fn ptr(&self) -> *const u8 { + // SAFETY: this is always safe since addr_of! will return a pointer valid for reading + let encoded = unsafe { ptr::read_unaligned(ptr::addr_of!(self.encoded)) }; + encoded.as_ptr() } #[inline] fn addr(&self) -> usize { - unsafe { self.ptr().addr() } + self.ptr().addr() } #[inline] @@ -230,6 +262,7 @@ impl ColdString { let layout = Layout::from_size_align(total, HEAP_ALIGN).unwrap(); unsafe { + // SAFETY: the layout size is non-zero, since the smallest VarInt is one byte let ptr = alloc(layout); if ptr.is_null() { alloc::alloc::handle_alloc_error(layout); @@ -244,6 +277,7 @@ impl ColdString { addr |= Self::PTR_TAG; addr }); + let encoded = NonNull::new_unchecked(encoded); Self { encoded } } } @@ -251,14 +285,12 @@ impl ColdString { #[inline] fn heap_ptr(&self) -> *const u8 { debug_assert!(!self.is_inline()); - unsafe { - self.ptr().map_addr(|mut addr| { - addr ^= Self::PTR_TAG; - let addr = addr.rotate_right(6 + Self::ROT); - debug_assert!(addr % HEAP_ALIGN == 0); - addr - }) - } + self.ptr().map_addr(|mut addr| { + addr ^= Self::PTR_TAG; + let addr = addr.rotate_right(6 + Self::ROT); + debug_assert!(addr % HEAP_ALIGN == 0); + addr + }) } #[inline] @@ -293,6 +325,9 @@ impl ColdString { } else { unsafe { let ptr = self.heap_ptr(); + if ptr.is_null() { + return WIDTH; + } let (len, _) = VarInt::read(ptr); len as usize } @@ -313,6 +348,9 @@ impl ColdString { #[inline] unsafe fn decode_heap(&self) -> &[u8] { let ptr = self.heap_ptr(); + if ptr.is_null() { + return &EIGHT_NUL; + } let (len, header) = VarInt::read(ptr); let data = ptr.add(header); slice::from_raw_parts(data, len) @@ -382,11 +420,15 @@ impl Deref for ColdString { impl Drop for ColdString { fn drop(&mut self) { if !self.is_inline() { + let ptr = self.heap_ptr(); + if ptr.is_null() { + return; + } unsafe { - let ptr = self.heap_ptr(); let (len, header) = VarInt::read(ptr); let total = header + len; let layout = Layout::from_size_align(total, HEAP_ALIGN).unwrap(); + // SAFETY: if ptr is non-null then it was allocated by alloc() in new_heap() dealloc(ptr as *mut u8, layout); } } @@ -395,13 +437,12 @@ impl Drop for ColdString { impl Clone for ColdString { fn clone(&self) -> Self { - match self.is_inline() { - true => unsafe { - Self { - encoded: self.ptr(), - } - }, - false => Self::new_heap(self.as_str()), + if self.is_inline() || self.heap_ptr().is_null() { + let ptr = self.ptr(); + let encoded = unsafe { NonNull::new_unchecked(ptr as *mut _) }; + Self { encoded } + } else { + Self::new_heap(self.as_str()) } } } @@ -409,7 +450,7 @@ impl Clone for ColdString { impl PartialEq for ColdString { fn eq(&self, other: &Self) -> bool { match (self.is_inline(), other.is_inline()) { - (true, true) => unsafe { self.ptr() == other.ptr() }, + (true, true) => self.ptr() == other.ptr(), (false, false) => unsafe { self.decode_heap() == other.decode_heap() }, _ => false, } @@ -625,7 +666,9 @@ mod tests { fn assert_correct(s: &str) { let cs = ColdString::new(s); - assert_eq!(s.len() <= mem::size_of::(), cs.is_inline()); + if s.as_bytes() != &[0u8; WIDTH] { + assert_eq!(s.len() <= mem::size_of::(), cs.is_inline()); + } assert_eq!(cs.len(), s.len()); assert_eq!(cs.as_bytes(), s.as_bytes()); assert_eq!(cs.as_str(), s); @@ -741,4 +784,20 @@ mod tests { } } } + + #[test] + fn test_const_8nul_vs_non_const() { + let nul8 = str::from_utf8(&EIGHT_NUL).unwrap(); + let const8 = ColdString::new_inline_const(nul8); + let non_const = ColdString::new(nul8); + let cloned = non_const.clone(); + assert_eq!(const8.ptr(), non_const.ptr()); + assert_eq!(const8.ptr(), cloned.ptr()); + assert_eq!(non_const.heap_ptr(), ptr::null()); + // check that a null pointer will return a str pointing to EIGHT_NUL + assert_eq!( + &const8.as_str().as_bytes()[0] as *const u8, + (&EIGHT_NUL) as *const u8 + ); + } }