Skip to content

Commit

Permalink
ksmbd: add master connection
Browse files Browse the repository at this point in the history
Signed-off-by: Namjae Jeon <[email protected]>
  • Loading branch information
namjaejeon committed Oct 2, 2023
1 parent b1dd612 commit 7fe97f7
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 27 deletions.
67 changes: 47 additions & 20 deletions connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ static DEFINE_MUTEX(init_lock);

static struct ksmbd_conn_ops default_conn_ops;

LIST_HEAD(conn_list);
LIST_HEAD(global_conn_list);
DECLARE_RWSEM(conn_list_lock);

/**
Expand All @@ -36,7 +36,7 @@ DECLARE_RWSEM(conn_list_lock);
void ksmbd_conn_free(struct ksmbd_conn *conn)
{
down_write(&conn_list_lock);
list_del(&conn->conns_list);
list_del(&conn->conn_entry);
up_write(&conn_list_lock);

xa_destroy(&conn->sessions);
Expand Down Expand Up @@ -80,7 +80,8 @@ struct ksmbd_conn *ksmbd_conn_alloc(void)

init_waitqueue_head(&conn->req_running_q);
init_waitqueue_head(&conn->r_count_q);
INIT_LIST_HEAD(&conn->conns_list);
INIT_LIST_HEAD(&conn->conn_entry);
INIT_LIST_HEAD(&conn->bind_conn_entry);
INIT_LIST_HEAD(&conn->requests);
INIT_LIST_HEAD(&conn->async_requests);
spin_lock_init(&conn->request_lock);
Expand All @@ -94,7 +95,7 @@ struct ksmbd_conn *ksmbd_conn_alloc(void)
init_rwsem(&conn->session_lock);

down_write(&conn_list_lock);
list_add(&conn->conns_list, &conn_list);
list_add(&conn->conn_entry, &global_conn_list);
up_write(&conn_list_lock);
return conn;
}
Expand All @@ -105,7 +106,7 @@ bool ksmbd_conn_lookup_dialect(struct ksmbd_conn *c)
bool ret = false;

down_read(&conn_list_lock);
list_for_each_entry(t, &conn_list, conns_list) {
list_for_each_entry(t, &global_conn_list, conn_entry) {
if (memcmp(t->ClientGUID, c->ClientGUID, SMB2_CLIENT_GUID_SIZE))
continue;

Expand Down Expand Up @@ -171,37 +172,63 @@ void ksmbd_conn_unlock(struct ksmbd_conn *conn)
mutex_unlock(&conn->srv_mutex);
}

void ksmbd_all_conn_set_status(u64 sess_id, u32 status)
struct ksmbd_conn *ksmbd_find_master_conn(struct ksmbd_conn *bind_conn,
u64 sess_id)
{
struct ksmbd_conn *conn;
struct ksmbd_conn *master_conn;

down_read(&conn_list_lock);
list_for_each_entry(conn, &conn_list, conns_list) {
if (conn->binding || xa_load(&conn->sessions, sess_id))
WRITE_ONCE(conn->status, status);
list_for_each_entry(master_conn, &global_conn_list, conn_entry) {
if (bind_conn == master_conn)
continue;
if (xa_load(&master_conn->sessions, sess_id)) {
up_read(&conn_list_lock);
return master_conn;
}
}
up_read(&conn_list_lock);

return NULL;
}

void ksmbd_all_conn_set_status(struct ksmbd_conn *conn, u64 sess_id, u32 status)
{
struct ksmbd_conn *master_conn =
conn->master_conn ? conn->master_conn : conn;
struct ksmbd_conn *bind_conn;

ksmbd_conn_lock(master_conn);
list_for_each_entry(bind_conn, &master_conn->bind_conn_list,
bind_conn_entry)
WRITE_ONCE(bind_conn->status, status);
WRITE_ONCE(master_conn->status, status);
ksmbd_conn_unlock(master_conn);
}

void ksmbd_conn_wait_idle(struct ksmbd_conn *conn, u64 sess_id)
{
struct ksmbd_conn *master_conn =
conn->master_conn ? conn->master_conn : conn;
struct ksmbd_conn *bind_conn;

wait_event(conn->req_running_q, atomic_read(&conn->req_running) < 2);

down_read(&conn_list_lock);
list_for_each_entry(bind_conn, &conn_list, conns_list) {
ksmbd_conn_lock(master_conn);
list_for_each_entry(bind_conn, &master_conn->bind_conn_list,
bind_conn_entry) {
if (bind_conn == conn)
continue;

if ((bind_conn->binding || xa_load(&bind_conn->sessions, sess_id)) &&
!ksmbd_conn_releasing(bind_conn) &&
atomic_read(&bind_conn->req_running)) {
if (!ksmbd_conn_releasing(bind_conn) &&
atomic_read(&bind_conn->req_running))
wait_event(bind_conn->req_running_q,
atomic_read(&bind_conn->req_running) == 0);
}
atomic_read(&bind_conn->req_running) == 0);
}
up_read(&conn_list_lock);

if (conn->master_conn)
wait_event(master_conn->req_running_q,
atomic_read(&master_conn->req_running) == 0);
ksmbd_conn_unlock(master_conn);
}

int ksmbd_conn_write(struct ksmbd_work *work)
Expand Down Expand Up @@ -482,7 +509,7 @@ static void stop_sessions(void)

again:
down_read(&conn_list_lock);
list_for_each_entry(conn, &conn_list, conns_list) {
list_for_each_entry(conn, &global_conn_list, conn_entry) {
struct task_struct *task;

t = conn->transport;
Expand All @@ -499,7 +526,7 @@ static void stop_sessions(void)
}
up_read(&conn_list_lock);

if (!list_empty(&conn_list)) {
if (!list_empty(&global_conn_list)) {
schedule_timeout_interruptible(HZ / 10); /* 100ms */
goto again;
}
Expand Down
12 changes: 8 additions & 4 deletions connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct ksmbd_conn {
struct ksmbd_transport *transport;
struct nls_table *local_nls;
struct unicode_map *um;
struct list_head conns_list;
struct list_head conn_entry;
struct rw_semaphore session_lock;
/* smb session 1 per user */
struct xarray sessions;
Expand Down Expand Up @@ -105,6 +105,9 @@ struct ksmbd_conn {
bool signing_negotiated;
__le16 signing_algorithm;
bool binding;
void *master_conn;
struct list_head bind_conn_entry;
struct list_head bind_conn_list;
};

struct ksmbd_conn_ops {
Expand Down Expand Up @@ -141,7 +144,7 @@ struct ksmbd_transport {
#define KSMBD_TCP_SEND_TIMEOUT (5 * HZ)
#define KSMBD_TCP_PEER_SOCKADDR(c) ((struct sockaddr *)&((c)->peer_addr))

extern struct list_head conn_list;
extern struct list_head global_conn_list;
extern struct rw_semaphore conn_list_lock;

bool ksmbd_conn_alive(struct ksmbd_conn *conn);
Expand All @@ -166,6 +169,9 @@ int ksmbd_conn_transport_init(void);
void ksmbd_conn_transport_destroy(void);
void ksmbd_conn_lock(struct ksmbd_conn *conn);
void ksmbd_conn_unlock(struct ksmbd_conn *conn);
struct ksmbd_conn *ksmbd_find_master_conn(struct ksmbd_conn *bind_conn,
u64 sess_id);
void ksmbd_all_conn_set_status(struct ksmbd_conn *conn, u64 sess_id, u32 status);

/*
* WARNING
Expand Down Expand Up @@ -227,6 +233,4 @@ static inline void ksmbd_conn_set_releasing(struct ksmbd_conn *conn)
{
WRITE_ONCE(conn->status, KSMBD_SESS_RELEASING);
}

void ksmbd_all_conn_set_status(u64 sess_id, u32 status);
#endif /* __CONNECTION_H__ */
18 changes: 15 additions & 3 deletions smb2pdu.c
Original file line number Diff line number Diff line change
Expand Up @@ -1717,6 +1717,8 @@ int smb2_sess_setup(struct ksmbd_work *work)
} else if (conn->dialect >= SMB30_PROT_ID &&
(server_conf.flags & KSMBD_GLOBAL_FLAG_SMB3_MULTICHANNEL) &&
req->Flags & SMB2_SESSION_REQ_FLAG_BINDING) {
struct ksmbd_conn *master_conn;

u64 sess_id = le64_to_cpu(req->hdr.SessionId);

sess = ksmbd_session_lookup_slowpath(sess_id);
Expand Down Expand Up @@ -1767,6 +1769,16 @@ int smb2_sess_setup(struct ksmbd_work *work)
goto out_err;
}

master_conn = ksmbd_find_master_conn(conn, sess_id);
if (!master_conn) {
rc = -ENOENT;
goto out_err;
}

conn->master_conn = master_conn;
ksmbd_conn_lock(master_conn);
list_add(&conn->bind_conn_entry, &master_conn->bind_conn_list);
ksmbd_conn_unlock(master_conn);
conn->binding = true;
} else if ((conn->dialect < SMB30_PROT_ID ||
server_conf.flags & KSMBD_GLOBAL_FLAG_SMB3_MULTICHANNEL) &&
Expand Down Expand Up @@ -2192,7 +2204,7 @@ int smb2_session_logoff(struct ksmbd_work *work)
return -ENOENT;
}
sess_id = le64_to_cpu(req->hdr.SessionId);
ksmbd_all_conn_set_status(sess_id, KSMBD_SESS_NEED_RECONNECT);
ksmbd_all_conn_set_status(conn, sess_id, KSMBD_SESS_NEED_RECONNECT);
ksmbd_conn_unlock(conn);

ksmbd_close_session_fds(work);
Expand All @@ -2215,7 +2227,7 @@ int smb2_session_logoff(struct ksmbd_work *work)

ksmbd_free_user(sess->user);
sess->user = NULL;
ksmbd_all_conn_set_status(sess_id, KSMBD_SESS_NEED_NEGOTIATE);
ksmbd_all_conn_set_status(conn, sess_id, KSMBD_SESS_NEED_NEGOTIATE);

rsp->StructureSize = cpu_to_le16(4);
err = ksmbd_iov_pin_rsp(work, rsp, sizeof(struct smb2_logoff_rsp));
Expand Down Expand Up @@ -7549,7 +7561,7 @@ int smb2_lock(struct ksmbd_work *work)
nolock = 1;
/* check locks in connection list */
down_read(&conn_list_lock);
list_for_each_entry(conn, &conn_list, conns_list) {
list_for_each_entry(conn, &global_conn_list, conn_entry) {
spin_lock(&conn->llist_lock);
list_for_each_entry_safe(cmp_lock, tmp2, &conn->lock_list, clist) {
if (file_inode(cmp_lock->fl->fl_file) !=
Expand Down

0 comments on commit 7fe97f7

Please sign in to comment.