diff --git a/connection.c b/connection.c index 5ca98caf0..55f5d5a12 100644 --- a/connection.c +++ b/connection.c @@ -22,7 +22,7 @@ static DEFINE_MUTEX(init_lock); static struct ksmbd_conn_ops default_conn_ops; -LIST_HEAD(conn_list); +LIST_HEAD(global_conn_list); DECLARE_RWSEM(conn_list_lock); /** @@ -36,7 +36,7 @@ DECLARE_RWSEM(conn_list_lock); void ksmbd_conn_free(struct ksmbd_conn *conn) { down_write(&conn_list_lock); - list_del(&conn->conns_list); + list_del(&conn->conn_entry); up_write(&conn_list_lock); xa_destroy(&conn->sessions); @@ -80,7 +80,8 @@ struct ksmbd_conn *ksmbd_conn_alloc(void) init_waitqueue_head(&conn->req_running_q); init_waitqueue_head(&conn->r_count_q); - INIT_LIST_HEAD(&conn->conns_list); + INIT_LIST_HEAD(&conn->conn_entry); + INIT_LIST_HEAD(&conn->bind_conn_entry); INIT_LIST_HEAD(&conn->requests); INIT_LIST_HEAD(&conn->async_requests); spin_lock_init(&conn->request_lock); @@ -94,7 +95,7 @@ struct ksmbd_conn *ksmbd_conn_alloc(void) init_rwsem(&conn->session_lock); down_write(&conn_list_lock); - list_add(&conn->conns_list, &conn_list); + list_add(&conn->conn_entry, &global_conn_list); up_write(&conn_list_lock); return conn; } @@ -105,7 +106,7 @@ bool ksmbd_conn_lookup_dialect(struct ksmbd_conn *c) bool ret = false; down_read(&conn_list_lock); - list_for_each_entry(t, &conn_list, conns_list) { + list_for_each_entry(t, &global_conn_list, conn_entry) { if (memcmp(t->ClientGUID, c->ClientGUID, SMB2_CLIENT_GUID_SIZE)) continue; @@ -171,37 +172,63 @@ void ksmbd_conn_unlock(struct ksmbd_conn *conn) mutex_unlock(&conn->srv_mutex); } -void ksmbd_all_conn_set_status(u64 sess_id, u32 status) +struct ksmbd_conn *ksmbd_find_master_conn(struct ksmbd_conn *bind_conn, + u64 sess_id) { - struct ksmbd_conn *conn; + struct ksmbd_conn *master_conn; down_read(&conn_list_lock); - list_for_each_entry(conn, &conn_list, conns_list) { - if (conn->binding || xa_load(&conn->sessions, sess_id)) - WRITE_ONCE(conn->status, status); + list_for_each_entry(master_conn, &global_conn_list, conn_entry) { + if (bind_conn == master_conn) + continue; + if (xa_load(&master_conn->sessions, sess_id)) { + up_read(&conn_list_lock); + return master_conn; + } } up_read(&conn_list_lock); + + return NULL; +} + +void ksmbd_all_conn_set_status(struct ksmbd_conn *conn, u64 sess_id, u32 status) +{ + struct ksmbd_conn *master_conn = + conn->master_conn ? conn->master_conn : conn; + struct ksmbd_conn *bind_conn; + + ksmbd_conn_lock(master_conn); + list_for_each_entry(bind_conn, &master_conn->bind_conn_list, + bind_conn_entry) + WRITE_ONCE(bind_conn->status, status); + WRITE_ONCE(master_conn->status, status); + ksmbd_conn_unlock(master_conn); } void ksmbd_conn_wait_idle(struct ksmbd_conn *conn, u64 sess_id) { + struct ksmbd_conn *master_conn = + conn->master_conn ? conn->master_conn : conn; struct ksmbd_conn *bind_conn; wait_event(conn->req_running_q, atomic_read(&conn->req_running) < 2); - down_read(&conn_list_lock); - list_for_each_entry(bind_conn, &conn_list, conns_list) { + ksmbd_conn_lock(master_conn); + list_for_each_entry(bind_conn, &master_conn->bind_conn_list, + bind_conn_entry) { if (bind_conn == conn) continue; - if ((bind_conn->binding || xa_load(&bind_conn->sessions, sess_id)) && - !ksmbd_conn_releasing(bind_conn) && - atomic_read(&bind_conn->req_running)) { + if (!ksmbd_conn_releasing(bind_conn) && + atomic_read(&bind_conn->req_running)) wait_event(bind_conn->req_running_q, - atomic_read(&bind_conn->req_running) == 0); - } + atomic_read(&bind_conn->req_running) == 0); } - up_read(&conn_list_lock); + + if (conn->master_conn) + wait_event(master_conn->req_running_q, + atomic_read(&master_conn->req_running) == 0); + ksmbd_conn_unlock(master_conn); } int ksmbd_conn_write(struct ksmbd_work *work) @@ -482,7 +509,7 @@ static void stop_sessions(void) again: down_read(&conn_list_lock); - list_for_each_entry(conn, &conn_list, conns_list) { + list_for_each_entry(conn, &global_conn_list, conn_entry) { struct task_struct *task; t = conn->transport; @@ -499,7 +526,7 @@ static void stop_sessions(void) } up_read(&conn_list_lock); - if (!list_empty(&conn_list)) { + if (!list_empty(&global_conn_list)) { schedule_timeout_interruptible(HZ / 10); /* 100ms */ goto again; } diff --git a/connection.h b/connection.h index 3c005246a..178e0b61b 100644 --- a/connection.h +++ b/connection.h @@ -49,7 +49,7 @@ struct ksmbd_conn { struct ksmbd_transport *transport; struct nls_table *local_nls; struct unicode_map *um; - struct list_head conns_list; + struct list_head conn_entry; struct rw_semaphore session_lock; /* smb session 1 per user */ struct xarray sessions; @@ -105,6 +105,9 @@ struct ksmbd_conn { bool signing_negotiated; __le16 signing_algorithm; bool binding; + void *master_conn; + struct list_head bind_conn_entry; + struct list_head bind_conn_list; }; struct ksmbd_conn_ops { @@ -141,7 +144,7 @@ struct ksmbd_transport { #define KSMBD_TCP_SEND_TIMEOUT (5 * HZ) #define KSMBD_TCP_PEER_SOCKADDR(c) ((struct sockaddr *)&((c)->peer_addr)) -extern struct list_head conn_list; +extern struct list_head global_conn_list; extern struct rw_semaphore conn_list_lock; bool ksmbd_conn_alive(struct ksmbd_conn *conn); @@ -166,6 +169,9 @@ int ksmbd_conn_transport_init(void); void ksmbd_conn_transport_destroy(void); void ksmbd_conn_lock(struct ksmbd_conn *conn); void ksmbd_conn_unlock(struct ksmbd_conn *conn); +struct ksmbd_conn *ksmbd_find_master_conn(struct ksmbd_conn *bind_conn, + u64 sess_id); +void ksmbd_all_conn_set_status(struct ksmbd_conn *conn, u64 sess_id, u32 status); /* * WARNING @@ -227,6 +233,4 @@ static inline void ksmbd_conn_set_releasing(struct ksmbd_conn *conn) { WRITE_ONCE(conn->status, KSMBD_SESS_RELEASING); } - -void ksmbd_all_conn_set_status(u64 sess_id, u32 status); #endif /* __CONNECTION_H__ */ diff --git a/smb2pdu.c b/smb2pdu.c index f4bfbaf16..f9fb39ab8 100644 --- a/smb2pdu.c +++ b/smb2pdu.c @@ -1717,6 +1717,8 @@ int smb2_sess_setup(struct ksmbd_work *work) } else if (conn->dialect >= SMB30_PROT_ID && (server_conf.flags & KSMBD_GLOBAL_FLAG_SMB3_MULTICHANNEL) && req->Flags & SMB2_SESSION_REQ_FLAG_BINDING) { + struct ksmbd_conn *master_conn; + u64 sess_id = le64_to_cpu(req->hdr.SessionId); sess = ksmbd_session_lookup_slowpath(sess_id); @@ -1767,6 +1769,16 @@ int smb2_sess_setup(struct ksmbd_work *work) goto out_err; } + master_conn = ksmbd_find_master_conn(conn, sess_id); + if (!master_conn) { + rc = -ENOENT; + goto out_err; + } + + conn->master_conn = master_conn; + ksmbd_conn_lock(master_conn); + list_add(&conn->bind_conn_entry, &master_conn->bind_conn_list); + ksmbd_conn_unlock(master_conn); conn->binding = true; } else if ((conn->dialect < SMB30_PROT_ID || server_conf.flags & KSMBD_GLOBAL_FLAG_SMB3_MULTICHANNEL) && @@ -2192,7 +2204,7 @@ int smb2_session_logoff(struct ksmbd_work *work) return -ENOENT; } sess_id = le64_to_cpu(req->hdr.SessionId); - ksmbd_all_conn_set_status(sess_id, KSMBD_SESS_NEED_RECONNECT); + ksmbd_all_conn_set_status(conn, sess_id, KSMBD_SESS_NEED_RECONNECT); ksmbd_conn_unlock(conn); ksmbd_close_session_fds(work); @@ -2215,7 +2227,7 @@ int smb2_session_logoff(struct ksmbd_work *work) ksmbd_free_user(sess->user); sess->user = NULL; - ksmbd_all_conn_set_status(sess_id, KSMBD_SESS_NEED_NEGOTIATE); + ksmbd_all_conn_set_status(conn, sess_id, KSMBD_SESS_NEED_NEGOTIATE); rsp->StructureSize = cpu_to_le16(4); err = ksmbd_iov_pin_rsp(work, rsp, sizeof(struct smb2_logoff_rsp)); @@ -7549,7 +7561,7 @@ int smb2_lock(struct ksmbd_work *work) nolock = 1; /* check locks in connection list */ down_read(&conn_list_lock); - list_for_each_entry(conn, &conn_list, conns_list) { + list_for_each_entry(conn, &global_conn_list, conn_entry) { spin_lock(&conn->llist_lock); list_for_each_entry_safe(cmp_lock, tmp2, &conn->lock_list, clist) { if (file_inode(cmp_lock->fl->fl_file) !=