Skip to content

Commit

Permalink
session: require schema agreement to fetch schema version
Browse files Browse the repository at this point in the history
It can be a nasty source of race bugs when schema version is fetched
without awaiting for schema agreement. It is especially risky when one
awaits schema agreement and subsequently fetches schema version, because
the schema version can change in the meantime and schema may no longer
be in agreement.
To cope with the problem, `await_schema_agreement()` now returns the
agreed schema version, and `fetch_schema_version()` is removed
alltogether.
  • Loading branch information
wprzytula committed Aug 3, 2023
1 parent f90942c commit 4c6f5d5
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 90 deletions.
15 changes: 1 addition & 14 deletions docs/source/queries/schema-agreement.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,6 @@ Sometimes after performing queries some nodes have not been updated, so we need
There is a number of methods in `Session` that assist us.
Every method raise `QueryError` if something goes wrong, but they should never raise any errors, unless there is a DB or connection malfunction.

### Checking schema version
`Session::fetch_schema_version` returns an `Uuid` of local node's schema version.

```rust
# extern crate scylla;
# use scylla::Session;
# use std::error::Error;
# async fn check_only_compiles(session: &Session) -> Result<(), Box<dyn Error>> {
println!("Local schema version is: {}", session.fetch_schema_version().await?);
# Ok(())
# }
```

### Awaiting schema agreement

