Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New deserialization API - preparations #1065

Merged
Merged
13 changes: 13 additions & 0 deletions scylla-cql/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use crate::frame::frame_errors::{CqlResponseParseError, FrameError, ParseError};
use crate::frame::protocol_features::ProtocolFeatures;
use crate::frame::value::SerializeValuesError;
use crate::types::deserialize::{DeserializationError, TypeCheckError};
use crate::types::serialize::SerializationError;
use crate::Consistency;
use bytes::Bytes;
Expand Down Expand Up @@ -461,6 +462,18 @@ impl From<SerializationError> for QueryError {
}
}

impl From<DeserializationError> for QueryError {
fn from(value: DeserializationError) -> Self {
Self::InvalidMessage(value.to_string())
}
}

impl From<TypeCheckError> for QueryError {
fn from(value: TypeCheckError) -> Self {
Self::InvalidMessage(value.to_string())
}
}

impl From<ParseError> for QueryError {
fn from(parse_error: ParseError) -> QueryError {
QueryError::InvalidMessage(format!("Error parsing message: {}", parse_error))
Expand Down
6 changes: 5 additions & 1 deletion scylla-cql/src/frame/frame_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;
use super::TryFromPrimitiveError;
use crate::cql_to_rust::CqlTypeError;
use crate::frame::value::SerializeValuesError;
use crate::types::deserialize::DeserializationError;
use crate::types::deserialize::{DeserializationError, TypeCheckError};
use crate::types::serialize::SerializationError;
use thiserror::Error;

Expand Down Expand Up @@ -46,6 +46,8 @@ pub enum ParseError {
#[error(transparent)]
DeserializationError(#[from] DeserializationError),
#[error(transparent)]
DeserializationTypeCheckError(#[from] TypeCheckError),
#[error(transparent)]
IoError(#[from] std::io::Error),
#[error(transparent)]
SerializeValuesError(#[from] SerializeValuesError),
Expand Down Expand Up @@ -216,6 +218,8 @@ pub enum PreparedParseError {
ResultMetadataParseError(ResultMetadataParseError),
#[error("Invalid prepared metadata: {0}")]
PreparedMetadataParseError(ResultMetadataParseError),
#[error("Non-zero paging state in result metadata: {0:?}")]
NonZeroPagingState(Arc<[u8]>),
}

/// An error type returned when deserialization
Expand Down
8 changes: 8 additions & 0 deletions scylla-cql/src/frame/request/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,14 @@ impl PagingStateResponse {
Self::NoMorePages => ControlFlow::Break(()),
}
}

/// Swaps the paging state response with PagingStateResponse::NoMorePages.
///
/// Only for use in driver's inner code, as an optimisation.
#[doc(hidden)]
pub fn take(&mut self) -> Self {
std::mem::replace(self, Self::NoMorePages)
}
}

/// The state of a paged query, i.e. where to resume fetching result rows
Expand Down
11 changes: 8 additions & 3 deletions scylla-cql/src/frame/response/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ pub mod event;
pub mod result;
pub mod supported;

use std::sync::Arc;

pub use error::Error;
pub use supported::Supported;

Expand Down Expand Up @@ -65,17 +67,20 @@ impl Response {
pub fn deserialize(
features: &ProtocolFeatures,
opcode: ResponseOpcode,
buf: &mut &[u8],
cached_metadata: Option<&ResultMetadata>,
buf_bytes: bytes::Bytes,
cached_metadata: Option<&Arc<ResultMetadata>>,
) -> Result<Response, CqlResponseParseError> {
let buf = &mut &*buf_bytes;
let response = match opcode {
ResponseOpcode::Error => Response::Error(Error::deserialize(features, buf)?),
ResponseOpcode::Ready => Response::Ready,
ResponseOpcode::Authenticate => {
Response::Authenticate(authenticate::Authenticate::deserialize(buf)?)
}
ResponseOpcode::Supported => Response::Supported(Supported::deserialize(buf)?),
ResponseOpcode::Result => Response::Result(result::deserialize(buf, cached_metadata)?),
ResponseOpcode::Result => {
Response::Result(result::deserialize(buf_bytes, cached_metadata)?)
}
ResponseOpcode::Event => Response::Event(event::Event::deserialize(buf)?),
ResponseOpcode::AuthChallenge => {
Response::AuthChallenge(authenticate::AuthChallenge::deserialize(buf)?)
Expand Down
81 changes: 48 additions & 33 deletions scylla-cql/src/frame/response/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::types::deserialize::value::{
use crate::types::deserialize::{DeserializationError, FrameSlice};
use bytes::{Buf, Bytes};
use std::borrow::Cow;
use std::sync::Arc;
use std::{net::IpAddr, result::Result as StdResult, str};
use uuid::Uuid;

Expand Down Expand Up @@ -431,7 +432,6 @@ pub struct ColumnSpec {
#[derive(Debug, Clone)]
pub struct ResultMetadata {
col_count: usize,
pub paging_state: PagingStateResponse,
pub col_specs: Vec<ColumnSpec>,
}

Expand All @@ -440,10 +440,18 @@ impl ResultMetadata {
pub fn mock_empty() -> Self {
Self {
col_count: 0,
paging_state: PagingStateResponse::NoMorePages,
col_specs: Vec::new(),
}
}

#[inline]
#[doc(hidden)]
pub fn new_for_test(col_count: usize, col_specs: Vec<ColumnSpec>) -> Self {
Self {
col_count,
col_specs,
}
}
}

#[derive(Debug, Copy, Clone)]
Expand Down Expand Up @@ -478,7 +486,8 @@ impl Row {

#[derive(Debug)]
pub struct Rows {
pub metadata: ResultMetadata,
pub metadata: Arc<ResultMetadata>,
pub paging_state_response: PagingStateResponse,
pub rows_count: usize,
pub rows: Vec<Row>,
/// Original size of the serialized rows.
Expand Down Expand Up @@ -620,7 +629,9 @@ fn deser_col_specs(
Ok(col_specs)
}

fn deser_result_metadata(buf: &mut &[u8]) -> StdResult<ResultMetadata, ResultMetadataParseError> {
fn deser_result_metadata(
buf: &mut &[u8],
) -> StdResult<(ResultMetadata, PagingStateResponse), ResultMetadataParseError> {
let flags = types::read_int(buf)
.map_err(|err| ResultMetadataParseError::FlagsParseError(err.into()))?;
let global_tables_spec = flags & 0x0001 != 0;
Expand All @@ -635,27 +646,23 @@ fn deser_result_metadata(buf: &mut &[u8]) -> StdResult<ResultMetadata, ResultMet
.transpose()?;
let paging_state = PagingStateResponse::new_from_raw_bytes(raw_paging_state);

if no_metadata {
return Ok(ResultMetadata {
col_count,
paging_state,
col_specs: vec![],
});
}

let global_table_spec = if global_tables_spec {
Some(deser_table_spec(buf)?)
let col_specs = if no_metadata {
vec![]
} else {
None
};
let global_table_spec = if global_tables_spec {
Some(deser_table_spec(buf)?)
} else {
None
};

let col_specs = deser_col_specs(buf, &global_table_spec, col_count)?;
deser_col_specs(buf, &global_table_spec, col_count)?
};

Ok(ResultMetadata {
let metadata = ResultMetadata {
col_count,
paging_state,
col_specs,
})
};
Ok((metadata, paging_state))
}

fn deser_prepared_metadata(
Expand Down Expand Up @@ -859,17 +866,14 @@ pub fn deser_cql_value(
}

fn deser_rows(
buf: &mut &[u8],
cached_metadata: Option<&ResultMetadata>,
buf_bytes: Bytes,
cached_metadata: Option<&Arc<ResultMetadata>>,
) -> StdResult<Rows, RowsParseError> {
let server_metadata = deser_result_metadata(buf)?;
let buf = &mut &*buf_bytes;
let (server_metadata, paging_state_response) = deser_result_metadata(buf)?;

let metadata = match cached_metadata {
Some(cached) => ResultMetadata {
col_count: cached.col_count,
paging_state: server_metadata.paging_state,
col_specs: cached.col_specs.clone(),
},
Some(cached) => Arc::clone(cached),
None => {
// No cached_metadata provided. Server is supposed to provide the result metadata.
if server_metadata.col_count != server_metadata.col_specs.len() {
Expand All @@ -878,7 +882,7 @@ fn deser_rows(
col_specs_count: server_metadata.col_specs.len(),
});
}
server_metadata
Arc::new(server_metadata)
}
};

Expand All @@ -899,6 +903,7 @@ fn deser_rows(

Ok(Rows {
metadata,
paging_state_response,
rows_count,
rows,
serialized_size: original_size - buf.len(),
Expand All @@ -919,8 +924,17 @@ fn deser_prepared(buf: &mut &[u8]) -> StdResult<Prepared, PreparedParseError> {
buf.advance(id_len);
let prepared_metadata =
deser_prepared_metadata(buf).map_err(PreparedParseError::PreparedMetadataParseError)?;
let result_metadata =
let (result_metadata, paging_state_response) =
deser_result_metadata(buf).map_err(PreparedParseError::ResultMetadataParseError)?;
if let PagingStateResponse::HasMorePages { state } = paging_state_response {
return Err(PreparedParseError::NonZeroPagingState(
state
.as_bytes_slice()
.cloned()
.unwrap_or_else(|| Arc::from([])),
));
}

Ok(Prepared {
id,
prepared_metadata,
Expand All @@ -935,16 +949,17 @@ fn deser_schema_change(buf: &mut &[u8]) -> StdResult<SchemaChange, SchemaChangeE
}

pub fn deserialize(
buf: &mut &[u8],
cached_metadata: Option<&ResultMetadata>,
buf_bytes: Bytes,
cached_metadata: Option<&Arc<ResultMetadata>>,
) -> StdResult<Result, CqlResultParseError> {
let buf = &mut &*buf_bytes;
use self::Result::*;
Ok(
match types::read_int(buf)
.map_err(|err| CqlResultParseError::ResultIdParseError(err.into()))?
{
0x0001 => Void,
0x0002 => Rows(deser_rows(buf, cached_metadata)?),
0x0002 => Rows(deser_rows(buf_bytes.slice_ref(buf), cached_metadata)?),
0x0003 => SetKeyspace(deser_set_keyspace(buf)?),
0x0004 => Prepared(deser_prepared(buf)?),
0x0005 => SchemaChange(deser_schema_change(buf)?),
Expand Down
2 changes: 2 additions & 0 deletions scylla-cql/src/types/deserialize/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use super::{DeserializationError, FrameSlice, TypeCheckError};
use std::marker::PhantomData;

/// Iterates over the whole result, returning rows.
#[derive(Debug)]
pub struct RowIterator<'frame> {
specs: &'frame [ColumnSpec],
remaining: usize,
Expand Down Expand Up @@ -76,6 +77,7 @@ impl<'frame> Iterator for RowIterator<'frame> {

/// A typed version of [RowIterator] which deserializes the rows before
/// returning them.
#[derive(Debug)]
pub struct TypedRowIterator<'frame, R> {
inner: RowIterator<'frame>,
_phantom: PhantomData<R>,
Expand Down
Loading
Loading