From 7767f173bc3809a7d9eeb7d1cdc0185b4c710253 Mon Sep 17 00:00:00 2001 From: Piotr Dulikowski Date: Fri, 20 Oct 2023 17:34:25 +0200 Subject: [PATCH 1/3] types: introduce read_value Introduce the `read_value` function which is able to read a [value], as specified in the CQL protocol. It will be used in the next commit, in order to make the interface of the SerializedValue iterators more correct. --- scylla-cql/src/frame/types.rs | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/scylla-cql/src/frame/types.rs b/scylla-cql/src/frame/types.rs index 672fe2f97e..5de8124111 100644 --- a/scylla-cql/src/frame/types.rs +++ b/scylla-cql/src/frame/types.rs @@ -104,6 +104,23 @@ impl From for ParseError { } } +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum RawValue<'a> { + Null, + Unset, + Value(&'a [u8]), +} + +impl<'a> RawValue<'a> { + #[inline] + pub fn as_value(&self) -> Option<&'a [u8]> { + match self { + RawValue::Value(v) => Some(v), + RawValue::Null | RawValue::Unset => None, + } + } +} + fn read_raw_bytes<'a>(count: usize, buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> { if buf.len() < count { return Err(ParseError::BadIncomingData(format!( @@ -218,6 +235,22 @@ pub fn read_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> { Ok(v) } +pub fn read_value<'a>(buf: &mut &'a [u8]) -> Result, ParseError> { + let len = read_int(buf)?; + match len { + -2 => Ok(RawValue::Unset), + -1 => Ok(RawValue::Null), + len if len >= 0 => { + let v = read_raw_bytes(len as usize, buf)?; + Ok(RawValue::Value(v)) + } + len => Err(ParseError::BadIncomingData(format!( + "invalid value length: {}", + len, + ))), + } +} + pub fn read_short_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> { let len = read_short_length(buf)?; let v = read_raw_bytes(len, buf)?; From 5e544c9ffc4c4fe20adfa58cbf461dfdf12e1365 Mon Sep 17 00:00:00 2001 From: Piotr Dulikowski Date: Fri, 20 Oct 2023 17:10:16 +0200 Subject: [PATCH 2/3] frame: adjust serialized values iterator to return RawValue Currently, the SerializedValues' `iter()` method treats both null and unset values as None, and `iter_name_value_pairs()` just assumes that values are never null/unset and panics if they are. Make the interface more correct by adjusting both methods to return RawValue. The iterators will be used in the next commit to implement the fallback that allows to implement `SerializeRow`/`SerializeCql` via legacy `ValueList`/`Value` traits. --- scylla-cql/src/frame/value.rs | 11 ++++--- scylla-cql/src/frame/value_tests.rs | 37 ++++++++++++++-------- scylla/src/statement/prepared_statement.rs | 3 +- scylla/src/transport/partitioner.rs | 8 +++-- 4 files changed, 37 insertions(+), 22 deletions(-) diff --git a/scylla-cql/src/frame/value.rs b/scylla-cql/src/frame/value.rs index db67b4fab8..4faa8df501 100644 --- a/scylla-cql/src/frame/value.rs +++ b/scylla-cql/src/frame/value.rs @@ -16,6 +16,7 @@ use chrono::{DateTime, NaiveDate, NaiveTime, TimeZone, Utc}; use super::response::result::CqlValue; use super::types::vint_encode; +use super::types::RawValue; #[cfg(feature = "secret")] use secrecy::{ExposeSecret, Secret, Zeroize}; @@ -366,7 +367,7 @@ impl SerializedValues { Ok(()) } - pub fn iter(&self) -> impl Iterator> { + pub fn iter(&self) -> impl Iterator { SerializedValuesIterator { serialized_values: &self.serialized_values, contains_names: self.contains_names, @@ -410,7 +411,7 @@ impl SerializedValues { }) } - pub fn iter_name_value_pairs(&self) -> impl Iterator, &[u8])> { + pub fn iter_name_value_pairs(&self) -> impl Iterator, RawValue)> { let mut buf = &self.serialized_values[..]; (0..self.values_num).map(move |_| { // `unwrap()`s here are safe, as we assume type-safety: if `SerializedValues` exits, @@ -418,7 +419,7 @@ impl SerializedValues { let name = self .contains_names .then(|| types::read_string(&mut buf).unwrap()); - let serialized = types::read_bytes(&mut buf).unwrap(); + let serialized = types::read_value(&mut buf).unwrap(); (name, serialized) }) } @@ -431,7 +432,7 @@ pub struct SerializedValuesIterator<'a> { } impl<'a> Iterator for SerializedValuesIterator<'a> { - type Item = Option<&'a [u8]>; + type Item = RawValue<'a>; fn next(&mut self) -> Option { if self.serialized_values.is_empty() { @@ -443,7 +444,7 @@ impl<'a> Iterator for SerializedValuesIterator<'a> { types::read_short_bytes(&mut self.serialized_values).expect("badly encoded value name"); } - Some(types::read_bytes_opt(&mut self.serialized_values).expect("badly encoded value")) + Some(types::read_value(&mut self.serialized_values).expect("badly encoded value")) } } diff --git a/scylla-cql/src/frame/value_tests.rs b/scylla-cql/src/frame/value_tests.rs index 633e3ce7b8..003ff0116a 100644 --- a/scylla-cql/src/frame/value_tests.rs +++ b/scylla-cql/src/frame/value_tests.rs @@ -1,4 +1,4 @@ -use crate::frame::value::BatchValuesIterator; +use crate::frame::{types::RawValue, value::BatchValuesIterator}; use super::value::{ BatchValues, CqlDate, CqlTime, CqlTimestamp, MaybeUnset, SerializeValuesError, @@ -421,7 +421,10 @@ fn serialized_values() { values.write_to_request(&mut request); assert_eq!(request, vec![0, 1, 0, 0, 0, 1, 8]); - assert_eq!(values.iter().collect::>(), vec![Some([8].as_ref())]); + assert_eq!( + values.iter().collect::>(), + vec![RawValue::Value([8].as_ref())] + ); } // Add second value @@ -436,7 +439,10 @@ fn serialized_values() { assert_eq!( values.iter().collect::>(), - vec![Some([8].as_ref()), Some([0, 16].as_ref())] + vec![ + RawValue::Value([8].as_ref()), + RawValue::Value([0, 16].as_ref()) + ] ); } @@ -468,7 +474,10 @@ fn serialized_values() { assert_eq!( values.iter().collect::>(), - vec![Some([8].as_ref()), Some([0, 16].as_ref())] + vec![ + RawValue::Value([8].as_ref()), + RawValue::Value([0, 16].as_ref()) + ] ); } } @@ -498,9 +507,9 @@ fn slice_value_list() { assert_eq!( serialized.iter().collect::>(), vec![ - Some([0, 0, 0, 1].as_ref()), - Some([0, 0, 0, 2].as_ref()), - Some([0, 0, 0, 3].as_ref()) + RawValue::Value([0, 0, 0, 1].as_ref()), + RawValue::Value([0, 0, 0, 2].as_ref()), + RawValue::Value([0, 0, 0, 3].as_ref()) ] ); } @@ -515,9 +524,9 @@ fn vec_value_list() { assert_eq!( serialized.iter().collect::>(), vec![ - Some([0, 0, 0, 1].as_ref()), - Some([0, 0, 0, 2].as_ref()), - Some([0, 0, 0, 3].as_ref()) + RawValue::Value([0, 0, 0, 1].as_ref()), + RawValue::Value([0, 0, 0, 2].as_ref()), + RawValue::Value([0, 0, 0, 3].as_ref()) ] ); } @@ -530,7 +539,7 @@ fn tuple_value_list() { let serialized_vals: Vec = serialized .iter() - .map(|o: Option<&[u8]>| o.unwrap()[0]) + .map(|o: RawValue| o.as_value().unwrap()[0]) .collect(); let expected: Vec = expected.collect(); @@ -604,9 +613,9 @@ fn ref_value_list() { assert_eq!( serialized.iter().collect::>(), vec![ - Some([0, 0, 0, 1].as_ref()), - Some([0, 0, 0, 2].as_ref()), - Some([0, 0, 0, 3].as_ref()) + RawValue::Value([0, 0, 0, 1].as_ref()), + RawValue::Value([0, 0, 0, 2].as_ref()), + RawValue::Value([0, 0, 0, 3].as_ref()) ] ); } diff --git a/scylla/src/statement/prepared_statement.rs b/scylla/src/statement/prepared_statement.rs index 22f34e60a2..58d8b9ea3d 100644 --- a/scylla/src/statement/prepared_statement.rs +++ b/scylla/src/statement/prepared_statement.rs @@ -1,5 +1,6 @@ use bytes::{Bytes, BytesMut}; use scylla_cql::errors::{BadQuery, QueryError}; +use scylla_cql::frame::types::RawValue; use smallvec::{smallvec, SmallVec}; use std::convert::TryInto; use std::sync::Arc; @@ -399,7 +400,7 @@ impl<'ps> PartitionKey<'ps> { PartitionKeyExtractionError::NoPkIndexValue(pk_index.index, bound_values.len()) })?; // Add it in sequence order to pk_values - if let Some(v) = next_val { + if let RawValue::Value(v) = next_val { let spec = &prepared_metadata.col_specs[pk_index.index as usize]; pk_values[pk_index.sequence as usize] = Some((v, spec)); } diff --git a/scylla/src/transport/partitioner.rs b/scylla/src/transport/partitioner.rs index 9c8a542325..4526715ab2 100644 --- a/scylla/src/transport/partitioner.rs +++ b/scylla/src/transport/partitioner.rs @@ -1,4 +1,5 @@ use bytes::Buf; +use scylla_cql::frame::types::RawValue; use std::num::Wrapping; use crate::{ @@ -343,11 +344,14 @@ pub fn calculate_token_for_partition_key( if serialized_partition_key_values.len() == 1 { let val = serialized_partition_key_values.iter().next().unwrap(); - if let Some(val) = val { + if let RawValue::Value(val) = val { partitioner_hasher.write(val); } } else { - for val in serialized_partition_key_values.iter().flatten() { + for val in serialized_partition_key_values + .iter() + .filter_map(|rv| rv.as_value()) + { let val_len_u16: u16 = val .len() .try_into() From 29a37b492a51537f5ef4ee9ddcc333c436d049a3 Mon Sep 17 00:00:00 2001 From: Piotr Dulikowski Date: Fri, 20 Oct 2023 19:13:37 +0200 Subject: [PATCH 3/3] types: serialize: introduce new helpers for writing values and adjust interfaces Currently, `SerializeRow` and `SerializeCql` traits are just given a mutable reference to a Vec and asked to append their CQL representation to the end. While simple, there are some issues with the interface: - The serialize method has access to the serialized representation of the values that were appended before it. It's not necessary for a correct implementation to have access to it. - Implementors technically can append any byte sequence to the end, but actually are expected to produce a CQL [value] containing the serialized value. While the `SerializeRow` and `SerializeCql` traits are not generally meant to be manually implemented by the users, we can make the interface easier to use and harder to misuse by making it append-only, restricting what the users are allowed to append and requiring the users to append anything by using a dash of type-level magic. Introduce `RowWriter` and `CellWriter` traits which satisfy the above wishes and constraints, and pass them instead of Vec in `SerializeRow` and `SerializeCql`. The new traits have two implementations - a Vec backed one that actually appends the bytes given to it, and a usize-backed one which just measures the length of the output without writing anything. Passing the latter before doing the actual serialization will allow to preallocate the right amount of bytes and then serialize without reallocations. It should be measured whether the reallocation cost always outweighs the calculation cost before implementing this optimization. --- scylla-cql/src/types/serialize/mod.rs | 6 + scylla-cql/src/types/serialize/row.rs | 161 +++++++- scylla-cql/src/types/serialize/value.rs | 110 +++++- scylla-cql/src/types/serialize/writers.rs | 426 ++++++++++++++++++++++ 4 files changed, 688 insertions(+), 15 deletions(-) create mode 100644 scylla-cql/src/types/serialize/writers.rs diff --git a/scylla-cql/src/types/serialize/mod.rs b/scylla-cql/src/types/serialize/mod.rs index 0cda84e252..511a7104b1 100644 --- a/scylla-cql/src/types/serialize/mod.rs +++ b/scylla-cql/src/types/serialize/mod.rs @@ -2,5 +2,11 @@ use std::{any::Any, sync::Arc}; pub mod row; pub mod value; +pub mod writers; + +pub use writers::{ + BufBackedCellValueBuilder, BufBackedCellWriter, BufBackedRowWriter, CellValueBuilder, + CellWriter, CountingWriter, RowWriter, +}; type SerializationError = Arc; diff --git a/scylla-cql/src/types/serialize/row.rs b/scylla-cql/src/types/serialize/row.rs index 2e9832412d..fe91585c8c 100644 --- a/scylla-cql/src/types/serialize/row.rs +++ b/scylla-cql/src/types/serialize/row.rs @@ -1,20 +1,25 @@ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; + +use thiserror::Error; -use crate::frame::response::result::ColumnSpec; use crate::frame::value::ValueList; +use crate::frame::{response::result::ColumnSpec, types::RawValue}; -use super::SerializationError; +use super::{CellWriter, RowWriter, SerializationError}; +/// Contains information needed to serialize a row. pub struct RowSerializationContext<'a> { columns: &'a [ColumnSpec], } impl<'a> RowSerializationContext<'a> { + /// Returns column/bind marker specifications for given query. #[inline] pub fn columns(&self) -> &'a [ColumnSpec] { self.columns } + /// Looks up and returns a column/bind marker by name. // TODO: change RowSerializationContext to make this faster #[inline] pub fn column_by_name(&self, target: &str) -> Option<&ColumnSpec> { @@ -23,11 +28,25 @@ impl<'a> RowSerializationContext<'a> { } pub trait SerializeRow { + /// Checks if it _might_ be possible to serialize the row according to the + /// information in the context. + /// + /// This function is intended to serve as an optimization in the future, + /// if we were ever to introduce prepared statements parametrized by types. + /// + /// Sometimes, a row cannot be fully type checked right away without knowing + /// the exact values of the columns (e.g. when deserializing to `CqlValue`), + /// but it's fine to do full type checking later in `serialize`. fn preliminary_type_check(ctx: &RowSerializationContext<'_>) -> Result<(), SerializationError>; - fn serialize( + + /// Serializes the row according to the information in the given context. + /// + /// The function may assume that `preliminary_type_check` was called, + /// though it must not do anything unsafe if this assumption does not hold. + fn serialize( &self, ctx: &RowSerializationContext<'_>, - out: &mut Vec, + writer: &mut W, ) -> Result<(), SerializationError>; } @@ -38,12 +57,134 @@ impl SerializeRow for T { Ok(()) } - fn serialize( + fn serialize( &self, - _ctx: &RowSerializationContext<'_>, - out: &mut Vec, + ctx: &RowSerializationContext<'_>, + writer: &mut W, ) -> Result<(), SerializationError> { - self.write_to_request(out) - .map_err(|err| Arc::new(err) as SerializationError) + serialize_legacy_row(self, ctx, writer) + } +} + +pub fn serialize_legacy_row( + r: &T, + ctx: &RowSerializationContext<'_>, + writer: &mut impl RowWriter, +) -> Result<(), SerializationError> { + let serialized = + ::serialized(r).map_err(|err| Arc::new(err) as SerializationError)?; + + let mut append_value = |value: RawValue| { + let cell_writer = writer.make_cell_writer(); + let _proof = match value { + RawValue::Null => cell_writer.set_null(), + RawValue::Unset => cell_writer.set_unset(), + RawValue::Value(v) => cell_writer.set_value(v), + }; + }; + + if !serialized.has_names() { + serialized.iter().for_each(append_value); + } else { + let values_by_name = serialized + .iter_name_value_pairs() + .map(|(k, v)| (k.unwrap(), v)) + .collect::>(); + + for col in ctx.columns() { + let val = values_by_name.get(col.name.as_str()).ok_or_else(|| { + Arc::new(ValueListToSerializeRowAdapterError::NoBindMarkerWithName { + name: col.name.clone(), + }) as SerializationError + })?; + append_value(*val); + } + } + + Ok(()) +} + +#[derive(Error, Debug)] +pub enum ValueListToSerializeRowAdapterError { + #[error("There is no bind marker with name {name}, but a value for it was provided")] + NoBindMarkerWithName { name: String }, +} + +#[cfg(test)] +mod tests { + use crate::frame::response::result::{ColumnSpec, ColumnType, TableSpec}; + use crate::frame::value::{MaybeUnset, SerializedValues, ValueList}; + use crate::types::serialize::BufBackedRowWriter; + + use super::{RowSerializationContext, SerializeRow}; + + fn col_spec(name: &str, typ: ColumnType) -> ColumnSpec { + ColumnSpec { + table_spec: TableSpec { + ks_name: "ks".to_string(), + table_name: "tbl".to_string(), + }, + name: name.to_string(), + typ, + } + } + + #[test] + fn test_legacy_fallback() { + let row = ( + 1i32, + "Ala ma kota", + None::, + MaybeUnset::Unset::, + ); + + let mut legacy_data = Vec::new(); + <_ as ValueList>::write_to_request(&row, &mut legacy_data).unwrap(); + + let mut new_data = Vec::new(); + let mut new_data_writer = BufBackedRowWriter::new(&mut new_data); + let ctx = RowSerializationContext { columns: &[] }; + <_ as SerializeRow>::serialize(&row, &ctx, &mut new_data_writer).unwrap(); + assert_eq!(new_data_writer.value_count(), 4); + + // Skip the value count + assert_eq!(&legacy_data[2..], new_data); + } + + #[test] + fn test_legacy_fallback_with_names() { + let sorted_row = ( + 1i32, + "Ala ma kota", + None::, + MaybeUnset::Unset::, + ); + + let mut sorted_row_data = Vec::new(); + <_ as ValueList>::write_to_request(&sorted_row, &mut sorted_row_data).unwrap(); + + let mut unsorted_row = SerializedValues::new(); + unsorted_row.add_named_value("a", &1i32).unwrap(); + unsorted_row.add_named_value("b", &"Ala ma kota").unwrap(); + unsorted_row + .add_named_value("d", &MaybeUnset::Unset::) + .unwrap(); + unsorted_row.add_named_value("c", &None::).unwrap(); + + let mut unsorted_row_data = Vec::new(); + let mut unsorted_row_data_writer = BufBackedRowWriter::new(&mut unsorted_row_data); + let ctx = RowSerializationContext { + columns: &[ + col_spec("a", ColumnType::Int), + col_spec("b", ColumnType::Text), + col_spec("c", ColumnType::BigInt), + col_spec("d", ColumnType::Ascii), + ], + }; + <_ as SerializeRow>::serialize(&unsorted_row, &ctx, &mut unsorted_row_data_writer).unwrap(); + assert_eq!(unsorted_row_data_writer.value_count(), 4); + + // Skip the value count + assert_eq!(&sorted_row_data[2..], unsorted_row_data); } } diff --git a/scylla-cql/src/types/serialize/value.rs b/scylla-cql/src/types/serialize/value.rs index 43eb9ef738..25d605d13d 100644 --- a/scylla-cql/src/types/serialize/value.rs +++ b/scylla-cql/src/types/serialize/value.rs @@ -1,13 +1,32 @@ use std::sync::Arc; +use thiserror::Error; + use crate::frame::response::result::ColumnType; use crate::frame::value::Value; -use super::SerializationError; +use super::{CellWriter, SerializationError}; pub trait SerializeCql { + /// Given a CQL type, checks if it _might_ be possible to serialize to that type. + /// + /// This function is intended to serve as an optimization in the future, + /// if we were ever to introduce prepared statements parametrized by types. + /// + /// Some types cannot be type checked without knowing the exact value, + /// this is the case e.g. for `CqlValue`. It's also fine to do it later in + /// `serialize`. fn preliminary_type_check(typ: &ColumnType) -> Result<(), SerializationError>; - fn serialize(&self, typ: &ColumnType, buf: &mut Vec) -> Result<(), SerializationError>; + + /// Serializes the value to given CQL type. + /// + /// The function may assume that `preliminary_type_check` was called, + /// though it must not do anything unsafe if this assumption does not hold. + fn serialize( + &self, + typ: &ColumnType, + writer: W, + ) -> Result; } impl SerializeCql for T { @@ -15,8 +34,89 @@ impl SerializeCql for T { Ok(()) } - fn serialize(&self, _typ: &ColumnType, buf: &mut Vec) -> Result<(), SerializationError> { - self.serialize(buf) - .map_err(|err| Arc::new(err) as SerializationError) + fn serialize( + &self, + _typ: &ColumnType, + writer: W, + ) -> Result { + serialize_legacy_value(self, writer) + } +} + +pub fn serialize_legacy_value( + v: &T, + writer: W, +) -> Result { + // It's an inefficient and slightly tricky but correct implementation. + let mut buf = Vec::new(); + ::serialize(v, &mut buf).map_err(|err| Arc::new(err) as SerializationError)?; + + // Analyze the output. + // All this dance shows how unsafe our previous interface was... + if buf.len() < 4 { + return Err(Arc::new(ValueToSerializeCqlAdapterError::TooShort { + size: buf.len(), + })); + } + + let (len_bytes, contents) = buf.split_at(4); + let len = i32::from_be_bytes(len_bytes.try_into().unwrap()); + match len { + -2 => Ok(writer.set_unset()), + -1 => Ok(writer.set_null()), + len if len >= 0 => { + if contents.len() != len as usize { + Err(Arc::new( + ValueToSerializeCqlAdapterError::DeclaredVsActualSizeMismatch { + declared: len as usize, + actual: contents.len(), + }, + )) + } else { + Ok(writer.set_value(contents)) + } + } + _ => Err(Arc::new( + ValueToSerializeCqlAdapterError::InvalidDeclaredSize { size: len }, + )), + } +} + +#[derive(Error, Debug)] +pub enum ValueToSerializeCqlAdapterError { + #[error("Output produced by the Value trait is too short to be considered a value: {size} < 4 minimum bytes")] + TooShort { size: usize }, + + #[error("Mismatch between the declared value size vs. actual size: {declared} != {actual}")] + DeclaredVsActualSizeMismatch { declared: usize, actual: usize }, + + #[error("Invalid declared value size: {size}")] + InvalidDeclaredSize { size: i32 }, +} + +#[cfg(test)] +mod tests { + use crate::frame::response::result::ColumnType; + use crate::frame::value::{MaybeUnset, Value}; + use crate::types::serialize::BufBackedCellWriter; + + use super::SerializeCql; + + fn check_compat(v: V) { + let mut legacy_data = Vec::new(); + ::serialize(&v, &mut legacy_data).unwrap(); + + let mut new_data = Vec::new(); + let new_data_writer = BufBackedCellWriter::new(&mut new_data); + ::serialize(&v, &ColumnType::Int, new_data_writer).unwrap(); + + assert_eq!(legacy_data, new_data); + } + + #[test] + fn test_legacy_fallback() { + check_compat(123i32); + check_compat(None::); + check_compat(MaybeUnset::Unset::); } } diff --git a/scylla-cql/src/types/serialize/writers.rs b/scylla-cql/src/types/serialize/writers.rs new file mode 100644 index 0000000000..cafd5442fc --- /dev/null +++ b/scylla-cql/src/types/serialize/writers.rs @@ -0,0 +1,426 @@ +//! Contains types and traits used for safe serialization of values for a CQL statement. + +/// An interface that facilitates writing values for a CQL query. +pub trait RowWriter { + type CellWriter<'a>: CellWriter + where + Self: 'a; + + /// Appends a new value to the sequence and returns an object that allows + /// to fill it in. + fn make_cell_writer(&mut self) -> Self::CellWriter<'_>; +} + +/// Represents a handle to a CQL value that needs to be written into. +/// +/// The writer can either be transformed into a ready value right away +/// (via [`set_null`](CellWriter::set_null), +/// [`set_unset`](CellWriter::set_unset) +/// or [`set_value`](CellWriter::set_value) or transformed into +/// the [`CellWriter::ValueBuilder`] in order to gradually initialize +/// the value when the contents are not available straight away. +/// +/// After the value is fully initialized, the handle is consumed and +/// a [`WrittenCellProof`](CellWriter::WrittenCellProof) object is returned +/// in its stead. This is a type-level proof that the value was fully initialized +/// and is used in [`SerializeCql::serialize`](`super::value::SerializeCql::serialize`) +/// in order to enforce the implementor to fully initialize the provided handle +/// to CQL value. +/// +/// Dropping this type without calling any of its methods will result +/// in nothing being written. +pub trait CellWriter { + /// The type of the value builder, returned by the [`CellWriter::set_value`] + /// method. + type ValueBuilder: CellValueBuilder; + + /// An object that serves as a proof that the cell was fully initialized. + /// + /// This type is returned by [`set_null`](CellWriter::set_null), + /// [`set_unset`](CellWriter::set_unset), + /// [`set_value`](CellWriter::set_value) + /// and also [`CellValueBuilder::finish`] - generally speaking, after + /// the value is fully initialized and the `CellWriter` is destroyed. + /// + /// The purpose of this type is to enforce the contract of + /// [`SerializeCql::serialize`](super::value::SerializeCql::serialize): either + /// the method succeeds and returns a proof that it serialized itself + /// into the given value, or it fails and returns an error or panics. + /// The exact type of [`WrittenCellProof`](CellWriter::WrittenCellProof) + /// is not important as the value is not used at all - it's only + /// a compile-time check. + type WrittenCellProof; + + /// Sets this value to be null, consuming this object. + fn set_null(self) -> Self::WrittenCellProof; + + /// Sets this value to represent an unset value, consuming this object. + fn set_unset(self) -> Self::WrittenCellProof; + + /// Sets this value to a non-zero, non-unset value with given contents. + /// + /// Prefer this to [`into_value_builder`](CellWriter::into_value_builder) + /// if you have all of the contents of the value ready up front (e.g. for + /// fixed size types). + fn set_value(self, contents: &[u8]) -> Self::WrittenCellProof; + + /// Turns this writter into a [`CellValueBuilder`] which can be used + /// to gradually initialize the CQL value. + /// + /// This method should be used if you don't have all of the data + /// up front, e.g. when serializing compound types such as collections + /// or UDTs. + fn into_value_builder(self) -> Self::ValueBuilder; +} + +/// Allows appending bytes to a non-null, non-unset cell. +/// +/// This object needs to be dropped in order for the value to be correctly +/// serialized. Failing to drop this value will result in a payload that will +/// not be parsed by the database correctly, but otherwise should not cause +/// data to be misinterpreted. +pub trait CellValueBuilder { + type SubCellWriter<'a>: CellWriter + where + Self: 'a; + + type WrittenCellProof; + + /// Appends raw bytes to this cell. + fn append_bytes(&mut self, bytes: &[u8]); + + /// Appends a sub-value to the end of the current contents of the cell + /// and returns an object that allows to fill it in. + fn make_sub_writer(&mut self) -> Self::SubCellWriter<'_>; + + /// Finishes serializing the value. + fn finish(self) -> Self::WrittenCellProof; +} + +/// A row writer backed by a buffer (vec). +pub struct BufBackedRowWriter<'buf> { + // Buffer that this value should be serialized to. + buf: &'buf mut Vec, + + // Number of values written so far. + value_count: u16, +} + +impl<'buf> BufBackedRowWriter<'buf> { + /// Creates a new row writer based on an existing Vec. + /// + /// The newly created row writer will append data to the end of the vec. + #[inline] + pub fn new(buf: &'buf mut Vec) -> Self { + Self { + buf, + value_count: 0, + } + } + + /// Returns the number of values that were written so far. + #[inline] + pub fn value_count(&self) -> u16 { + self.value_count + } +} + +impl<'buf> RowWriter for BufBackedRowWriter<'buf> { + type CellWriter<'a> = BufBackedCellWriter<'a> where Self: 'a; + + #[inline] + fn make_cell_writer(&mut self) -> Self::CellWriter<'_> { + self.value_count = self + .value_count + .checked_add(1) + .expect("tried to serialize too many values for a query (more than u16::MAX)"); + BufBackedCellWriter::new(self.buf) + } +} + +/// A cell writer backed by a buffer (vec). +pub struct BufBackedCellWriter<'buf> { + buf: &'buf mut Vec, +} + +impl<'buf> BufBackedCellWriter<'buf> { + /// Creates a new cell writer based on an existing Vec. + /// + /// The newly created row writer will append data to the end of the vec. + #[inline] + pub fn new(buf: &'buf mut Vec) -> Self { + BufBackedCellWriter { buf } + } +} + +impl<'buf> CellWriter for BufBackedCellWriter<'buf> { + type ValueBuilder = BufBackedCellValueBuilder<'buf>; + + type WrittenCellProof = (); + + #[inline] + fn set_null(self) { + self.buf.extend_from_slice(&(-1i32).to_be_bytes()); + } + + #[inline] + fn set_unset(self) { + self.buf.extend_from_slice(&(-2i32).to_be_bytes()); + } + + #[inline] + fn set_value(self, bytes: &[u8]) { + let value_len: i32 = bytes + .len() + .try_into() + .expect("value is too big to fit into a CQL [bytes] object (larger than i32::MAX)"); + self.buf.extend_from_slice(&value_len.to_be_bytes()); + self.buf.extend_from_slice(bytes); + } + + #[inline] + fn into_value_builder(self) -> Self::ValueBuilder { + BufBackedCellValueBuilder::new(self.buf) + } +} + +/// A cell value builder backed by a buffer (vec). +pub struct BufBackedCellValueBuilder<'buf> { + // Buffer that this value should be serialized to. + buf: &'buf mut Vec, + + // Starting position of the value in the buffer. + starting_pos: usize, +} + +impl<'buf> BufBackedCellValueBuilder<'buf> { + #[inline] + fn new(buf: &'buf mut Vec) -> Self { + // "Length" of a [bytes] frame can either be a non-negative i32, + // -1 (null) or -1 (not set). Push an invalid value here. It will be + // overwritten eventually either by set_null, set_unset or Drop. + // If the CellSerializer is not dropped as it should, this will trigger + // an error on the DB side and the serialized data + // won't be misinterpreted. + let starting_pos = buf.len(); + buf.extend_from_slice(&(-3i32).to_be_bytes()); + BufBackedCellValueBuilder { buf, starting_pos } + } +} + +impl<'buf> CellValueBuilder for BufBackedCellValueBuilder<'buf> { + type SubCellWriter<'a> = BufBackedCellWriter<'a> + where + Self: 'a; + + type WrittenCellProof = (); + + #[inline] + fn append_bytes(&mut self, bytes: &[u8]) { + self.buf.extend_from_slice(bytes); + } + + #[inline] + fn make_sub_writer(&mut self) -> Self::SubCellWriter<'_> { + BufBackedCellWriter::new(self.buf) + } + + #[inline] + fn finish(self) { + // TODO: Should this panic, or should we catch this error earlier? + // Vec will panic anyway if we overflow isize, so at least this + // behavior is consistent with what the stdlib does. + let value_len: i32 = (self.buf.len() - self.starting_pos - 4) + .try_into() + .expect("value is too big to fit into a CQL [bytes] object (larger than i32::MAX)"); + self.buf[self.starting_pos..self.starting_pos + 4] + .copy_from_slice(&value_len.to_be_bytes()); + } +} + +/// A writer that does not actually write anything, just counts the bytes. +/// +/// It can serve as a: +/// +/// - [`RowWriter`] +/// - [`CellWriter`] +/// - [`CellValueBuilder`] +pub struct CountingWriter<'buf> { + buf: &'buf mut usize, +} + +impl<'buf> CountingWriter<'buf> { + /// Creates a new writer which increments the counter under given reference + /// when bytes are appended. + #[inline] + fn new(buf: &'buf mut usize) -> Self { + CountingWriter { buf } + } +} + +impl<'buf> RowWriter for CountingWriter<'buf> { + type CellWriter<'a> = CountingWriter<'a> where Self: 'a; + + #[inline] + fn make_cell_writer(&mut self) -> Self::CellWriter<'_> { + CountingWriter::new(self.buf) + } +} + +impl<'buf> CellWriter for CountingWriter<'buf> { + type ValueBuilder = CountingWriter<'buf>; + + type WrittenCellProof = (); + + #[inline] + fn set_null(self) { + *self.buf += 4; + } + + #[inline] + fn set_unset(self) { + *self.buf += 4; + } + + #[inline] + fn set_value(self, contents: &[u8]) { + *self.buf += 4 + contents.len(); + } + + #[inline] + fn into_value_builder(self) -> Self::ValueBuilder { + *self.buf += 4; + CountingWriter::new(self.buf) + } +} + +impl<'buf> CellValueBuilder for CountingWriter<'buf> { + type SubCellWriter<'a> = CountingWriter<'a> + where + Self: 'a; + + type WrittenCellProof = (); + + #[inline] + fn append_bytes(&mut self, bytes: &[u8]) { + *self.buf += bytes.len(); + } + + #[inline] + fn make_sub_writer(&mut self) -> Self::SubCellWriter<'_> { + CountingWriter::new(self.buf) + } + + #[inline] + fn finish(self) -> Self::WrittenCellProof {} +} + +#[cfg(test)] +mod tests { + use super::{ + BufBackedCellWriter, BufBackedRowWriter, CellValueBuilder, CellWriter, CountingWriter, + RowWriter, + }; + + // We want to perform the same computation for both buf backed writer + // and counting writer, but Rust does not support generic closures. + // This trait comes to the rescue. + trait CellSerializeCheck { + fn check(&self, writer: W); + } + + fn check_cell_serialize(c: C) -> Vec { + let mut data = Vec::new(); + let writer = BufBackedCellWriter::new(&mut data); + c.check(writer); + + let mut byte_count = 0usize; + let counting_writer = CountingWriter::new(&mut byte_count); + c.check(counting_writer); + + assert_eq!(data.len(), byte_count); + data + } + + #[test] + fn test_cell_writer() { + struct Check; + impl CellSerializeCheck for Check { + fn check(&self, writer: W) { + let mut sub_writer = writer.into_value_builder(); + sub_writer.make_sub_writer().set_null(); + sub_writer.make_sub_writer().set_value(&[1, 2, 3, 4]); + sub_writer.make_sub_writer().set_unset(); + sub_writer.finish(); + } + } + + let data = check_cell_serialize(Check); + assert_eq!( + data, + [ + 0, 0, 0, 16, // Length of inner data is 16 + 255, 255, 255, 255, // Null (encoded as -1) + 0, 0, 0, 4, 1, 2, 3, 4, // Four byte value + 255, 255, 255, 254, // Unset (encoded as -2) + ] + ); + } + + #[test] + fn test_poisoned_appender() { + struct Check; + impl CellSerializeCheck for Check { + fn check(&self, writer: W) { + let _ = writer.into_value_builder(); + } + } + + let data = check_cell_serialize(Check); + assert_eq!( + data, + [ + 255, 255, 255, 253, // Invalid value + ] + ); + } + + trait RowSerializeCheck { + fn check(&self, writer: &mut W); + } + + fn check_row_serialize(c: C) -> Vec { + let mut data = Vec::new(); + let mut writer = BufBackedRowWriter::new(&mut data); + c.check(&mut writer); + + let mut byte_count = 0usize; + let mut counting_writer = CountingWriter::new(&mut byte_count); + c.check(&mut counting_writer); + + assert_eq!(data.len(), byte_count); + data + } + + #[test] + fn test_row_writer() { + struct Check; + impl RowSerializeCheck for Check { + fn check(&self, writer: &mut W) { + writer.make_cell_writer().set_null(); + writer.make_cell_writer().set_value(&[1, 2, 3, 4]); + writer.make_cell_writer().set_unset(); + } + } + + let data = check_row_serialize(Check); + assert_eq!( + data, + [ + 255, 255, 255, 255, // Null (encoded as -1) + 0, 0, 0, 4, 1, 2, 3, 4, // Four byte value + 255, 255, 255, 254, // Unset (encoded as -2) + ] + ) + } +}