`Session::await_schema_agreement` returns a `Future` that can be `await`ed as long as schema is not in an agreement.
Expand Down Expand Up @@ -62,7 +49,7 @@ If you want to check if schema is in agreement now, without retrying after failu
# use scylla::Session;
# use std::error::Error;
# async fn check_only_compiles(session: &Session) -> Result<(), Box<dyn Error>> {
if session.check_schema_agreement().await? {
if session.check_schema_agreement().await?.is_some() {
println!("SCHEMA AGREED");
} else {
println!("SCHEMA IS NOT IN AGREEMENT");
Expand Down
18 changes: 10 additions & 8 deletions examples/schema_agreement.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use anyhow::Result;
use anyhow::{bail, Result};
use scylla::transport::errors::QueryError;
use scylla::transport::session::{IntoTypedRows, Session};
use scylla::SessionBuilder;
use std::env;
Expand All @@ -17,16 +18,17 @@ async fn main() -> Result<()> {
.build()
.await?;

let schema_version = session.fetch_schema_version().await?;
let schema_version = session.await_schema_agreement().await?;

println!("Schema version: {}", schema_version);

session.query("CREATE KEYSPACE IF NOT EXISTS ks WITH REPLICATION = {'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}", &[]).await?;

if session.await_schema_agreement().await? {
println!("Schema is in agreement");
} else {
println!("Schema is NOT in agreement");
}
match session.await_schema_agreement().await {
Ok(_schema_version) => println!("Schema is in agreement in time"),
Err(QueryError::RequestTimeout(_)) => println!("Schema is NOT in agreement in time"),
Err(err) => bail!(err),
};
session
.query(
"CREATE TABLE IF NOT EXISTS ks.t (a int, b int, c text, primary key (a, b))",
Expand Down Expand Up @@ -66,7 +68,7 @@ async fn main() -> Result<()> {
}
println!("Ok.");

let schema_version = session.fetch_schema_version().await?;
let schema_version = session.await_schema_agreement().await?;
println!("Schema version: {}", schema_version);

Ok(())
Expand Down
79 changes: 18 additions & 61 deletions scylla/src/transport/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use std::time::Duration;
use tokio::net::lookup_host;
use tokio::time::timeout;
use tracing::warn;
use tracing::{debug, error, trace, trace_span, Instrument};
use tracing::{debug, trace, trace_span, Instrument};
use uuid::Uuid;

use super::cluster::ContactPoint;
Expand All @@ -58,7 +58,7 @@ use crate::frame::value::{
use crate::prepared_statement::{PartitionKeyError, PreparedStatement};
use crate::query::Query;
use crate::routing::Token;
use crate::statement::{Consistency, SerialConsistency};
use crate::statement::Consistency;
use crate::tracing::{TracingEvent, TracingInfo};
use crate::transport::cluster::{Cluster, ClusterData, ClusterNeatDebug};
use crate::transport::connection::{Connection, ConnectionConfig, VerifiedKeyspaceName};
Expand Down Expand Up @@ -747,8 +747,7 @@ impl Session {
};

self.handle_set_keyspace_response(&response).await?;
self.handle_auto_await_schema_agreement(&query.contents, &response)
.await?;
self.handle_auto_await_schema_agreement(&response).await?;

let result = response.into_query_result()?;
span.record_result_fields(&result);
Expand All @@ -773,18 +772,11 @@ impl Session {

async fn handle_auto_await_schema_agreement(
&self,
contents: &str,
response: &NonErrorQueryResponse,
) -> Result<(), QueryError> {
if self.schema_agreement_automatic_waiting {
if response.as_schema_change().is_some() && !self.await_schema_agreement().await? {
// TODO: The TimeoutError should allow to provide more context.
// For now, print an error to the logs
error!(
"Failed to reach schema agreement after a schema-altering statement: {}",
contents,
);
return Err(QueryError::TimeoutError);
if response.as_schema_change().is_some() {
self.await_schema_agreement().await?;
}

if self.refresh_metadata_on_auto_schema_agreement
Expand Down Expand Up @@ -1099,8 +1091,7 @@ impl Session {
};

self.handle_set_keyspace_response(&response).await?;
self.handle_auto_await_schema_agreement(prepared.get_statement(), &response)
.await?;
self.handle_auto_await_schema_agreement(&response).await?;

let result = response.into_query_result()?;
span.record_result_fields(&result);
Expand Down Expand Up @@ -1848,62 +1839,35 @@ impl Session {
last_error.map(Result::Err)
}

async fn await_schema_agreement_indefinitely(&self) -> Result<(), QueryError> {
while !self.check_schema_agreement().await? {
tokio::time::sleep(self.schema_agreement_interval).await
async fn await_schema_agreement_indefinitely(&self) -> Result<Uuid, QueryError> {
loop {
tokio::time::sleep(self.schema_agreement_interval).await;
if let Some(agreed_version) = self.check_schema_agreement().await? {
return Ok(agreed_version);
}
}
Ok(())
}

pub async fn await_schema_agreement(&self) -> Result<bool, QueryError> {
pub async fn await_schema_agreement(&self) -> Result<Uuid, QueryError> {
timeout(
self.schema_agreement_timeout,
self.await_schema_agreement_indefinitely(),
)
.await
.map_or(Ok(false), |res| res.and(Ok(true)))
.unwrap_or(Err(QueryError::RequestTimeout(
"schema agreement not reached in time".to_owned(),
)))
}

pub async fn check_schema_agreement(&self) -> Result<bool, QueryError> {
pub async fn check_schema_agreement(&self) -> Result<Option<Uuid>, QueryError> {
let connections = self.cluster.get_working_connections().await?;

let handles = connections.iter().map(|c| c.fetch_schema_version());
let versions = try_join_all(handles).await?;

let local_version: Uuid = versions[0];
let in_agreement = versions.into_iter().all(|v| v == local_version);
Ok(in_agreement)
}

pub async fn fetch_schema_version(&self) -> Result<Uuid, QueryError> {
// We ignore custom Consistency that a retry policy could decide to put here, using the default instead.
let info = RoutingInfo::default();
let config = StatementConfig {
is_idempotent: true,
serial_consistency: Some(Some(SerialConsistency::LocalSerial)),
..Default::default()
};

let span = RequestSpan::new_none();

match self
.run_query(
info,
&config,
self.get_default_execution_profile_handle().access(),
|node: Arc<Node>| async move { node.random_connection().await },
|connection: Arc<Connection>, _: Consistency, _: &ExecutionProfileInner| async move {
connection.fetch_schema_version().await
},
&span,
)
.await?
{
RunQueryResult::IgnoredWriteError => Err(QueryError::ProtocolError(
"Retry policy has made the driver ignore fetching schema version query.",
)),
RunQueryResult::Completed(result) => Ok(result),
}
Ok(in_agreement.then_some(local_version))
}

fn calculate_partition_key(
Expand Down Expand Up @@ -2189,13 +2153,6 @@ impl RequestSpan {
}
}

pub(crate) fn new_none() -> Self {
Self {
span: tracing::Span::none(),
speculative_executions: 0.into(),
}
}

pub(crate) fn record_shard_id(&self, conn: &Connection) {
if let Some(info) = conn.get_shard_info() {
self.span.record("shard", info.shard);
Expand Down
8 changes: 1 addition & 7 deletions scylla/src/transport/session_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1103,16 +1103,10 @@ async fn assert_in_tracing_table(session: &Session, tracing_uuid: Uuid) {
panic!("No rows for tracing with this session id!");
}

#[tokio::test]
async fn test_fetch_schema_version() {
let session = create_new_session_builder().build().await.unwrap();
session.fetch_schema_version().await.unwrap();
}

#[tokio::test]
async fn test_await_schema_agreement() {
let session = create_new_session_builder().build().await.unwrap();
session.await_schema_agreement().await.unwrap();
let _schema_version = session.await_schema_agreement().await.unwrap();
}

#[tokio::test]
Expand Down

0 comments on commit 4c6f5d5

Please sign in to comment.