Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RUST-1437 Send endSessions on client shutdown #1216

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 68 additions & 2 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ pub mod options;
pub mod session;

use std::{
sync::{atomic::AtomicBool, Mutex as SyncMutex},
sync::{
atomic::{AtomicBool, Ordering},
Mutex as SyncMutex,
},
time::{Duration, Instant},
};

Expand All @@ -26,13 +29,18 @@ use crate::trace::{
COMMAND_TRACING_EVENT_TARGET,
};
use crate::{
bson::doc,
concern::{ReadConcern, WriteConcern},
db::Database,
error::{Error, ErrorKind, Result},
event::command::CommandEvent,
id_set::IdSet,
options::{ClientOptions, DatabaseOptions, ReadPreference, SelectionCriteria, ServerAddress},
sdam::{server_selection, SelectedServer, Topology},
sdam::{
server_selection::{self, attempt_to_select_server},
SelectedServer,
Topology,
},
tracking_arc::TrackingArc,
BoxFuture,
ClientSession,
Expand Down Expand Up @@ -123,6 +131,7 @@ struct ClientInner {
options: ClientOptions,
session_pool: ServerSessionPool,
shutdown: Shutdown,
dropped: AtomicBool,
#[cfg(feature = "in-use-encryption")]
csfle: tokio::sync::RwLock<Option<csfle::ClientState>>,
#[cfg(test)]
Expand Down Expand Up @@ -159,6 +168,7 @@ impl Client {
pending_drops: SyncMutex::new(IdSet::new()),
executed: AtomicBool::new(false),
},
dropped: AtomicBool::new(false),
#[cfg(feature = "in-use-encryption")]
csfle: Default::default(),
#[cfg(test)]
Expand Down Expand Up @@ -591,6 +601,41 @@ impl Client {
pub(crate) fn options(&self) -> &ClientOptions {
&self.inner.options
}

/// Ends all sessions contained in this client's session pool on the server.
pub(crate) async fn end_all_sessions(&self) {
// The maximum number of session IDs that should be sent in a single endSessions command.
const MAX_END_SESSIONS_BATCH_SIZE: usize = 10_000;

let mut watcher = self.inner.topology.watch();
let selection_criteria =
SelectionCriteria::from(ReadPreference::PrimaryPreferred { options: None });

let session_ids = self.inner.session_pool.get_session_ids().await;
for chunk in session_ids.chunks(MAX_END_SESSIONS_BATCH_SIZE) {
let state = watcher.observe_latest();
let Ok(Some(_)) = attempt_to_select_server(
&selection_criteria,
&state.description,
&state.servers(),
None,
) else {
// If a suitable server is not available, do not proceed with the operation to avoid
// spinning for server_selection_timeout.
return;
};

let end_sessions = doc! {
"endSessions": chunk,
};
let result = self
.database("admin")
.run_command(end_sessions)
.selection_criteria(selection_criteria.clone())
.await;
debug_assert!(result.is_ok());
}
}
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -625,3 +670,24 @@ impl AsyncDropToken {
Self { tx: self.tx.take() }
}
}

impl Drop for Client {
fn drop(&mut self) {
if !self.inner.shutdown.executed.load(Ordering::SeqCst)
&& !self.inner.dropped.load(Ordering::SeqCst)
&& TrackingArc::strong_count(&self.inner) == 1
{
// We need an owned copy of the client to move into the spawned future. However, if this
// call to drop completes before the spawned future completes, the number of strong
// references to the inner client will again be 1 when the cloned client drops, and thus
// end_all_sessions will be called continuously until the runtime shuts down. Storing a
// flag indicating whether end_all_sessions has already been called breaks
// this cycle.
self.inner.dropped.store(true, Ordering::SeqCst);
let client = self.clone();
crate::runtime::spawn(async move {
client.end_all_sessions().await;
});
}
}
}
5 changes: 5 additions & 0 deletions src/client/action/shutdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ impl Action for crate::action::Shutdown {
.extract();
join_all(pending).await;
}
// If shutdown has already been called on a different copy of the client, don't call
// end_all_sessions again.
if !self.client.inner.shutdown.executed.load(Ordering::SeqCst) {
self.client.end_all_sessions().await;
}
self.client.inner.topology.shutdown().await;
// This has to happen last to allow pending cleanup to execute commands.
self.client
Expand Down
2 changes: 1 addition & 1 deletion src/client/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ impl Drop for ClientSession {
#[derive(Clone, Debug)]
pub(crate) struct ServerSession {
/// The id of the server session to which this corresponds.
id: Document,
pub(crate) id: Document,

/// The last time an operation was executed with this session.
last_use: std::time::Instant,
Expand Down
7 changes: 6 additions & 1 deletion src/client/session/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::{collections::VecDeque, time::Duration};
use tokio::sync::Mutex;

use super::ServerSession;
#[cfg(test)]
use crate::bson::Document;

#[derive(Debug)]
Expand Down Expand Up @@ -68,4 +67,10 @@ impl ServerSessionPool {
pub(crate) async fn contains(&self, id: &Document) -> bool {
self.pool.lock().await.iter().any(|s| &s.id == id)
}

/// Returns a list of the IDs of the sessions contained in the pool.
pub(crate) async fn get_session_ids(&self) -> Vec<Document> {
let sessions = self.pool.lock().await;
sessions.iter().map(|session| session.id.clone()).collect()
}
}
1 change: 0 additions & 1 deletion src/gridfs/upload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ impl GridFsUploadStream {
}

impl Drop for GridFsUploadStream {
// TODO RUST-1493: pre-create this task
fn drop(&mut self) {
if !matches!(self.state, State::Closed) {
let chunks = self.bucket.chunks().clone();
Expand Down
60 changes: 58 additions & 2 deletions src/test/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize};
use crate::{
bson::{doc, Bson},
error::{CommandError, Error, ErrorKind},
event::{cmap::CmapEvent, sdam::SdamEvent},
event::{cmap::CmapEvent, command::CommandEvent, sdam::SdamEvent},
hello::LEGACY_HELLO_COMMAND_NAME,
options::{AuthMechanism, ClientOptions, Credential, ServerAddress},
runtime,
Expand All @@ -15,7 +15,7 @@ use crate::{
get_client_options,
log_uncaptured,
util::{
event_buffer::EventBuffer,
event_buffer::{EventBuffer, EventStream},
fail_point::{FailPoint, FailPointMode},
TestClient,
},
Expand Down Expand Up @@ -930,3 +930,59 @@ async fn warm_connection_pool() {
// Validate that a command executes.
client.list_database_names().await.unwrap();
}

async fn get_end_session_event_count<'a>(event_stream: &mut EventStream<'a, Event>) -> usize {
event_stream
.collect(Duration::from_millis(500), |event| match event {
Event::Command(CommandEvent::Started(command_started_event)) => {
command_started_event.command_name == "endSessions"
}
_ => false,
})
.await
.len()
}

#[tokio::test]
async fn end_sessions_on_drop() {
let client1 = Client::for_test().monitor_events().await;
let client2 = client1.clone();
let events = client1.events.clone();
let mut event_stream = events.stream();

// Run an operation to populate the session pool.
client1
.database("db")
.collection::<Document>("coll")
.find(doc! {})
.await
.unwrap();

drop(client1);
assert_eq!(get_end_session_event_count(&mut event_stream).await, 0);

drop(client2);
assert_eq!(get_end_session_event_count(&mut event_stream).await, 1);
}

#[tokio::test]
async fn end_sessions_on_shutdown() {
let client1 = Client::for_test().monitor_events().await;
let client2 = client1.clone();
let events = client1.events.clone();
let mut event_stream = events.stream();

// Run an operation to populate the session pool.
client1
.database("db")
.collection::<Document>("coll")
.find(doc! {})
.await
.unwrap();

client1.into_client().shutdown().await;
assert_eq!(get_end_session_event_count(&mut event_stream).await, 1);

client2.into_client().shutdown().await;
assert_eq!(get_end_session_event_count(&mut event_stream).await, 0);
}
36 changes: 0 additions & 36 deletions src/test/spec/json/connection-monitoring-and-pooling/README.rst

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,73 @@
}
}
},
{
"level": "debug",
"component": "connection",
"data": {
"message": "Connection checkout started",
"serverHost": {
"$$type": "string"
},
"serverPort": {
"$$type": [
"int",
"long"
]
}
}
},
{
"level": "debug",
"component": "connection",
"data": {
"message": "Connection checked out",
"driverConnectionId": {
"$$type": [
"int",
"long"
]
},
"serverHost": {
"$$type": "string"
},
"serverPort": {
"$$type": [
"int",
"long"
]
},
"durationMS": {
"$$type": [
"double",
"int",
"long"
]
}
}
},
{
"level": "debug",
"component": "connection",
"data": {
"message": "Connection checked in",
"driverConnectionId": {
"$$type": [
"int",
"long"
]
},
"serverHost": {
"$$type": "string"
},
"serverPort": {
"$$type": [
"int",
"long"
]
}
}
},
{
"level": "debug",
"component": "connection",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,31 @@ tests:
serverHost: { $$type: string }
serverPort: { $$type: [int, long] }

# The next three expected logs are for ending a session.
- level: debug
component: connection
data:
message: "Connection checkout started"
serverHost: { $$type: string }
serverPort: { $$type: [int, long] }

- level: debug
component: connection
data:
message: "Connection checked out"
driverConnectionId: { $$type: [int, long] }
serverHost: { $$type: string }
serverPort: { $$type: [int, long] }
durationMS: { $$type: [double, int, long] }

- level: debug
component: connection
data:
message: "Connection checked in"
driverConnectionId: { $$type: [int, long] }
serverHost: { $$type: string }
serverPort: { $$type: [int, long] }

- level: debug
component: connection
data:
Expand Down
8 changes: 7 additions & 1 deletion src/test/spec/unified_runner/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2182,7 +2182,13 @@ impl TestOperation for Close {
Entity::Client(_) => {
let client = entities.get_mut(id).unwrap().as_mut_client();
let closed_client_topology_id = client.topology_id;
client.client = None;
client
.client
.take()
.unwrap()
.shutdown()
.immediate(true)
.await;

let mut entities_to_remove = vec![];
for (key, value) in entities.iter() {
Expand Down
Loading