diff --git a/ext/redis_client/hiredis/hiredis_connection.c b/ext/redis_client/hiredis/hiredis_connection.c index b490ce7..eaf2dd5 100644 --- a/ext/redis_client/hiredis/hiredis_connection.c +++ b/ext/redis_client/hiredis/hiredis_connection.c @@ -380,10 +380,9 @@ static int hiredis_wait_writable(int fd, const struct timeval *timeout, int *iss struct timeval to; struct timeval *toptr = NULL; - rb_fdset_t fds; - /* Be cautious: a call to rb_fd_init to initialize the rb_fdset_t structure * must be paired with a call to rb_fd_term to free it. */ + rb_fdset_t fds; rb_fd_init(&fds); rb_fd_set(fd, &fds); @@ -406,6 +405,42 @@ static int hiredis_wait_writable(int fd, const struct timeval *timeout, int *iss return 0; } +static int hiredis_wait(hiredis_connection_t *connection, const struct timeval *timeout, bool *readable, bool *writable) { + struct timeval to; + struct timeval *toptr = NULL; + + /* Be cautious: a call to rb_fd_init to initialize the rb_fdset_t structure + * must be paired with a call to rb_fd_term to free it. */ + rb_fdset_t read_fds; + if (readable) { + *readable = false; + rb_fd_init(&read_fds); + rb_fd_set(connection->context->fd, &read_fds); + } + + rb_fdset_t write_fds; + if (writable) { + *writable = false; + rb_fd_init(&write_fds); + rb_fd_set(connection->context->fd, &write_fds); + } + + int error = rb_thread_fd_select(connection->context->fd + 1, readable ? &read_fds : NULL, writable ? &write_fds : NULL, NULL, toptr); + if (error > 0) { + error = 0; + if (readable) *readable = (bool)rb_fd_isset(connection->context->fd, &read_fds); + if (writable) *writable = (bool)rb_fd_isset(connection->context->fd, &write_fds); + } else if (!errno) { + errno = ETIMEDOUT; + error = -2; + } + + if (readable) rb_fd_term(&read_fds); + if (writable) rb_fd_term(&write_fds); + + return error; +} + static VALUE hiredis_connect_finish(hiredis_connection_t *connection, redisContext *context) { if (context->err) { redis_raise_error_and_disconnect(context, rb_eRedisClientConnectTimeoutError); @@ -466,6 +501,24 @@ static VALUE hiredis_init_ssl(VALUE self, VALUE ssl_param) { if (redisInitiateSSLWithContext(connection->context, ssl_context->context) != REDIS_OK) { hiredis_raise_error_and_disconnect(connection, rb_eRedisClientConnectTimeoutError); } + + redisSSL *redis_ssl = redisGetSSLSocket(connection->context); + + if (redis_ssl->wantRead) { + int readable = 0; + if (hiredis_wait_readable(connection->context->fd, &connection->connect_timeout, &readable) < 0) { + hiredis_raise_error_and_disconnect(connection, rb_eRedisClientConnectTimeoutError); + } + if (!readable) { + errno = EAGAIN; + hiredis_raise_error_and_disconnect(connection, rb_eRedisClientConnectTimeoutError); + } + + if (redisInitiateSSLContinue(connection->context) != REDIS_OK) { + hiredis_raise_error_and_disconnect(connection, rb_eRedisClientConnectTimeoutError); + }; + } + return Qtrue; } diff --git a/ext/redis_client/hiredis/vendor/hiredis_ssl.h b/ext/redis_client/hiredis/vendor/hiredis_ssl.h index 604efe0..bdbf299 100644 --- a/ext/redis_client/hiredis/vendor/hiredis_ssl.h +++ b/ext/redis_client/hiredis/vendor/hiredis_ssl.h @@ -32,6 +32,8 @@ #ifndef __HIREDIS_SSL_H #define __HIREDIS_SSL_H +#include + #ifdef __cplusplus extern "C" { #endif @@ -46,6 +48,29 @@ struct ssl_st; */ typedef struct redisSSLContext redisSSLContext; +/* The SSL connection context is attached to SSL/TLS connections as a privdata. */ +typedef struct redisSSL { + /** + * OpenSSL SSL object. + */ + SSL *ssl; + + /** + * SSL_write() requires to be called again with the same arguments it was + * previously called with in the event of an SSL_read/SSL_write situation + */ + size_t lastLen; + + /** Whether the SSL layer requires read (possibly before a write) */ + int wantRead; + + /** + * Whether a write was requested prior to a read. If set, the write() + * should resume whenever a read takes place, if possible + */ + int pendingWrite; +} redisSSL; + /** * Initialization errors that redisCreateSSLContext() may return. */ @@ -114,12 +139,17 @@ void redisFreeSSLContext(redisSSLContext *redis_ssl_ctx); int redisInitiateSSLWithContext(redisContext *c, redisSSLContext *redis_ssl_ctx); +int redisInitiateSSLContinue(redisContext *c); + /** * Initiate SSL/TLS negotiation on a provided OpenSSL SSL object. */ int redisInitiateSSL(redisContext *c, struct ssl_st *ssl); + +redisSSL *redisGetSSLSocket(redisContext *c); + #ifdef __cplusplus } #endif diff --git a/ext/redis_client/hiredis/vendor/ssl.c b/ext/redis_client/hiredis/vendor/ssl.c index 7df58fb..d6b183c 100644 --- a/ext/redis_client/hiredis/vendor/ssl.c +++ b/ext/redis_client/hiredis/vendor/ssl.c @@ -59,29 +59,6 @@ struct redisSSLContext { char *server_name; }; -/* The SSL connection context is attached to SSL/TLS connections as a privdata. */ -typedef struct redisSSL { - /** - * OpenSSL SSL object. - */ - SSL *ssl; - - /** - * SSL_write() requires to be called again with the same arguments it was - * previously called with in the event of an SSL_read/SSL_write situation - */ - size_t lastLen; - - /** Whether the SSL layer requires read (possibly before a write) */ - int wantRead; - - /** - * Whether a write was requested prior to a read. If set, the write() - * should resume whenever a read takes place, if possible - */ - int pendingWrite; -} redisSSL; - /* Forward declaration */ redisContextFuncs redisContextSSLFuncs; @@ -163,6 +140,22 @@ int redisInitOpenSSL(void) return REDIS_OK; } +static int maybeCheckWant(redisSSL *rssl, int rv) { + /** + * If the error is WANT_READ or WANT_WRITE, the appropriate flags are set + * and true is returned. False is returned otherwise + */ + if (rv == SSL_ERROR_WANT_READ) { + rssl->wantRead = 1; + return 1; + } else if (rv == SSL_ERROR_WANT_WRITE) { + rssl->pendingWrite = 1; + return 1; + } else { + return 0; + } +} + /** * redisSSLContext helper context destruction. */ @@ -261,6 +254,42 @@ redisSSLContext *redisCreateSSLContext(const char *cacert_filename, const char * return NULL; } +int redisInitiateSSLContinue(redisContext *c) { + if (!c->privctx) { + __redisSetError(c, REDIS_ERR_OTHER, "redisContext is not associated"); + return REDIS_ERR; + } + + redisSSL *rssl = (redisSSL *)c->privctx; + ERR_clear_error(); + int rv = SSL_connect(rssl->ssl); + if (rv == 1) { + c->privctx = rssl; + return REDIS_OK; + } + + rv = SSL_get_error(rssl->ssl, rv); + if (((c->flags & REDIS_BLOCK) == 0) && + (rv == SSL_ERROR_WANT_READ || rv == SSL_ERROR_WANT_WRITE)) { + maybeCheckWant(rssl, rv); + c->privctx = rssl; + return REDIS_OK; + } + + if (c->err == 0) { + char err[512]; + if (rv == SSL_ERROR_SYSCALL) + snprintf(err,sizeof(err)-1,"SSL_connect failed: %s",strerror(errno)); + else { + unsigned long e = ERR_peek_last_error(); + snprintf(err,sizeof(err)-1,"SSL_connect failed: %s", + ERR_reason_error_string(e)); + } + __redisSetError(c, REDIS_ERR_IO, err); + } + return REDIS_ERR; +} + /** * SSL Connection initialization. */ @@ -295,6 +324,7 @@ static int redisSSLConnect(redisContext *c, SSL *ssl) { rv = SSL_get_error(rssl->ssl, rv); if (((c->flags & REDIS_BLOCK) == 0) && (rv == SSL_ERROR_WANT_READ || rv == SSL_ERROR_WANT_WRITE)) { + maybeCheckWant(rssl, rv); c->privctx = rssl; return REDIS_OK; } @@ -315,6 +345,10 @@ static int redisSSLConnect(redisContext *c, SSL *ssl) { return REDIS_ERR; } +redisSSL *redisGetSSLSocket(redisContext *c) { + return c->privctx; +} + /** * A wrapper around redisSSLConnect() for users who manage their own context and * create their own SSL object. @@ -361,22 +395,6 @@ int redisInitiateSSLWithContext(redisContext *c, redisSSLContext *redis_ssl_ctx) return REDIS_ERR; } -static int maybeCheckWant(redisSSL *rssl, int rv) { - /** - * If the error is WANT_READ or WANT_WRITE, the appropriate flags are set - * and true is returned. False is returned otherwise - */ - if (rv == SSL_ERROR_WANT_READ) { - rssl->wantRead = 1; - return 1; - } else if (rv == SSL_ERROR_WANT_WRITE) { - rssl->pendingWrite = 1; - return 1; - } else { - return 0; - } -} - /** * Implementation of redisContextFuncs for SSL connections. */ diff --git a/test/redis_client/connection_test.rb b/test/redis_client/connection_test.rb index 3744ae6..adf2904 100644 --- a/test/redis_client/connection_test.rb +++ b/test/redis_client/connection_test.rb @@ -218,16 +218,6 @@ class SSLConnectionTest < Minitest::Test include ClientTestHelper include ConnectionTests - if ENV["DRIVER"] == "hiredis" - def test_tcp_connect_downstream_timeout - skip "TODO: Find the proper way to timeout SSL connections with hiredis" - end - - def test_tcp_connect_upstream_timeout - skip "TODO: Find the proper way to timeout SSL connections with hiredis" - end - end - private def new_client(**overrides)