diff --git a/quickwit/quickwit-ingest/src/ingest_v2/broadcast.rs b/quickwit/quickwit-ingest/src/ingest_v2/broadcast.rs index a4ff25f6264..b0bc4c37602 100644 --- a/quickwit/quickwit-ingest/src/ingest_v2/broadcast.rs +++ b/quickwit/quickwit-ingest/src/ingest_v2/broadcast.rs @@ -18,7 +18,6 @@ // along with this program. If not, see . use std::collections::{BTreeMap, BTreeSet}; -use std::sync::Weak; use std::time::Duration; use bytesize::ByteSize; @@ -30,12 +29,11 @@ use quickwit_common::tower::Rate; use quickwit_proto::ingest::ShardState; use quickwit_proto::types::{split_queue_id, NodeId, QueueId, ShardId, SourceUid}; use serde::{Deserialize, Serialize, Serializer}; -use tokio::sync::RwLock; use tokio::task::JoinHandle; use tracing::{debug, warn}; -use super::ingester::IngesterState; use super::metrics::INGEST_V2_METRICS; +use super::state::WeakIngesterState; use crate::RateMibPerSec; const BROADCAST_INTERVAL_PERIOD: Duration = if cfg!(test) { @@ -149,11 +147,11 @@ impl LocalShardsSnapshot { /// broadcasts it to other nodes via Chitchat. pub(super) struct BroadcastLocalShardsTask { cluster: Cluster, - weak_state: Weak>, + weak_state: WeakIngesterState, } impl BroadcastLocalShardsTask { - pub fn spawn(cluster: Cluster, weak_state: Weak>) -> JoinHandle<()> { + pub fn spawn(cluster: Cluster, weak_state: WeakIngesterState) -> JoinHandle<()> { let mut broadcaster = Self { cluster, weak_state, @@ -163,7 +161,7 @@ impl BroadcastLocalShardsTask { async fn snapshot_local_shards(&self) -> Option { let state = self.weak_state.upgrade()?; - let mut state_guard = state.write().await; + let mut state_guard = state.lock_partially().await; let mut per_source_shard_infos: BTreeMap = BTreeMap::new(); @@ -331,14 +329,14 @@ pub async fn setup_local_shards_update_listener( #[cfg(test)] mod tests { - use std::collections::HashMap; + use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use mrecordlog::MultiRecordLog; use quickwit_cluster::{create_cluster_for_test, ChannelTransport}; use quickwit_common::rate_limiter::{RateLimiter, RateLimiterSettings}; - use quickwit_proto::ingest::ingester::{IngesterStatus, ObservationMessage}; + use quickwit_proto::ingest::ingester::ObservationMessage; use quickwit_proto::ingest::ShardState; use quickwit_proto::types::{queue_id, Position}; use tokio::sync::watch; @@ -346,6 +344,7 @@ mod tests { use super::*; use crate::ingest_v2::models::IngesterShard; use crate::ingest_v2::rate_meter::RateMeter; + use crate::ingest_v2::state::IngesterState; #[test] fn test_shard_info_serde() { @@ -473,16 +472,8 @@ mod tests { let tempdir = tempfile::tempdir().unwrap(); let mrecordlog = MultiRecordLog::open(tempdir.path()).await.unwrap(); let (observation_tx, _observation_rx) = watch::channel(Ok(ObservationMessage::default())); - let state = Arc::new(RwLock::new(IngesterState { - mrecordlog, - shards: HashMap::new(), - rate_trackers: HashMap::new(), - replication_streams: HashMap::new(), - replication_tasks: HashMap::new(), - status: IngesterStatus::Ready, - observation_tx, - })); - let weak_state = Arc::downgrade(&state); + let state = IngesterState::new(mrecordlog, observation_tx); + let weak_state = state.weak(); let task = BroadcastLocalShardsTask { cluster, weak_state, @@ -490,7 +481,7 @@ mod tests { let previous_snapshot = task.snapshot_local_shards().await.unwrap(); assert!(previous_snapshot.per_source_shard_infos.is_empty()); - let mut state_guard = state.write().await; + let mut state_guard = state.lock_partially().await; let queue_id_01 = queue_id("test-index:0", "test-source", &ShardId::from(1)); let shard = diff --git a/quickwit/quickwit-ingest/src/ingest_v2/fetch.rs b/quickwit/quickwit-ingest/src/ingest_v2/fetch.rs index 364dae139a2..cbf94bb6ed3 100644 --- a/quickwit/quickwit-ingest/src/ingest_v2/fetch.rs +++ b/quickwit/quickwit-ingest/src/ingest_v2/fetch.rs @@ -25,6 +25,7 @@ use std::sync::Arc; use bytes::{BufMut, BytesMut}; use futures::StreamExt; +use mrecordlog::MultiRecordLog; use quickwit_common::retry::RetryParams; use quickwit_common::ServiceStream; use quickwit_proto::ingest::ingester::{ @@ -36,7 +37,6 @@ use tokio::sync::{mpsc, watch, RwLock}; use tokio::task::JoinHandle; use tracing::{debug, error, warn}; -use super::ingester::IngesterState; use super::models::ShardStatus; use crate::{with_lock_metrics, ClientId, IngesterPool}; @@ -51,7 +51,7 @@ pub(super) struct FetchStreamTask { queue_id: QueueId, /// The position of the next record fetched. from_position_inclusive: u64, - state: Arc>, + mrecordlog: Arc>, fetch_message_tx: mpsc::Sender>, /// This channel notifies the fetch task when new records are available. This way the fetch /// task does not need to grab the lock and poll the mrecordlog queue unnecessarily. @@ -75,7 +75,7 @@ impl FetchStreamTask { pub fn spawn( open_fetch_stream_request: OpenFetchStreamRequest, - state: Arc>, + mrecordlog: Arc>, shard_status_rx: watch::Receiver, batch_num_bytes: usize, ) -> (ServiceStream>, JoinHandle<()>) { @@ -92,7 +92,7 @@ impl FetchStreamTask { index_uid: open_fetch_stream_request.index_uid.into(), source_id: open_fetch_stream_request.source_id, from_position_inclusive, - state, + mrecordlog, fetch_message_tx, shard_status_rx, batch_num_bytes, @@ -126,11 +126,11 @@ impl FetchStreamTask { let mut mrecord_buffer = BytesMut::with_capacity(self.batch_num_bytes); let mut mrecord_lengths = Vec::new(); - let state_guard = with_lock_metrics!(self.state.read().await, "fetch", "read"); + let mrecordlog_guard = + with_lock_metrics!(self.mrecordlog.read().await, "fetch", "read"); - let Ok(mrecords) = state_guard - .mrecordlog - .range(&self.queue_id, self.from_position_inclusive..) + let Ok(mrecords) = + mrecordlog_guard.range(&self.queue_id, self.from_position_inclusive..) else { // The queue was dropped. break; @@ -144,7 +144,7 @@ impl FetchStreamTask { mrecord_lengths.push(mrecord.len() as u32); } // Drop the lock while we send the message. - drop(state_guard); + drop(mrecordlog_guard); if !mrecord_lengths.is_empty() { let from_position_exclusive = if self.from_position_inclusive == 0 { @@ -593,9 +593,7 @@ pub(super) mod tests { use bytes::Bytes; use mrecordlog::MultiRecordLog; - use quickwit_proto::ingest::ingester::{ - IngesterServiceClient, IngesterStatus, ObservationMessage, - }; + use quickwit_proto::ingest::ingester::IngesterServiceClient; use quickwit_proto::ingest::ShardState; use quickwit_proto::types::queue_id; use tokio::time::timeout; @@ -620,7 +618,9 @@ pub(super) mod tests { #[tokio::test] async fn test_fetch_task_happy_path() { let tempdir = tempfile::tempdir().unwrap(); - let mrecordlog = MultiRecordLog::open(tempdir.path()).await.unwrap(); + let mrecordlog = Arc::new(RwLock::new( + MultiRecordLog::open(tempdir.path()).await.unwrap(), + )); let client_id = "test-client".to_string(); let index_uid = "test-index:0".to_string(); let source_id = "test-source".to_string(); @@ -631,38 +631,23 @@ pub(super) mod tests { shard_id: Some(ShardId::from(1)), from_position_exclusive: Some(Position::Beginning), }; - let (observation_tx, _observation_rx) = watch::channel(Ok(ObservationMessage::default())); - let state = Arc::new(RwLock::new(IngesterState { - mrecordlog, - shards: HashMap::new(), - rate_trackers: HashMap::new(), - replication_streams: HashMap::new(), - replication_tasks: HashMap::new(), - status: IngesterStatus::Ready, - observation_tx, - })); let (shard_status_tx, shard_status_rx) = watch::channel(ShardStatus::default()); let (mut fetch_stream, fetch_task_handle) = FetchStreamTask::spawn( open_fetch_stream_request, - state.clone(), + mrecordlog.clone(), shard_status_rx, 1024, ); let queue_id = queue_id(&index_uid, &source_id, &ShardId::from(1)); - let mut state_guard = state.write().await; + let mut mrecordlog_guard = mrecordlog.write().await; - state_guard - .mrecordlog - .create_queue(&queue_id) - .await - .unwrap(); - state_guard - .mrecordlog + mrecordlog_guard.create_queue(&queue_id).await.unwrap(); + mrecordlog_guard .append_record(&queue_id, None, MRecord::new_doc("test-doc-foo").encode()) .await .unwrap(); - drop(state_guard); + drop(mrecordlog_guard); let fetch_message = timeout(Duration::from_millis(100), fetch_stream.next()) .await @@ -704,14 +689,13 @@ pub(super) mod tests { .await .unwrap_err(); - let mut state_guard = state.write().await; + let mut mrecordlog_guard = mrecordlog.write().await; - state_guard - .mrecordlog + mrecordlog_guard .append_record(&queue_id, None, MRecord::new_doc("test-doc-bar").encode()) .await .unwrap(); - drop(state_guard); + drop(mrecordlog_guard); let shard_status = (ShardState::Open, Position::offset(1u64)); shard_status_tx.send(shard_status.clone()).unwrap(); @@ -744,7 +728,7 @@ pub(super) mod tests { "\0\0test-doc-bar" ); - let mut state_guard = state.write().await; + let mut mrecordlog_guard = mrecordlog.write().await; let mrecords = [ MRecord::new_doc("test-doc-baz").encode(), @@ -752,12 +736,11 @@ pub(super) mod tests { ] .into_iter(); - state_guard - .mrecordlog + mrecordlog_guard .append_records(&queue_id, None, mrecords) .await .unwrap(); - drop(state_guard); + drop(mrecordlog_guard); let shard_status = (ShardState::Open, Position::offset(3u64)); shard_status_tx.send(shard_status).unwrap(); @@ -811,7 +794,9 @@ pub(super) mod tests { #[tokio::test] async fn test_fetch_task_eof_at_beginning() { let tempdir = tempfile::tempdir().unwrap(); - let mrecordlog = MultiRecordLog::open(tempdir.path()).await.unwrap(); + let mrecordlog = Arc::new(RwLock::new( + MultiRecordLog::open(tempdir.path()).await.unwrap(), + )); let client_id = "test-client".to_string(); let index_uid = "test-index:0".to_string(); let source_id = "test-source".to_string(); @@ -822,33 +807,19 @@ pub(super) mod tests { shard_id: Some(ShardId::from(1)), from_position_exclusive: Some(Position::Beginning), }; - let (observation_tx, _observation_rx) = watch::channel(Ok(ObservationMessage::default())); - let state = Arc::new(RwLock::new(IngesterState { - mrecordlog, - shards: HashMap::new(), - rate_trackers: HashMap::new(), - replication_streams: HashMap::new(), - replication_tasks: HashMap::new(), - status: IngesterStatus::Ready, - observation_tx, - })); let (shard_status_tx, shard_status_rx) = watch::channel(ShardStatus::default()); let (mut fetch_stream, fetch_task_handle) = FetchStreamTask::spawn( open_fetch_stream_request, - state.clone(), + mrecordlog.clone(), shard_status_rx, 1024, ); let queue_id = queue_id(&index_uid, &source_id, &ShardId::from(1)); - let mut state_guard = state.write().await; + let mut mrecordlog_guard = mrecordlog.write().await; - state_guard - .mrecordlog - .create_queue(&queue_id) - .await - .unwrap(); - drop(state_guard); + mrecordlog_guard.create_queue(&queue_id).await.unwrap(); + drop(mrecordlog_guard); timeout(Duration::from_millis(100), fetch_stream.next()) .await @@ -875,7 +846,9 @@ pub(super) mod tests { #[tokio::test] async fn test_fetch_task_from_position_exclusive() { let tempdir = tempfile::tempdir().unwrap(); - let mrecordlog = MultiRecordLog::open(tempdir.path()).await.unwrap(); + let mrecordlog = Arc::new(RwLock::new( + MultiRecordLog::open(tempdir.path()).await.unwrap(), + )); let client_id = "test-client".to_string(); let index_uid = "test-index:0".to_string(); let source_id = "test-source".to_string(); @@ -886,46 +859,31 @@ pub(super) mod tests { shard_id: Some(ShardId::from(1)), from_position_exclusive: Some(Position::offset(0u64)), }; - let (observation_tx, _observation_rx) = watch::channel(Ok(ObservationMessage::default())); - let state = Arc::new(RwLock::new(IngesterState { - mrecordlog, - shards: HashMap::new(), - rate_trackers: HashMap::new(), - replication_streams: HashMap::new(), - replication_tasks: HashMap::new(), - status: IngesterStatus::Ready, - observation_tx, - })); let (shard_status_tx, shard_status_rx) = watch::channel(ShardStatus::default()); let (mut fetch_stream, _fetch_task_handle) = FetchStreamTask::spawn( open_fetch_stream_request, - state.clone(), + mrecordlog.clone(), shard_status_rx, 1024, ); let queue_id = queue_id(&index_uid, &source_id, &ShardId::from(1)); - let mut state_guard = state.write().await; + let mut mrecordlog_guard = mrecordlog.write().await; - state_guard - .mrecordlog - .create_queue(&queue_id) - .await - .unwrap(); - drop(state_guard); + mrecordlog_guard.create_queue(&queue_id).await.unwrap(); + drop(mrecordlog_guard); timeout(Duration::from_millis(100), fetch_stream.next()) .await .unwrap_err(); - let mut state_guard = state.write().await; + let mut mrecordlog_guard = mrecordlog.write().await; - state_guard - .mrecordlog + mrecordlog_guard .append_record(&queue_id, None, MRecord::new_doc("test-doc-foo").encode()) .await .unwrap(); - drop(state_guard); + drop(mrecordlog_guard); let shard_status = (ShardState::Open, Position::offset(0u64)); shard_status_tx.send(shard_status).unwrap(); @@ -934,14 +892,13 @@ pub(super) mod tests { .await .unwrap_err(); - let mut state_guard = state.write().await; + let mut mrecordlog_guard = mrecordlog.write().await; - state_guard - .mrecordlog + mrecordlog_guard .append_record(&queue_id, None, MRecord::new_doc("test-doc-bar").encode()) .await .unwrap(); - drop(state_guard); + drop(mrecordlog_guard); let shard_status = (ShardState::Open, Position::offset(1u64)); shard_status_tx.send(shard_status).unwrap(); @@ -981,7 +938,9 @@ pub(super) mod tests { #[tokio::test] async fn test_fetch_task_error() { let tempdir = tempfile::tempdir().unwrap(); - let mrecordlog = MultiRecordLog::open(tempdir.path()).await.unwrap(); + let mrecordlog = Arc::new(RwLock::new( + MultiRecordLog::open(tempdir.path()).await.unwrap(), + )); let client_id = "test-client".to_string(); let index_uid = "test-index:0".to_string(); let source_id = "test-source".to_string(); @@ -992,20 +951,10 @@ pub(super) mod tests { shard_id: Some(ShardId::from(1)), from_position_exclusive: Some(Position::Beginning), }; - let (observation_tx, _observation_rx) = watch::channel(Ok(ObservationMessage::default())); - let state = Arc::new(RwLock::new(IngesterState { - mrecordlog, - shards: HashMap::new(), - rate_trackers: HashMap::new(), - replication_streams: HashMap::new(), - replication_tasks: HashMap::new(), - status: IngesterStatus::Ready, - observation_tx, - })); let (_shard_status_tx, shard_status_rx) = watch::channel(ShardStatus::default()); let (mut fetch_stream, fetch_task_handle) = FetchStreamTask::spawn( open_fetch_stream_request, - state.clone(), + mrecordlog.clone(), shard_status_rx, 1024, ); @@ -1022,7 +971,9 @@ pub(super) mod tests { #[tokio::test] async fn test_fetch_task_batch_num_bytes() { let tempdir = tempfile::tempdir().unwrap(); - let mrecordlog = MultiRecordLog::open(tempdir.path()).await.unwrap(); + let mrecordlog = Arc::new(RwLock::new( + MultiRecordLog::open(tempdir.path()).await.unwrap(), + )); let client_id = "test-client".to_string(); let index_uid = "test-index:0".to_string(); let source_id = "test-source".to_string(); @@ -1033,32 +984,18 @@ pub(super) mod tests { shard_id: Some(ShardId::from(1)), from_position_exclusive: Some(Position::Beginning), }; - let (observation_tx, _observation_rx) = watch::channel(Ok(ObservationMessage::default())); - let state = Arc::new(RwLock::new(IngesterState { - mrecordlog, - shards: HashMap::new(), - rate_trackers: HashMap::new(), - replication_streams: HashMap::new(), - replication_tasks: HashMap::new(), - status: IngesterStatus::Ready, - observation_tx, - })); let (shard_status_tx, shard_status_rx) = watch::channel(ShardStatus::default()); let (mut fetch_stream, _fetch_task_handle) = FetchStreamTask::spawn( open_fetch_stream_request, - state.clone(), + mrecordlog.clone(), shard_status_rx, 30, ); let queue_id = queue_id(&index_uid, &source_id, &ShardId::from(1)); - let mut state_guard = state.write().await; + let mut mrecordlog_guard = mrecordlog.write().await; - state_guard - .mrecordlog - .create_queue(&queue_id) - .await - .unwrap(); + mrecordlog_guard.create_queue(&queue_id).await.unwrap(); let records = [ Bytes::from_static(b"test-doc-foo"), @@ -1067,12 +1004,11 @@ pub(super) mod tests { ] .into_iter(); - state_guard - .mrecordlog + mrecordlog_guard .append_records(&queue_id, None, records) .await .unwrap(); - drop(state_guard); + drop(mrecordlog_guard); let shard_status = (ShardState::Open, Position::offset(2u64)); shard_status_tx.send(shard_status).unwrap(); diff --git a/quickwit/quickwit-ingest/src/ingest_v2/ingester.rs b/quickwit/quickwit-ingest/src/ingest_v2/ingester.rs index 2e4a1ad5072..49e8c06a878 100644 --- a/quickwit/quickwit-ingest/src/ingest_v2/ingester.rs +++ b/quickwit/quickwit-ingest/src/ingest_v2/ingester.rs @@ -21,14 +21,14 @@ use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet}; use std::fmt; use std::path::Path; -use std::sync::{Arc, Weak}; use std::time::Duration; use async_trait::async_trait; use bytesize::ByteSize; +use fnv::FnvHashMap; use futures::stream::FuturesUnordered; use futures::StreamExt; -use mrecordlog::error::{CreateQueueError, DeleteQueueError, TruncateError}; +use mrecordlog::error::CreateQueueError; use mrecordlog::MultiRecordLog; use quickwit_cluster::Cluster; use quickwit_common::pubsub::{EventBroker, EventSubscriber}; @@ -49,7 +49,7 @@ use quickwit_proto::ingest::ingester::{ }; use quickwit_proto::ingest::{CommitTypeV2, IngestV2Error, IngestV2Result, Shard, ShardState}; use quickwit_proto::types::{queue_id, NodeId, Position, QueueId}; -use tokio::sync::{watch, RwLock}; +use tokio::sync::watch; use tracing::{debug, error, info, warn}; use super::fetch::FetchStreamTask; @@ -61,13 +61,14 @@ use super::mrecordlog_utils::{ use super::rate_meter::RateMeter; use super::replication::{ ReplicationClient, ReplicationStreamTask, ReplicationStreamTaskHandle, ReplicationTask, - ReplicationTaskHandle, SYN_REPLICATION_STREAM_CAPACITY, + SYN_REPLICATION_STREAM_CAPACITY, }; +use super::state::{IngesterState, InnerIngesterState, WeakIngesterState}; use super::IngesterPool; use crate::ingest_v2::broadcast::BroadcastLocalShardsTask; use crate::ingest_v2::mrecordlog_utils::queue_position_range; use crate::metrics::INGEST_METRICS; -use crate::{estimate_size, with_lock_metrics, with_request_metrics, FollowerId, LeaderId}; +use crate::{estimate_size, with_lock_metrics, with_request_metrics, FollowerId}; /// Duration after which persist requests time out with /// [`quickwit_proto::ingest::IngestV2Error::Timeout`]. @@ -81,7 +82,7 @@ pub(super) const PERSIST_REQUEST_TIMEOUT: Duration = if cfg!(any(test, feature = pub struct Ingester { self_node_id: NodeId, ingester_pool: IngesterPool, - state: Arc>, + state: IngesterState, disk_capacity: ByteSize, memory_capacity: ByteSize, rate_limiter_settings: RateLimiterSettings, @@ -97,18 +98,6 @@ impl fmt::Debug for Ingester { } } -pub(super) struct IngesterState { - pub mrecordlog: MultiRecordLog, - pub shards: HashMap, - pub rate_trackers: HashMap, - // Replication stream opened with followers. - pub replication_streams: HashMap, - // Replication tasks running for each replication stream opened with leaders. - pub replication_tasks: HashMap, - pub status: IngesterStatus, - pub observation_tx: watch::Sender>, -} - impl Ingester { pub async fn try_new( cluster: Cluster, @@ -142,19 +131,11 @@ impl Ingester { }; let (observation_tx, observation_rx) = watch::channel(Ok(observe_message)); - let inner = IngesterState { - mrecordlog, - shards: HashMap::new(), - rate_trackers: HashMap::new(), - replication_streams: HashMap::new(), - replication_tasks: HashMap::new(), - status: IngesterStatus::Ready, - observation_tx, - }; + let state = IngesterState::new(mrecordlog, observation_tx); let ingester = Self { self_node_id, ingester_pool, - state: Arc::new(RwLock::new(inner)), + state, disk_capacity, memory_capacity, rate_limiter_settings, @@ -168,14 +149,14 @@ impl Ingester { ); ingester.init().await?; - let weak_state = Arc::downgrade(&ingester.state); + let weak_state = ingester.state.weak(); BroadcastLocalShardsTask::spawn(cluster, weak_state); Ok(ingester) } /// Checks whether the ingester is fully decommissioned and updates its status accordingly. - fn check_decommissioning_status(&self, state: &mut IngesterState) { + fn check_decommissioning_status(&self, state: &mut InnerIngesterState) { if state.status != IngesterStatus::Decommissioning { return; } @@ -197,7 +178,7 @@ impl Ingester { /// the write-ahead log. Empty queues are deleted, while non-empty queues are recovered. /// However, the corresponding shards are closed and become read-only. async fn init(&self) -> IngestV2Result<()> { - let mut state_guard = self.state.write().await; + let mut state_guard = self.state.lock_fully().await; let queue_ids: Vec = state_guard .mrecordlog @@ -262,7 +243,8 @@ impl Ingester { /// - initialize the replica shard. async fn init_primary_shard( &self, - state: &mut IngesterState, + state: &mut InnerIngesterState, + mrecordlog: &mut MultiRecordLog, shard: Shard, ) -> IngestV2Result<()> { let queue_id = shard.queue_id(); @@ -275,7 +257,7 @@ impl Ingester { let Entry::Vacant(entry) = state.shards.entry(queue_id.clone()) else { return Ok(()); }; - match state.mrecordlog.create_queue(&queue_id).await { + match mrecordlog.create_queue(&queue_id).await { Ok(_) => {} Err(CreateQueueError::AlreadyExists) => panic!("queue should not exist"), Err(CreateQueueError::IoError(io_error)) => { @@ -327,7 +309,7 @@ impl Ingester { async fn init_replication_stream( &self, - replication_streams: &mut HashMap, + replication_streams: &mut FnvHashMap, leader_id: NodeId, follower_id: NodeId, ) -> IngestV2Result { @@ -384,7 +366,7 @@ impl Ingester { } pub fn subscribe(&self, event_broker: &EventBroker) { - let weak_ingester_state = WeakIngesterState(Arc::downgrade(&self.state)); + let weak_ingester_state = self.state.weak(); event_broker .subscribe::(weak_ingester_state) @@ -416,7 +398,7 @@ impl Ingester { let force_commit = commit_type == CommitTypeV2::Force; let leader_id: NodeId = persist_request.leader_id.into(); - let mut state_guard = with_lock_metrics!(self.state.write().await, "persist", "write"); + let mut state_guard = with_lock_metrics!(self.state.lock_fully().await, "persist", "write"); if state_guard.status != IngesterStatus::Ready { persist_failures.reserve_exact(persist_request.subrequests.len()); @@ -735,7 +717,7 @@ impl Ingester { let leader_id: NodeId = open_replication_stream_request.leader_id.into(); let follower_id: NodeId = open_replication_stream_request.follower_id.into(); - let mut state_guard = self.state.write().await; + let mut state_guard = self.state.lock_partially().await; if state_guard.status != IngesterStatus::Ready { return Err(IngestV2Error::Internal("node decommissioned".to_string())); @@ -776,7 +758,7 @@ impl Ingester { let queue_id = open_fetch_stream_request.queue_id(); let shard_status_rx = self .state - .read() + .lock_partially() .await .shards .get(&queue_id) @@ -785,9 +767,10 @@ impl Ingester { })? .shard_status_rx .clone(); + let mrecordlog = self.state.mrecordlog(); let (service_stream, _fetch_task_handle) = FetchStreamTask::spawn( open_fetch_stream_request, - self.state.clone(), + mrecordlog, shard_status_rx, FetchStreamTask::DEFAULT_BATCH_NUM_BYTES, ); @@ -806,14 +789,16 @@ impl Ingester { &mut self, init_shards_request: InitShardsRequest, ) -> IngestV2Result { - let mut state_guard = with_lock_metrics!(self.state.write().await, "init_shards", "write"); + let mut state_guard = + with_lock_metrics!(self.state.lock_fully().await, "init_shards", "write"); if state_guard.status != IngesterStatus::Ready { return Err(IngestV2Error::Internal("node decommissioned".to_string())); } for shard in init_shards_request.shards { - self.init_primary_shard(&mut state_guard, shard).await?; + self.init_primary_shard(&mut state_guard.inner, &mut state_guard.mrecordlog, shard) + .await?; } Ok(InitShardsResponse {}) } @@ -829,7 +814,7 @@ impl Ingester { ))); } let mut state_guard = - with_lock_metrics!(self.state.write().await, "truncate_shards", "write"); + with_lock_metrics!(self.state.lock_fully().await, "truncate_shards", "write"); for subrequest in truncate_shards_request.subrequests { let queue_id = subrequest.queue_id(); @@ -862,7 +847,8 @@ impl Ingester { &mut self, close_shards_request: CloseShardsRequest, ) -> IngestV2Result { - let mut state_guard = with_lock_metrics!(self.state.write().await, "close_shards", "write"); + let mut state_guard = + with_lock_metrics!(self.state.lock_partially().await, "close_shards", "write"); for shard_ids in close_shards_request.shards { for queue_id in shard_ids.queue_ids() { @@ -876,7 +862,7 @@ impl Ingester { } async fn ping_inner(&mut self, ping_request: PingRequest) -> IngestV2Result { - let state_guard = self.state.read().await; + let state_guard = self.state.lock_partially().await; if state_guard.status != IngesterStatus::Ready { return Err(IngestV2Error::Internal("node decommissioned".to_string())); @@ -910,7 +896,7 @@ impl Ingester { _decommission_request: DecommissionRequest, ) -> IngestV2Result { info!("decommissioning ingester"); - let mut state_guard = self.state.write().await; + let mut state_guard = self.state.lock_partially().await; for shard in state_guard.shards.values_mut() { shard.shard_state = ShardState::Closed; @@ -1008,7 +994,7 @@ impl IngesterService for Ingester { }) }) .collect(); - let mut state_guard = self.state.write().await; + let mut state_guard = self.state.lock_fully().await; let remove_queue_ids: HashSet = state_guard .shards .keys() @@ -1069,70 +1055,13 @@ impl IngesterService for Ingester { } } -impl IngesterState { - /// Truncates the shard identified by `queue_id` up to `truncate_up_to_position_inclusive` only - /// if the current truncation position of the shard is smaller. - async fn truncate_shard( - &mut self, - queue_id: &QueueId, - truncate_up_to_position_inclusive: &Position, - ) { - // TODO: Replace with if-let-chains when stabilized. - let Some(truncate_up_to_offset_inclusive) = truncate_up_to_position_inclusive.as_u64() - else { - return; - }; - let Some(shard) = self.shards.get_mut(queue_id) else { - return; - }; - if shard.truncation_position_inclusive >= *truncate_up_to_position_inclusive { - return; - } - match self - .mrecordlog - .truncate(queue_id, truncate_up_to_offset_inclusive) - .await - { - Ok(_) => { - shard.truncation_position_inclusive = truncate_up_to_position_inclusive.clone(); - } - Err(TruncateError::MissingQueue(_)) => { - error!("failed to truncate shard `{queue_id}`: WAL queue not found"); - self.shards.remove(queue_id); - self.rate_trackers.remove(queue_id); - info!("deleted dangling shard `{queue_id}`"); - } - Err(TruncateError::IoError(io_error)) => { - error!("failed to truncate shard `{queue_id}`: {io_error}"); - } - }; - } - - /// Deletes the shard identified by `queue_id` from the ingester state. It removes the - /// mrecordlog queue first and then removes the associated in-memory shard and rate trackers. - async fn delete_shard(&mut self, queue_id: &QueueId) { - match self.mrecordlog.delete_queue(queue_id).await { - Ok(_) | Err(DeleteQueueError::MissingQueue(_)) => { - self.shards.remove(queue_id); - self.rate_trackers.remove(queue_id); - info!("deleted shard `{queue_id}`"); - } - Err(DeleteQueueError::IoError(io_error)) => { - error!("failed to delete shard `{queue_id}`: {io_error}"); - } - }; - } -} - -struct WeakIngesterState(Weak>); - #[async_trait] impl EventSubscriber for WeakIngesterState { async fn handle_event(&mut self, shard_positions_update: ShardPositionsUpdate) { - let Some(state) = self.0.upgrade() else { + let Some(state) = self.upgrade() else { return; }; - let mut state_guard = with_lock_metrics!(state.write().await, "gc_shards", "write"); + let mut state_guard = with_lock_metrics!(state.lock_fully().await, "gc_shards", "write"); let index_uid = shard_positions_update.source_uid.index_uid; let source_id = shard_positions_update.source_uid.source_id; @@ -1319,7 +1248,7 @@ mod tests { #[tokio::test] async fn test_ingester_init() { let (_ingester_ctx, ingester) = IngesterForTest::default().build().await; - let mut state_guard = ingester.state.write().await; + let mut state_guard = ingester.state.lock_fully().await; let queue_id_01 = queue_id("test-index:0", "test-source", &ShardId::from(1)); state_guard @@ -1378,7 +1307,7 @@ mod tests { ingester.init().await.unwrap(); - let state_guard = ingester.state.read().await; + let state_guard = ingester.state.lock_fully().await; assert_eq!(state_guard.shards.len(), 1); let solo_shard_02 = state_guard.shards.get(&queue_id_02).unwrap(); @@ -1398,7 +1327,7 @@ mod tests { async fn test_ingester_broadcasts_local_shards() { let (ingester_ctx, ingester) = IngesterForTest::default().build().await; - let mut state_guard = ingester.state.write().await; + let mut state_guard = ingester.state.lock_fully().await; let queue_id_01 = queue_id("test-index:0", "test-source", &ShardId::from(1)); let shard = @@ -1429,7 +1358,7 @@ mod tests { assert_eq!(shard_info.shard_state, ShardState::Open); assert_eq!(shard_info.ingestion_rate, 0); - let mut state_guard = ingester.state.write().await; + let mut state_guard = ingester.state.lock_fully().await; state_guard .shards .get_mut(&queue_id_01) @@ -1447,7 +1376,7 @@ mod tests { let shard_info = shard_infos.iter().next().unwrap(); assert_eq!(shard_info.shard_state, ShardState::Closed); - let mut state_guard = ingester.state.write().await; + let mut state_guard = ingester.state.lock_fully().await; state_guard.shards.remove(&queue_id_01).unwrap(); drop(state_guard); @@ -1528,7 +1457,7 @@ mod tests { Some(Position::offset(2u64)) ); - let state_guard = ingester.state.read().await; + let state_guard = ingester.state.lock_fully().await; assert_eq!(state_guard.shards.len(), 2); let queue_id_01 = queue_id("test-index:0", "test-source", &ShardId::from(1)); @@ -1624,7 +1553,7 @@ mod tests { let (_ingester_ctx, mut ingester) = IngesterForTest::default().build().await; - let mut state_guard = ingester.state.write().await; + let mut state_guard = ingester.state.lock_fully().await; let queue_id = queue_id("test-index:0", "test-source", &ShardId::from(1)); let solo_shard = IngesterShard::new_solo(ShardState::Open, Position::Beginning, Position::Beginning); @@ -1667,7 +1596,7 @@ mod tests { assert_eq!(persist_failure.shard_id(), ShardId::from(1)); assert_eq!(persist_failure.reason(), PersistFailureReason::ShardClosed,); - let state_guard = ingester.state.read().await; + let state_guard = ingester.state.lock_fully().await; let shard = state_guard.shards.get(&queue_id).unwrap(); shard.assert_is_closed(); @@ -1678,7 +1607,7 @@ mod tests { async fn test_ingester_persist_deletes_dangling_shard() { let (_ingester_ctx, mut ingester) = IngesterForTest::default().build().await; - let mut state_guard = ingester.state.write().await; + let mut state_guard = ingester.state.lock_fully().await; let queue_id = queue_id("test-index:0", "test-source", &ShardId::from(1)); let solo_shard = IngesterShard::new_solo(ShardState::Open, Position::Beginning, Position::Beginning); @@ -1718,7 +1647,7 @@ mod tests { PersistFailureReason::ShardNotFound ); - let state_guard = ingester.state.read().await; + let state_guard = ingester.state.lock_fully().await; assert_eq!(state_guard.shards.len(), 0); assert_eq!(state_guard.rate_trackers.len(), 0); } @@ -1812,7 +1741,7 @@ mod tests { Some(Position::offset(2u64)) ); - let leader_state_guard = leader.state.read().await; + let leader_state_guard = leader.state.lock_fully().await; assert_eq!(leader_state_guard.shards.len(), 2); let queue_id_01 = queue_id("test-index:0", "test-source", &ShardId::from(1)); @@ -1843,7 +1772,7 @@ mod tests { ], ); - let follower_state_guard = follower.state.read().await; + let follower_state_guard = follower.state.lock_fully().await; assert_eq!(follower_state_guard.shards.len(), 2); let replica_shard_01 = follower_state_guard.shards.get(&queue_id_01).unwrap(); @@ -1995,7 +1924,7 @@ mod tests { Some(Position::offset(1u64)) ); - let leader_state_guard = leader.state.read().await; + let leader_state_guard = leader.state.lock_fully().await; assert_eq!(leader_state_guard.shards.len(), 2); let queue_id_01 = queue_id("test-index:0", "test-source", &ShardId::from(1)); @@ -2022,7 +1951,7 @@ mod tests { &[(0, "\0\0test-doc-110"), (1, "\0\0test-doc-111")], ); - let follower_state_guard = follower.state.read().await; + let follower_state_guard = follower.state.lock_fully().await; assert_eq!(follower_state_guard.shards.len(), 2); let replica_shard_01 = follower_state_guard.shards.get(&queue_id_01).unwrap(); @@ -2056,7 +1985,7 @@ mod tests { IngesterShard::new_solo(ShardState::Closed, Position::Beginning, Position::Beginning); ingester .state - .write() + .lock_fully() .await .shards .insert(queue_id_01.clone(), solo_shard); @@ -2084,7 +2013,7 @@ mod tests { assert_eq!(persist_failure.shard_id(), ShardId::from(1)); assert_eq!(persist_failure.reason(), PersistFailureReason::ShardClosed); - let state_guard = ingester.state.read().await; + let state_guard = ingester.state.lock_fully().await; assert_eq!(state_guard.shards.len(), 1); let solo_shard_01 = state_guard.shards.get(&queue_id_01).unwrap(); @@ -2104,7 +2033,7 @@ mod tests { .build() .await; - let mut state_guard = ingester.state.write().await; + let mut state_guard = ingester.state.lock_fully().await; let primary_shard = Shard { index_uid: "test-index:0".to_string(), @@ -2115,7 +2044,11 @@ mod tests { ..Default::default() }; ingester - .init_primary_shard(&mut state_guard, primary_shard) + .init_primary_shard( + &mut state_guard.inner, + &mut state_guard.mrecordlog, + primary_shard, + ) .await .unwrap(); @@ -2144,7 +2077,7 @@ mod tests { assert_eq!(persist_failure.shard_id(), ShardId::from(1)); assert_eq!(persist_failure.reason(), PersistFailureReason::RateLimited); - let state_guard = ingester.state.read().await; + let state_guard = ingester.state.lock_fully().await; assert_eq!(state_guard.shards.len(), 1); let queue_id_01 = queue_id("test-index:0", "test-source", &ShardId::from(1)); @@ -2166,7 +2099,7 @@ mod tests { .build() .await; - let mut state_guard = ingester.state.write().await; + let mut state_guard = ingester.state.lock_fully().await; let primary_shard = Shard { index_uid: "test-index:0".to_string(), @@ -2177,7 +2110,11 @@ mod tests { ..Default::default() }; ingester - .init_primary_shard(&mut state_guard, primary_shard) + .init_primary_shard( + &mut state_guard.inner, + &mut state_guard.mrecordlog, + primary_shard, + ) .await .unwrap(); @@ -2209,7 +2146,7 @@ mod tests { PersistFailureReason::ResourceExhausted ); - let state_guard = ingester.state.read().await; + let state_guard = ingester.state.lock_fully().await; assert_eq!(state_guard.shards.len(), 1); let queue_id_01 = queue_id("test-index:0", "test-source", &ShardId::from(1)); @@ -2253,7 +2190,7 @@ mod tests { .into_open_response() .unwrap(); - let state_guard = ingester.state.read().await; + let state_guard = ingester.state.lock_fully().await; assert!(state_guard.replication_tasks.contains_key("test-leader")); } @@ -2285,10 +2222,10 @@ mod tests { }; let queue_id = queue_id("test-index:0", "test-source", &ShardId::from(1)); - let mut state_guard = ingester.state.write().await; + let mut state_guard = ingester.state.lock_fully().await; ingester - .init_primary_shard(&mut state_guard, shard) + .init_primary_shard(&mut state_guard.inner, &mut state_guard.mrecordlog, shard) .await .unwrap(); @@ -2330,7 +2267,7 @@ mod tests { ); assert_eq!(mrecord_batch.mrecord_lengths, [14]); - let mut state_guard = ingester.state.write().await; + let mut state_guard = ingester.state.lock_fully().await; let records = [MRecord::new_doc("test-doc-bar").encode()].into_iter(); @@ -2390,14 +2327,22 @@ mod tests { }; let queue_id_02 = queue_id("test-index:0", "test-source", &ShardId::from(2)); - let mut state_guard = ingester.state.write().await; + let mut state_guard = ingester.state.lock_fully().await; ingester - .init_primary_shard(&mut state_guard, shard_01) + .init_primary_shard( + &mut state_guard.inner, + &mut state_guard.mrecordlog, + shard_01, + ) .await .unwrap(); ingester - .init_primary_shard(&mut state_guard, shard_02) + .init_primary_shard( + &mut state_guard.inner, + &mut state_guard.mrecordlog, + shard_02, + ) .await .unwrap(); @@ -2457,7 +2402,7 @@ mod tests { .await .unwrap(); - let state_guard = ingester.state.read().await; + let state_guard = ingester.state.lock_fully().await; assert_eq!(state_guard.shards.len(), 1); assert!(state_guard.shards.contains_key(&queue_id_01)); @@ -2476,7 +2421,7 @@ mod tests { let queue_id = queue_id("test-index:0", "test-source", &ShardId::from(1)); - let mut state_guard = ingester.state.write().await; + let mut state_guard = ingester.state.lock_fully().await; let solo_shard = IngesterShard::new_solo(ShardState::Open, Position::Beginning, Position::Beginning); state_guard.shards.insert(queue_id.clone(), solo_shard); @@ -2503,7 +2448,7 @@ mod tests { .await .unwrap(); - let state_guard = ingester.state.read().await; + let state_guard = ingester.state.lock_fully().await; assert_eq!(state_guard.shards.len(), 0); assert_eq!(state_guard.rate_trackers.len(), 0); } @@ -2533,20 +2478,28 @@ mod tests { shard_17.shard_id(), ); - let mut state_guard = ingester.state.write().await; + let mut state_guard = ingester.state.lock_fully().await; ingester - .init_primary_shard(&mut state_guard, shard_17) + .init_primary_shard( + &mut state_guard.inner, + &mut state_guard.mrecordlog, + shard_17, + ) .await .unwrap(); ingester - .init_primary_shard(&mut state_guard, shard_18) + .init_primary_shard( + &mut state_guard.inner, + &mut state_guard.mrecordlog, + shard_18, + ) .await .unwrap(); drop(state_guard); { - let state_guard = ingester.state.read().await; + let state_guard = ingester.state.lock_fully().await; assert_eq!(state_guard.shards.len(), 2); } @@ -2560,7 +2513,7 @@ mod tests { ingester.retain_shards(retain_shard_request).await.unwrap(); { - let state_guard = ingester.state.read().await; + let state_guard = ingester.state.lock_fully().await; assert_eq!(state_guard.shards.len(), 1); assert!(state_guard.shards.contains_key(&queue_id_17)); } @@ -2580,9 +2533,9 @@ mod tests { }; let queue_id = queue_id("test-index:0", "test-source", &ShardId::from(1)); - let mut state_guard = ingester.state.write().await; + let mut state_guard = ingester.state.lock_fully().await; ingester - .init_primary_shard(&mut state_guard, shard) + .init_primary_shard(&mut state_guard.inner, &mut state_guard.mrecordlog, shard) .await .unwrap(); drop(state_guard); @@ -2617,7 +2570,7 @@ mod tests { .await .unwrap(); - let state_guard = ingester.state.read().await; + let state_guard = ingester.state.lock_partially().await; let shard = state_guard.shards.get(&queue_id).unwrap(); shard.assert_is_closed(); @@ -2643,7 +2596,7 @@ mod tests { assert_eq!(observation.node_id, ingester_ctx.node_id); assert_eq!(observation.status(), IngesterStatus::Ready); - let state_guard = ingester.state.read().await; + let state_guard = ingester.state.lock_fully().await; let observe_message = ObservationMessage { node_id: ingester_ctx.node_id.to_string(), status: IngesterStatus::Decommissioning as i32, @@ -2667,7 +2620,7 @@ mod tests { #[tokio::test] async fn test_check_decommissioning_status() { let (_ingester_ctx, ingester) = IngesterForTest::default().build().await; - let mut state_guard = ingester.state.write().await; + let mut state_guard = ingester.state.lock_fully().await; ingester.check_decommissioning_status(&mut state_guard); assert_eq!(state_guard.status, IngesterStatus::Ready); @@ -2730,14 +2683,22 @@ mod tests { }; let queue_id_02 = queue_id("test-index:0", "test-source", &ShardId::from(2)); - let mut state_guard = ingester.state.write().await; + let mut state_guard = ingester.state.lock_fully().await; ingester - .init_primary_shard(&mut state_guard, shard_01) + .init_primary_shard( + &mut state_guard.inner, + &mut state_guard.mrecordlog, + shard_01, + ) .await .unwrap(); ingester - .init_primary_shard(&mut state_guard, shard_02) + .init_primary_shard( + &mut state_guard.inner, + &mut state_guard.mrecordlog, + shard_02, + ) .await .unwrap(); @@ -2782,7 +2743,7 @@ mod tests { // Yield so that the event is processed. yield_now().await; - let state_guard = ingester.state.read().await; + let state_guard = ingester.state.lock_fully().await; assert_eq!(state_guard.shards.len(), 1); assert!(state_guard.shards.contains_key(&queue_id_01)); diff --git a/quickwit/quickwit-ingest/src/ingest_v2/mod.rs b/quickwit/quickwit-ingest/src/ingest_v2/mod.rs index 4fe92ba156f..ac7a82afb0a 100644 --- a/quickwit/quickwit-ingest/src/ingest_v2/mod.rs +++ b/quickwit/quickwit-ingest/src/ingest_v2/mod.rs @@ -28,6 +28,7 @@ mod rate_meter; mod replication; mod router; mod routing_table; +mod state; #[cfg(test)] mod test_utils; mod workbench; diff --git a/quickwit/quickwit-ingest/src/ingest_v2/replication.rs b/quickwit/quickwit-ingest/src/ingest_v2/replication.rs index a903c0ac33b..112c6f6b474 100644 --- a/quickwit/quickwit-ingest/src/ingest_v2/replication.rs +++ b/quickwit/quickwit-ingest/src/ingest_v2/replication.rs @@ -19,7 +19,6 @@ use std::iter::once; use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::Arc; use std::time::Duration; use bytesize::ByteSize; @@ -34,14 +33,14 @@ use quickwit_proto::ingest::ingester::{ use quickwit_proto::ingest::{CommitTypeV2, IngestV2Error, IngestV2Result, Shard, ShardState}; use quickwit_proto::types::{NodeId, Position}; use tokio::sync::mpsc::error::TryRecvError; -use tokio::sync::{mpsc, oneshot, RwLock}; +use tokio::sync::{mpsc, oneshot}; use tokio::task::JoinHandle; use tracing::{error, warn}; -use super::ingester::IngesterState; use super::models::IngesterShard; use super::mrecord::MRecord; use super::mrecordlog_utils::check_enough_capacity; +use super::state::IngesterState; use crate::ingest_v2::metrics::INGEST_V2_METRICS; use crate::metrics::INGEST_METRICS; use crate::{estimate_size, with_request_metrics}; @@ -409,7 +408,7 @@ impl ReplicationClient { pub(super) struct ReplicationTask { leader_id: NodeId, follower_id: NodeId, - state: Arc>, + state: IngesterState, syn_replication_stream: ServiceStream, ack_replication_stream_tx: mpsc::UnboundedSender>, current_replication_seqno: ReplicationSeqNo, @@ -421,7 +420,7 @@ impl ReplicationTask { pub fn spawn( leader_id: NodeId, follower_id: NodeId, - state: Arc>, + state: IngesterState, syn_replication_stream: ServiceStream, ack_replication_stream_tx: mpsc::UnboundedSender>, disk_capacity: ByteSize, @@ -463,7 +462,7 @@ impl ReplicationTask { }; let queue_id = replica_shard.queue_id(); - let mut state_guard = self.state.write().await; + let mut state_guard = self.state.lock_fully().await; state_guard .mrecordlog @@ -516,7 +515,7 @@ impl ReplicationTask { let mut replicate_successes = Vec::with_capacity(replicate_request.subrequests.len()); let mut replicate_failures = Vec::new(); - let mut state_guard = self.state.write().await; + let mut state_guard = self.state.lock_fully().await; if state_guard.status != IngesterStatus::Ready { replicate_failures.reserve_exact(replicate_request.subrequests.len()); @@ -738,7 +737,6 @@ impl Drop for ReplicationTaskHandle { #[cfg(test)] mod tests { - use std::collections::HashMap; use mrecordlog::MultiRecordLog; use quickwit_proto::ingest::ingester::{ @@ -1014,15 +1012,7 @@ mod tests { let tempdir = tempfile::tempdir().unwrap(); let mrecordlog = MultiRecordLog::open(tempdir.path()).await.unwrap(); let (observation_tx, _observation_rx) = watch::channel(Ok(ObservationMessage::default())); - let state = Arc::new(RwLock::new(IngesterState { - mrecordlog, - shards: HashMap::new(), - rate_trackers: HashMap::new(), - replication_streams: HashMap::new(), - replication_tasks: HashMap::new(), - status: IngesterStatus::Ready, - observation_tx, - })); + let state = IngesterState::new(mrecordlog, observation_tx); let (syn_replication_stream_tx, syn_replication_stream) = ServiceStream::new_bounded(SYN_REPLICATION_STREAM_CAPACITY); let (ack_replication_stream_tx, mut ack_replication_stream) = @@ -1110,7 +1100,7 @@ mod tests { let init_replica_response = into_init_replica_response(ack_replication_message); assert_eq!(init_replica_response.replication_seqno, 2); - let state_guard = state.read().await; + let state_guard = state.lock_fully().await; let queue_id_01 = queue_id("test-index:0", "test-source", &ShardId::from(1)); @@ -1216,7 +1206,7 @@ mod tests { Position::offset(1u64) ); - let state_guard = state.read().await; + let state_guard = state.lock_fully().await; state_guard .mrecordlog @@ -1273,7 +1263,7 @@ mod tests { Position::offset(1u64) ); - let state_guard = state.read().await; + let state_guard = state.lock_fully().await; state_guard.mrecordlog.assert_records_eq( &queue_id_01, @@ -1289,15 +1279,7 @@ mod tests { let tempdir = tempfile::tempdir().unwrap(); let mrecordlog = MultiRecordLog::open(tempdir.path()).await.unwrap(); let (observation_tx, _observation_rx) = watch::channel(Ok(ObservationMessage::default())); - let state = Arc::new(RwLock::new(IngesterState { - mrecordlog, - shards: HashMap::new(), - rate_trackers: HashMap::new(), - replication_streams: HashMap::new(), - replication_tasks: HashMap::new(), - status: IngesterStatus::Ready, - observation_tx, - })); + let state = IngesterState::new(mrecordlog, observation_tx); let (syn_replication_stream_tx, syn_replication_stream) = ServiceStream::new_bounded(SYN_REPLICATION_STREAM_CAPACITY); let (ack_replication_stream_tx, mut ack_replication_stream) = @@ -1324,7 +1306,7 @@ mod tests { Position::Beginning, ); state - .write() + .lock_fully() .await .shards .insert(queue_id_01.clone(), replica_shard); @@ -1374,15 +1356,7 @@ mod tests { let tempdir = tempfile::tempdir().unwrap(); let mrecordlog = MultiRecordLog::open(tempdir.path()).await.unwrap(); let (observation_tx, _observation_rx) = watch::channel(Ok(ObservationMessage::default())); - let state = Arc::new(RwLock::new(IngesterState { - mrecordlog, - shards: HashMap::new(), - rate_trackers: HashMap::new(), - replication_streams: HashMap::new(), - replication_tasks: HashMap::new(), - status: IngesterStatus::Ready, - observation_tx, - })); + let state = IngesterState::new(mrecordlog, observation_tx); let (syn_replication_stream_tx, syn_replication_stream) = ServiceStream::new_bounded(SYN_REPLICATION_STREAM_CAPACITY); let (ack_replication_stream_tx, mut ack_replication_stream) = @@ -1409,7 +1383,7 @@ mod tests { Position::Beginning, ); state - .write() + .lock_fully() .await .shards .insert(queue_id_01.clone(), replica_shard); diff --git a/quickwit/quickwit-ingest/src/ingest_v2/state.rs b/quickwit/quickwit-ingest/src/ingest_v2/state.rs new file mode 100644 index 00000000000..f1fb26fd849 --- /dev/null +++ b/quickwit/quickwit-ingest/src/ingest_v2/state.rs @@ -0,0 +1,212 @@ +// Copyright (C) 2024 Quickwit, Inc. +// +// Quickwit is offered under the AGPL v3.0 and as commercial software. +// For commercial licensing, contact us at hello@quickwit.io. +// +// AGPL: +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +use std::ops::{Deref, DerefMut}; +use std::sync::{Arc, Weak}; + +use fnv::FnvHashMap; +use mrecordlog::error::{DeleteQueueError, TruncateError}; +use mrecordlog::MultiRecordLog; +use quickwit_common::rate_limiter::RateLimiter; +use quickwit_proto::ingest::ingester::{IngesterStatus, ObservationMessage}; +use quickwit_proto::ingest::IngestV2Result; +use quickwit_proto::types::{Position, QueueId}; +use tokio::sync::{watch, Mutex, MutexGuard, RwLock, RwLockWriteGuard}; +use tracing::{error, info}; + +use super::models::IngesterShard; +use super::rate_meter::RateMeter; +use super::replication::{ReplicationStreamTaskHandle, ReplicationTaskHandle}; +use crate::{FollowerId, LeaderId}; + +/// Stores the state of the ingester and attempts to prevent deadlocks by exposing an API that +/// guarantees that the internal data structures are always locked in the same order: first, +/// `inner`, then `mrecordlog`. +/// +/// `lock_partially` locks `inner` only, while `lock_fully` locks both `inner` and `mrecordlog`. Use +/// the former when you only need to access the in-memory state of the ingester and the latter when +/// you need to access both the in-memory state AND the WAL. +#[derive(Clone)] +pub(super) struct IngesterState { + // `inner` is a mutex because it's almost always accessed mutably. + inner: Arc>, + mrecordlog: Arc>, +} + +pub(super) struct InnerIngesterState { + pub shards: FnvHashMap, + pub rate_trackers: FnvHashMap, + // Replication stream opened with followers. + pub replication_streams: FnvHashMap, + // Replication tasks running for each replication stream opened with leaders. + pub replication_tasks: FnvHashMap, + pub status: IngesterStatus, + pub observation_tx: watch::Sender>, +} + +impl IngesterState { + pub fn new( + mrecordlog: MultiRecordLog, + observation_tx: watch::Sender>, + ) -> Self { + let inner = InnerIngesterState { + shards: Default::default(), + rate_trackers: Default::default(), + replication_streams: Default::default(), + replication_tasks: Default::default(), + status: IngesterStatus::Ready, + observation_tx, + }; + let inner = Arc::new(Mutex::new(inner)); + let mrecordlog = Arc::new(RwLock::new(mrecordlog)); + Self { inner, mrecordlog } + } + + pub async fn lock_partially(&self) -> PartiallyLockedIngesterState<'_> { + PartiallyLockedIngesterState { + inner: self.inner.lock().await, + } + } + + pub async fn lock_fully(&self) -> FullyLockedIngesterState<'_> { + FullyLockedIngesterState { + inner: self.inner.lock().await, + mrecordlog: self.mrecordlog.write().await, + } + } + + // Leaks the mrecordlog lock for use in fetch tasks. It's safe to do so because fetch tasks + // never attempt to lock the inner state. + pub fn mrecordlog(&self) -> Arc> { + self.mrecordlog.clone() + } + + pub fn weak(&self) -> WeakIngesterState { + WeakIngesterState { + inner: Arc::downgrade(&self.inner), + mrecordlog: Arc::downgrade(&self.mrecordlog), + } + } +} + +pub(super) struct PartiallyLockedIngesterState<'a> { + pub inner: MutexGuard<'a, InnerIngesterState>, +} + +impl Deref for PartiallyLockedIngesterState<'_> { + type Target = InnerIngesterState; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for PartiallyLockedIngesterState<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +pub(super) struct FullyLockedIngesterState<'a> { + pub inner: MutexGuard<'a, InnerIngesterState>, + pub mrecordlog: RwLockWriteGuard<'a, MultiRecordLog>, +} + +impl Deref for FullyLockedIngesterState<'_> { + type Target = InnerIngesterState; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for FullyLockedIngesterState<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl FullyLockedIngesterState<'_> { + /// Truncates the shard identified by `queue_id` up to `truncate_up_to_position_inclusive` only + /// if the current truncation position of the shard is smaller. + pub async fn truncate_shard( + &mut self, + queue_id: &QueueId, + truncate_up_to_position_inclusive: &Position, + ) { + // TODO: Replace with if-let-chains when stabilized. + let Some(truncate_up_to_offset_inclusive) = truncate_up_to_position_inclusive.as_u64() + else { + return; + }; + let Some(shard) = self.inner.shards.get_mut(queue_id) else { + return; + }; + if shard.truncation_position_inclusive >= *truncate_up_to_position_inclusive { + return; + } + match self + .mrecordlog + .truncate(queue_id, truncate_up_to_offset_inclusive) + .await + { + Ok(_) => { + shard.truncation_position_inclusive = truncate_up_to_position_inclusive.clone(); + } + Err(TruncateError::MissingQueue(_)) => { + error!("failed to truncate shard `{queue_id}`: WAL queue not found"); + self.shards.remove(queue_id); + self.rate_trackers.remove(queue_id); + info!("deleted dangling shard `{queue_id}`"); + } + Err(TruncateError::IoError(io_error)) => { + error!("failed to truncate shard `{queue_id}`: {io_error}"); + } + }; + } + + /// Deletes the shard identified by `queue_id` from the ingester state. It removes the + /// mrecordlog queue first and then removes the associated in-memory shard and rate trackers. + pub async fn delete_shard(&mut self, queue_id: &QueueId) { + match self.mrecordlog.delete_queue(queue_id).await { + Ok(_) | Err(DeleteQueueError::MissingQueue(_)) => { + self.shards.remove(queue_id); + self.rate_trackers.remove(queue_id); + info!("deleted shard `{queue_id}`"); + } + Err(DeleteQueueError::IoError(io_error)) => { + error!("failed to delete shard `{queue_id}`: {io_error}"); + } + }; + } +} + +#[derive(Clone)] +pub(super) struct WeakIngesterState { + inner: Weak>, + mrecordlog: Weak>, +} + +impl WeakIngesterState { + pub fn upgrade(&self) -> Option { + let inner = self.inner.upgrade()?; + let mrecordlog = self.mrecordlog.upgrade()?; + Some(IngesterState { inner, mrecordlog }) + } +} diff --git a/quickwit/quickwit-storage/src/storage_resolver.rs b/quickwit/quickwit-storage/src/storage_resolver.rs index 108bdd3a5b6..781617016d2 100644 --- a/quickwit/quickwit-storage/src/storage_resolver.rs +++ b/quickwit/quickwit-storage/src/storage_resolver.rs @@ -109,7 +109,7 @@ impl StorageResolver { builder = builder.register(UnsupportedStorage::new( StorageBackend::Azure, - "Quickwit was compiled without the `azure` feature.", + "Quickwit was compiled without the `azure` feature", )) } #[cfg(feature = "gcs")] @@ -124,12 +124,12 @@ impl StorageResolver { builder = builder.register(UnsupportedStorage::new( StorageBackend::Google, - "Quickwit was compiled without the `gcs` feature.", + "Quickwit was compiled without the `gcs` feature", )) } builder .build() - .expect("Storage factory and config backends should match.") + .expect("storage factory and config backends should match") } /// Returns a [`StorageResolver`] for testing purposes. Unlike @@ -140,7 +140,7 @@ impl StorageResolver { .register(RamStorageFactory::default()) .register(LocalFileStorageFactory) .build() - .expect("Storage factory and config backends should match.") + .expect("storage factory and config backends should match") } }