Skip to content

Commit

Permalink
msgq: Use Shared Memory Mutex and Condition Variables for Synchroniza…
Browse files Browse the repository at this point in the history
…tion
  • Loading branch information
deanlee committed May 31, 2024
1 parent 51cbf62 commit a095ab3
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 134 deletions.
4 changes: 2 additions & 2 deletions SConscript
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ messaging_objects = env.SharedObject([
messaging = env.Library('messaging', messaging_objects)
Depends('messaging/impl_zmq.cc', services_h)

env.Program('messaging/bridge', ['messaging/bridge.cc'], LIBS=[messaging, 'zmq', common])
env.Program('messaging/bridge', ['messaging/bridge.cc'], LIBS=[messaging, 'zmq', 'pthread', common])
Depends('messaging/bridge.cc', services_h)

messaging_python = envCython.Program('messaging/messaging_pyx.so', 'messaging/messaging_pyx.pyx', LIBS=envCython["LIBS"]+[messaging, "zmq", common])
Expand Down Expand Up @@ -69,7 +69,7 @@ envCython.Program('visionipc/visionipc_pyx.so', 'visionipc/visionipc_pyx.pyx',
LIBS=vipc_libs, FRAMEWORKS=vipc_frameworks)

if GetOption('extras'):
env.Program('messaging/test_runner', ['messaging/test_runner.cc', 'messaging/msgq_tests.cc'], LIBS=[messaging, common])
env.Program('messaging/test_runner', ['messaging/test_runner.cc', 'messaging/msgq_tests.cc'], LIBS=[messaging, common, 'pthread'])

env.Program('visionipc/test_runner', ['visionipc/test_runner.cc', 'visionipc/visionipc_tests.cc'],
LIBS=['pthread'] + vipc_libs, FRAMEWORKS=vipc_frameworks)
Expand Down
39 changes: 8 additions & 31 deletions messaging/impl_msgq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ int MSGQSubSocket::connect(Context *context, std::string endpoint, std::string a
return 0;
}


Message * MSGQSubSocket::receive(bool non_blocking){
msgq_do_exit = 0;

Expand All @@ -90,49 +89,27 @@ Message * MSGQSubSocket::receive(bool non_blocking){
prev_handler_sigterm = std::signal(SIGTERM, sig_handler);
}

int rc = 0;
msgq_msg_t msg;

MSGQMessage *r = NULL;

int rc = msgq_msg_recv(&msg, q);

// Hack to implement blocking read with a poller. Don't use this
while (!non_blocking && rc == 0 && msgq_do_exit == 0){
msgq_pollitem_t items[1];
items[0].q = q;

int t = (timeout != -1) ? timeout : 100;

int n = msgq_poll(items, 1, t);
while (!msgq_do_exit) {
rc = msgq_msg_recv(&msg, q);
if (rc > 0 || non_blocking) break;

// The poll indicated a message was ready, but the receive failed. Try again
if (n == 1 && rc == 0){
continue;
}

if (timeout != -1){
break;
}
int ms = (timeout != -1) ? timeout : 100;
if (!q->shm->waitFor(ms) && timeout != -1) break;
}


if (!non_blocking){
std::signal(SIGINT, prev_handler_sigint);
std::signal(SIGTERM, prev_handler_sigterm);
}

errno = msgq_do_exit ? EINTR : 0;

MSGQMessage *r = nullptr;
if (rc > 0){
if (msgq_do_exit){
msgq_msg_close(&msg); // Free unused message on exit
} else {
r = new MSGQMessage;
r->takeOwnership(msg.data, msg.size);
}
r = new MSGQMessage;
r->takeOwnership(msg.data, msg.size);
}

return (Message*)r;
}

Expand Down
207 changes: 108 additions & 99 deletions messaging/msgq.cc
Original file line number Diff line number Diff line change
@@ -1,34 +1,112 @@
#include "cereal/messaging/msgq.h"

#include <iostream>
#include <cassert>
#include <cerrno>
#include <chrono>
#include <cmath>
#include <cstring>
#include <cstdint>
#include <chrono>
#include <algorithm>
#include <cstdlib>
#include <csignal>
#include <random>
#include <string>
#include <limits>
#include <semaphore.h>

#include <poll.h>
#include <sys/ioctl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/syscall.h>
#include <fcntl.h>
#include <unistd.h>

#include <stdio.h>
constexpr const char *SHM_INIT_SEM = "/op_shm_init_sem";

#include "cereal/messaging/msgq.h"
SharedMemory::SharedMemory(const std::string &name, size_t size) :shm_name(name) {
const char* prefix = std::getenv("OPENPILOT_PREFIX");
std::string full_path = "/dev/shm/";
if (prefix) {
full_path += std::string(prefix) + "/";
}
full_path += shm_name;
shm_fd = open(full_path.c_str(), O_RDWR | O_CREAT, 0664);
assert(shm_fd != -1);

shm_size = sizeof(SharedMemoryHeader) + size;
int ret = ftruncate(shm_fd, shm_size);
assert(ret != -1);
shm_ptr = mmap(0, shm_size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0);
assert(shm_ptr != MAP_FAILED);

initMutexCond();
}

SharedMemory::~SharedMemory() {
munmap(shm_ptr, shm_size);
::close(shm_fd);
}

void SharedMemory::initMutexCond() {
sem_t *sem = sem_open(SHM_INIT_SEM, O_CREAT, 0644, 1);
assert(sem != SEM_FAILED);
sem_wait(sem); // Lock semaphore

// Initialize the header if it hasn't been initialized yet
header = (SharedMemoryHeader *)shm_ptr;
if (!header->initialized) {
pthread_mutexattr_t mutex_attr;
pthread_mutexattr_init(&mutex_attr);
pthread_mutexattr_setpshared(&mutex_attr, PTHREAD_PROCESS_SHARED);
#ifndef __APPLE__
// Set the mutex to be robust, meaning it can recover from a process crash
pthread_mutexattr_setrobust(&mutex_attr, PTHREAD_MUTEX_ROBUST);
#endif
pthread_mutex_init(&(header->mutex), &mutex_attr);

pthread_condattr_t cond_attr;
pthread_condattr_init(&cond_attr);
pthread_condattr_setpshared(&cond_attr, PTHREAD_PROCESS_SHARED);
pthread_cond_init(&(header->cond), &cond_attr);

pthread_mutexattr_destroy(&mutex_attr);
pthread_condattr_destroy(&cond_attr);
header->initialized = true;
}

sem_post(sem);
sem_close(sem);
}

void sigusr2_handler(int signal) {
assert(signal == SIGUSR2);
void SharedMemory::notifyAll() {
pthread_cond_broadcast(&(header->cond));
}

bool SharedMemory::waitFor(int timeout_ms) {
struct timespec ts;
clock_gettime(CLOCK_REALTIME, &ts);
ts.tv_sec += (timeout_ms / 1000);
ts.tv_nsec += (timeout_ms % 1000) * 1000000;
if (ts.tv_nsec >= 1000000000) {
ts.tv_sec += 1;
ts.tv_nsec -= 1000000000;
}

int ret = pthread_mutex_lock(&(header->mutex));
#ifndef __APPLE__
// Handle case where previous owner of the mutex died
if (ret == EOWNERDEAD) {
pthread_mutex_consistent((&(header->mutex)));
}
#endif
ret = pthread_cond_timedwait(&(header->cond), &(header->mutex), &ts);
pthread_mutex_unlock(&(header->mutex));

// Return true if condition was signaled, false if timed out
return ret == 0;
}

PollerContext::PollerContext() : shm("msgq_poll", 0) {
ctx = (SharedMemoryHeader *)shm.shm_ptr;
}

PollerContext poller_context;

uint64_t msgq_get_uid(void){
std::random_device rd("/dev/urandom");
std::uniform_int_distribution<uint64_t> distribution(0, std::numeric_limits<uint32_t>::max());
Expand All @@ -53,7 +131,6 @@ int msgq_msg_init_size(msgq_msg_t * msg, size_t size){

int msgq_msg_init_data(msgq_msg_t * msg, char * data, size_t size) {
int r = msgq_msg_init_size(msg, size);

if (r == 0)
memcpy(msg->data, data, size);

Expand All @@ -79,41 +156,13 @@ void msgq_wait_for_subscriber(msgq_queue_t *q){
while (*q->num_readers == 0){
// wait for subscriber
}

return;
}

int msgq_new_queue(msgq_queue_t * q, const char * path, size_t size){
assert(size < 0xFFFFFFFF); // Buffer must be smaller than 2^32 bytes
std::signal(SIGUSR2, sigusr2_handler);

std::string full_path = "/dev/shm/";
const char* prefix = std::getenv("OPENPILOT_PREFIX");
if (prefix) {
full_path += std::string(prefix) + "/";
}
full_path += path;

auto fd = open(full_path.c_str(), O_RDWR | O_CREAT, 0664);
if (fd < 0) {
std::cout << "Warning, could not open: " << full_path << std::endl;
return -1;
}

int rc = ftruncate(fd, size + sizeof(msgq_header_t));
if (rc < 0){
close(fd);
return -1;
}
char * mem = (char*)mmap(NULL, size + sizeof(msgq_header_t), PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
close(fd);

if (mem == NULL){
return -1;
}
q->mmap_p = mem;

msgq_header_t *header = (msgq_header_t *)mem;
q->shm = std::make_unique<SharedMemory>(path, size + sizeof(msgq_header_t));
msgq_header_t *header = (msgq_header_t*)(q->shm->header + 1);

// Setup pointers to header segment
q->num_readers = reinterpret_cast<std::atomic<uint64_t>*>(&header->num_readers);
Expand All @@ -126,7 +175,7 @@ int msgq_new_queue(msgq_queue_t * q, const char * path, size_t size){
q->read_uids[i] = reinterpret_cast<std::atomic<uint64_t>*>(&header->read_uids[i]);
}

q->data = mem + sizeof(msgq_header_t);
q->data = (char*)(header + 1);
q->size = size;
q->reader_id = -1;

Expand All @@ -136,12 +185,7 @@ int msgq_new_queue(msgq_queue_t * q, const char * path, size_t size){
return 0;
}

void msgq_close_queue(msgq_queue_t *q){
if (q->mmap_p != NULL){
munmap(q->mmap_p, q->size + sizeof(msgq_header_t));
}
}

void msgq_close_queue(msgq_queue_t *q) {}

void msgq_init_publisher(msgq_queue_t * q) {
//std::cout << "Starting publisher" << std::endl;
Expand All @@ -158,15 +202,6 @@ void msgq_init_publisher(msgq_queue_t * q) {
q->write_uid_local = uid;
}

static void thread_signal(uint32_t tid) {
#ifndef SYS_tkill
// TODO: this won't work for multithreaded programs
kill(tid, SIGUSR2);
#else
syscall(SYS_tkill, tid, SIGUSR2);
#endif
}

void msgq_init_subscriber(msgq_queue_t * q) {
assert(q != NULL);
assert(q->num_readers != NULL);
Expand All @@ -185,13 +220,11 @@ void msgq_init_subscriber(msgq_queue_t * q) {

for (size_t i = 0; i < NUM_READERS; i++){
*q->read_valids[i] = false;

uint64_t old_uid = *q->read_uids[i];
*q->read_uids[i] = 0;

// Wake up reader in case they are in a poll
thread_signal(old_uid & 0xFFFFFFFF);
q->shm->notifyAll();
}
poller_context.shm.notifyAll();

continue;
}
Expand Down Expand Up @@ -292,12 +325,9 @@ int msgq_msg_send(msgq_msg_t * msg, msgq_queue_t *q){
uint32_t new_ptr = ALIGN(write_pointer + msg->size + sizeof(int64_t));
PACK64(*q->write_pointer, write_cycles, new_ptr);

// Notify readers
for (uint64_t i = 0; i < num_readers; i++){
uint64_t reader_uid = *q->read_uids[i];
thread_signal(reader_uid & 0xFFFFFFFF);
}

// Notify pollers and readers
poller_context.shm.notifyAll();
q->shm->notifyAll();
return msg->size;
}

Expand Down Expand Up @@ -418,42 +448,21 @@ int msgq_msg_recv(msgq_msg_t * msg, msgq_queue_t * q){
return msg->size;
}



int msgq_poll(msgq_pollitem_t * items, size_t nitems, int timeout){
int msgq_poll(msgq_pollitem_t *items, size_t nitems, int timeout) {
int num = 0;

// Check if messages ready
for (size_t i = 0; i < nitems; i++) {
items[i].revents = msgq_msg_ready(items[i].q);
if (items[i].revents) num++;
}

int ms = (timeout == -1) ? 100 : timeout;
struct timespec ts;
ts.tv_sec = ms / 1000;
ts.tv_nsec = (ms % 1000) * 1000 * 1000;


while (num == 0) {
int ret;

ret = nanosleep(&ts, &ts);

while (true) {
// Check if messages ready
for (size_t i = 0; i < nitems; i++) {
if (items[i].revents == 0 && msgq_msg_ready(items[i].q)){
num += 1;
items[i].revents = 1;
}
for (size_t i = 0; i < nitems; ++i) {
items[i].revents = msgq_msg_ready(items[i].q);
if (items[i].revents) ++num;
}
if (num > 0 || timeout == 0) break;

// exit if we had a timeout and the sleep finished
if (timeout != -1 && ret == 0){
break;
}
// Wait until messages are ready or timeout occurs
if (!poller_context.shm.waitFor(ms) && timeout != -1) break;
}

return num;
}

Expand Down
Loading

0 comments on commit a095ab3

Please sign in to comment.