diff --git a/lib/common.c b/lib/common.c index 9d3cff02..f76bfd0d 100644 --- a/lib/common.c +++ b/lib/common.c @@ -197,7 +197,7 @@ void sasl_set_mutex(sasl_mutex_alloc_t *n, { /* Disallow mutex function changes once sasl_client_init and/or sasl_server_init is called */ - if (_sasl_server_cleanup_hook || _sasl_client_cleanup_hook) { + if (free_mutex) { return; } @@ -207,6 +207,25 @@ void sasl_set_mutex(sasl_mutex_alloc_t *n, _sasl_mutex_utils.free=d; } +void _sasl_mutex_init(void) +{ + if (_sasl_mutex_utils.alloc != &sasl_mutex_alloc) { + free_mutex = _sasl_mutex_utils.alloc(); + return; + } + +#if defined(HAVE_PTHREAD) + free_mutex = &static_mutex; +#elif defined(HAVE_NT_THREADS) + HANDLE p = CreateMutex(NULL, 0, NULL); + if (InterlockedCompareExchangePointer((PVOID*)&free_mutex, (PVOID)p, NULL) != NULL) + CloseHandle(p); +#else + free_mutex = _sasl_mutex_utils.alloc(); +#endif + return; +} + /* copy a string to malloced memory */ int _sasl_strdup(const char *in, char **out, size_t *outlen) { @@ -831,34 +850,40 @@ int _sasl_common_init(sasl_global_callbacks_t *global_callbacks) { int result; - /* The last specified global callback always wins */ - if (sasl_global_utils != NULL) { - sasl_utils_t * global_utils = (sasl_utils_t *)sasl_global_utils; + if (!free_mutex) { + _sasl_mutex_init(); + if (!free_mutex) { + return SASL_FAIL; + } + } + + result = sasl_MUTEX_LOCK(free_mutex); + if (result != SASL_OK) return SASL_FAIL; + + /* Just update global callback if we are already initialized */ + if (sasl_global_utils) { + sasl_utils_t *global_utils = (sasl_utils_t *)sasl_global_utils; global_utils->getopt = &_sasl_global_getopt; global_utils->getopt_context = global_callbacks; - } - /* Do nothing if we are already initialized */ - if (free_mutex) { + sasl_MUTEX_UNLOCK(free_mutex); return SASL_OK; - } - - /* Setup the global utilities */ - if(!sasl_global_utils) { + } else { sasl_global_utils = _sasl_alloc_utils(NULL, global_callbacks); - if(sasl_global_utils == NULL) return SASL_NOMEM; + if (sasl_global_utils == NULL) { + sasl_MUTEX_UNLOCK(free_mutex); + return SASL_NOMEM; + } } /* Init the canon_user plugin */ result = sasl_canonuser_add_plugin("INTERNAL", internal_canonuser_init); - if(result != SASL_OK) return result; - - if (!free_mutex) { - free_mutex = sasl_MUTEX_ALLOC(); + if (result != SASL_OK) { + result = SASL_FAIL; } - if (!free_mutex) return SASL_FAIL; - return SASL_OK; + sasl_MUTEX_UNLOCK(free_mutex); + return result; } /* dispose connection state, sets it to NULL @@ -874,7 +899,7 @@ void sasl_dispose(sasl_conn_t **pconn) /* serialize disposes. this is necessary because we can't dispose of conn->mutex if someone else is locked on it */ if (!free_mutex) { - free_mutex = sasl_MUTEX_ALLOC(); + _sasl_mutex_init(); if (!free_mutex) return; }