Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

types: serialize: constrain the new serialization traits to make them easier and safer to use #855

Merged
merged 3 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions scylla-cql/src/frame/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,23 @@ impl From<std::array::TryFromSliceError> for ParseError {
}
}

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum RawValue<'a> {
Null,
Unset,
Value(&'a [u8]),
}

impl<'a> RawValue<'a> {
#[inline]
pub fn as_value(&self) -> Option<&'a [u8]> {
match self {
RawValue::Value(v) => Some(v),
RawValue::Null | RawValue::Unset => None,
}
}
}

fn read_raw_bytes<'a>(count: usize, buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> {
if buf.len() < count {
return Err(ParseError::BadIncomingData(format!(
Expand Down Expand Up @@ -218,6 +235,22 @@ pub fn read_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> {
Ok(v)
}

pub fn read_value<'a>(buf: &mut &'a [u8]) -> Result<RawValue<'a>, ParseError> {
let len = read_int(buf)?;
match len {
-2 => Ok(RawValue::Unset),
-1 => Ok(RawValue::Null),
len if len >= 0 => {
let v = read_raw_bytes(len as usize, buf)?;
Ok(RawValue::Value(v))
}
len => Err(ParseError::BadIncomingData(format!(
"invalid value length: {}",
len,
))),
}
}

pub fn read_short_bytes<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], ParseError> {
let len = read_short_length(buf)?;
let v = read_raw_bytes(len, buf)?;
Expand Down
11 changes: 6 additions & 5 deletions scylla-cql/src/frame/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use chrono::{DateTime, NaiveDate, NaiveTime, TimeZone, Utc};

use super::response::result::CqlValue;
use super::types::vint_encode;
use super::types::RawValue;

#[cfg(feature = "secret")]
use secrecy::{ExposeSecret, Secret, Zeroize};
Expand Down Expand Up @@ -366,7 +367,7 @@ impl SerializedValues {
Ok(())
}

pub fn iter(&self) -> impl Iterator<Item = Option<&[u8]>> {
pub fn iter(&self) -> impl Iterator<Item = RawValue> {
SerializedValuesIterator {
serialized_values: &self.serialized_values,
contains_names: self.contains_names,
Expand Down Expand Up @@ -410,15 +411,15 @@ impl SerializedValues {
})
}

pub fn iter_name_value_pairs(&self) -> impl Iterator<Item = (Option<&str>, &[u8])> {
pub fn iter_name_value_pairs(&self) -> impl Iterator<Item = (Option<&str>, RawValue)> {
let mut buf = &self.serialized_values[..];
(0..self.values_num).map(move |_| {
// `unwrap()`s here are safe, as we assume type-safety: if `SerializedValues` exits,
// we have a guarantee that the layout of the serialized values is valid.
let name = self
.contains_names
.then(|| types::read_string(&mut buf).unwrap());
let serialized = types::read_bytes(&mut buf).unwrap();
let serialized = types::read_value(&mut buf).unwrap();
(name, serialized)
})
}
Expand All @@ -431,7 +432,7 @@ pub struct SerializedValuesIterator<'a> {
}

impl<'a> Iterator for SerializedValuesIterator<'a> {
type Item = Option<&'a [u8]>;
type Item = RawValue<'a>;

fn next(&mut self) -> Option<Self::Item> {
if self.serialized_values.is_empty() {
Expand All @@ -443,7 +444,7 @@ impl<'a> Iterator for SerializedValuesIterator<'a> {
types::read_short_bytes(&mut self.serialized_values).expect("badly encoded value name");
}

Some(types::read_bytes_opt(&mut self.serialized_values).expect("badly encoded value"))
Some(types::read_value(&mut self.serialized_values).expect("badly encoded value"))
}
}

Expand Down
37 changes: 23 additions & 14 deletions scylla-cql/src/frame/value_tests.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::frame::value::BatchValuesIterator;
use crate::frame::{types::RawValue, value::BatchValuesIterator};

use super::value::{
BatchValues, CqlDate, CqlTime, CqlTimestamp, MaybeUnset, SerializeValuesError,
Expand Down Expand Up @@ -421,7 +421,10 @@ fn serialized_values() {
values.write_to_request(&mut request);
assert_eq!(request, vec![0, 1, 0, 0, 0, 1, 8]);

assert_eq!(values.iter().collect::<Vec<_>>(), vec![Some([8].as_ref())]);
assert_eq!(
values.iter().collect::<Vec<_>>(),
vec![RawValue::Value([8].as_ref())]
);
}

// Add second value
Expand All @@ -436,7 +439,10 @@ fn serialized_values() {

assert_eq!(
values.iter().collect::<Vec<_>>(),
vec![Some([8].as_ref()), Some([0, 16].as_ref())]
vec![
RawValue::Value([8].as_ref()),
RawValue::Value([0, 16].as_ref())
]
);
}

Expand Down Expand Up @@ -468,7 +474,10 @@ fn serialized_values() {

assert_eq!(
values.iter().collect::<Vec<_>>(),
vec![Some([8].as_ref()), Some([0, 16].as_ref())]
vec![
RawValue::Value([8].as_ref()),
RawValue::Value([0, 16].as_ref())
]
);
}
}
Expand Down Expand Up @@ -498,9 +507,9 @@ fn slice_value_list() {
assert_eq!(
serialized.iter().collect::<Vec<_>>(),
vec![
Some([0, 0, 0, 1].as_ref()),
Some([0, 0, 0, 2].as_ref()),
Some([0, 0, 0, 3].as_ref())
RawValue::Value([0, 0, 0, 1].as_ref()),
RawValue::Value([0, 0, 0, 2].as_ref()),
RawValue::Value([0, 0, 0, 3].as_ref())
]
);
}
Expand All @@ -515,9 +524,9 @@ fn vec_value_list() {
assert_eq!(
serialized.iter().collect::<Vec<_>>(),
vec![
Some([0, 0, 0, 1].as_ref()),
Some([0, 0, 0, 2].as_ref()),
Some([0, 0, 0, 3].as_ref())
RawValue::Value([0, 0, 0, 1].as_ref()),
RawValue::Value([0, 0, 0, 2].as_ref()),
RawValue::Value([0, 0, 0, 3].as_ref())
]
);
}
Expand All @@ -530,7 +539,7 @@ fn tuple_value_list() {

let serialized_vals: Vec<u8> = serialized
.iter()
.map(|o: Option<&[u8]>| o.unwrap()[0])
.map(|o: RawValue| o.as_value().unwrap()[0])
.collect();

let expected: Vec<u8> = expected.collect();
Expand Down Expand Up @@ -604,9 +613,9 @@ fn ref_value_list() {
assert_eq!(
serialized.iter().collect::<Vec<_>>(),
vec![
Some([0, 0, 0, 1].as_ref()),
Some([0, 0, 0, 2].as_ref()),
Some([0, 0, 0, 3].as_ref())
RawValue::Value([0, 0, 0, 1].as_ref()),
RawValue::Value([0, 0, 0, 2].as_ref()),
RawValue::Value([0, 0, 0, 3].as_ref())
]
);
}
Expand Down
6 changes: 6 additions & 0 deletions scylla-cql/src/types/serialize/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,11 @@ use std::{any::Any, sync::Arc};

pub mod row;
pub mod value;
pub mod writers;

pub use writers::{
BufBackedCellValueBuilder, BufBackedCellWriter, BufBackedRowWriter, CellValueBuilder,
CellWriter, CountingWriter, RowWriter,
};

type SerializationError = Arc<dyn Any + Send + Sync>;
161 changes: 151 additions & 10 deletions scylla-cql/src/types/serialize/row.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};

use thiserror::Error;

use crate::frame::response::result::ColumnSpec;
use crate::frame::value::ValueList;
use crate::frame::{response::result::ColumnSpec, types::RawValue};

use super::SerializationError;
use super::{CellWriter, RowWriter, SerializationError};

/// Contains information needed to serialize a row.
pub struct RowSerializationContext<'a> {
columns: &'a [ColumnSpec],
}

impl<'a> RowSerializationContext<'a> {
/// Returns column/bind marker specifications for given query.
#[inline]
pub fn columns(&self) -> &'a [ColumnSpec] {
self.columns
}

/// Looks up and returns a column/bind marker by name.
// TODO: change RowSerializationContext to make this faster
#[inline]
pub fn column_by_name(&self, target: &str) -> Option<&ColumnSpec> {
Expand All @@ -23,11 +28,25 @@ impl<'a> RowSerializationContext<'a> {
}

pub trait SerializeRow {
/// Checks if it _might_ be possible to serialize the row according to the
/// information in the context.
///
/// This function is intended to serve as an optimization in the future,
/// if we were ever to introduce prepared statements parametrized by types.
///
/// Sometimes, a row cannot be fully type checked right away without knowing
/// the exact values of the columns (e.g. when deserializing to `CqlValue`),
/// but it's fine to do full type checking later in `serialize`.
fn preliminary_type_check(ctx: &RowSerializationContext<'_>) -> Result<(), SerializationError>;
fn serialize(

/// Serializes the row according to the information in the given context.
///
/// The function may assume that `preliminary_type_check` was called,
/// though it must not do anything unsafe if this assumption does not hold.
fn serialize<W: RowWriter>(
&self,
ctx: &RowSerializationContext<'_>,
out: &mut Vec<u8>,
writer: &mut W,
) -> Result<(), SerializationError>;
}

Expand All @@ -38,12 +57,134 @@ impl<T: ValueList> SerializeRow for T {
Ok(())
}

fn serialize(
fn serialize<W: RowWriter>(
&self,
_ctx: &RowSerializationContext<'_>,
out: &mut Vec<u8>,
ctx: &RowSerializationContext<'_>,
writer: &mut W,
) -> Result<(), SerializationError> {
self.write_to_request(out)
.map_err(|err| Arc::new(err) as SerializationError)
serialize_legacy_row(self, ctx, writer)
}
}

pub fn serialize_legacy_row<T: ValueList>(
r: &T,
ctx: &RowSerializationContext<'_>,
writer: &mut impl RowWriter,
) -> Result<(), SerializationError> {
let serialized =
<T as ValueList>::serialized(r).map_err(|err| Arc::new(err) as SerializationError)?;

let mut append_value = |value: RawValue| {
let cell_writer = writer.make_cell_writer();
let _proof = match value {
RawValue::Null => cell_writer.set_null(),
RawValue::Unset => cell_writer.set_unset(),
RawValue::Value(v) => cell_writer.set_value(v),
};
};

if !serialized.has_names() {
serialized.iter().for_each(append_value);
} else {
let values_by_name = serialized
.iter_name_value_pairs()
.map(|(k, v)| (k.unwrap(), v))
.collect::<HashMap<_, _>>();

for col in ctx.columns() {
let val = values_by_name.get(col.name.as_str()).ok_or_else(|| {
Arc::new(ValueListToSerializeRowAdapterError::NoBindMarkerWithName {
name: col.name.clone(),
}) as SerializationError
})?;
append_value(*val);
}
}

Ok(())
}

#[derive(Error, Debug)]
pub enum ValueListToSerializeRowAdapterError {
#[error("There is no bind marker with name {name}, but a value for it was provided")]
NoBindMarkerWithName { name: String },
}

#[cfg(test)]
mod tests {
use crate::frame::response::result::{ColumnSpec, ColumnType, TableSpec};
use crate::frame::value::{MaybeUnset, SerializedValues, ValueList};
use crate::types::serialize::BufBackedRowWriter;

use super::{RowSerializationContext, SerializeRow};

fn col_spec(name: &str, typ: ColumnType) -> ColumnSpec {
ColumnSpec {
table_spec: TableSpec {
ks_name: "ks".to_string(),
table_name: "tbl".to_string(),
},
name: name.to_string(),
typ,
}
}

#[test]
fn test_legacy_fallback() {
let row = (
1i32,
"Ala ma kota",
None::<i64>,
MaybeUnset::Unset::<String>,
);

let mut legacy_data = Vec::new();
<_ as ValueList>::write_to_request(&row, &mut legacy_data).unwrap();

let mut new_data = Vec::new();
let mut new_data_writer = BufBackedRowWriter::new(&mut new_data);
let ctx = RowSerializationContext { columns: &[] };
<_ as SerializeRow>::serialize(&row, &ctx, &mut new_data_writer).unwrap();
assert_eq!(new_data_writer.value_count(), 4);

// Skip the value count
assert_eq!(&legacy_data[2..], new_data);
}

#[test]
fn test_legacy_fallback_with_names() {
let sorted_row = (
1i32,
"Ala ma kota",
None::<i64>,
MaybeUnset::Unset::<String>,
);

let mut sorted_row_data = Vec::new();
<_ as ValueList>::write_to_request(&sorted_row, &mut sorted_row_data).unwrap();

let mut unsorted_row = SerializedValues::new();
unsorted_row.add_named_value("a", &1i32).unwrap();
unsorted_row.add_named_value("b", &"Ala ma kota").unwrap();
unsorted_row
.add_named_value("d", &MaybeUnset::Unset::<String>)
.unwrap();
unsorted_row.add_named_value("c", &None::<i64>).unwrap();

let mut unsorted_row_data = Vec::new();
let mut unsorted_row_data_writer = BufBackedRowWriter::new(&mut unsorted_row_data);
let ctx = RowSerializationContext {
columns: &[
col_spec("a", ColumnType::Int),
col_spec("b", ColumnType::Text),
col_spec("c", ColumnType::BigInt),
col_spec("d", ColumnType::Ascii),
],
};
<_ as SerializeRow>::serialize(&unsorted_row, &ctx, &mut unsorted_row_data_writer).unwrap();
assert_eq!(unsorted_row_data_writer.value_count(), 4);

// Skip the value count
assert_eq!(&sorted_row_data[2..], unsorted_row_data);
}
}
Loading