diff --git a/src/tls.c b/src/tls.c index 83b34e9f..b95150ea 100644 --- a/src/tls.c +++ b/src/tls.c @@ -376,7 +376,25 @@ static int valkeyTLSConnect(valkeyContext *c, SSL *ssl) { ERR_clear_error(); + /* Apply connect_timeout to the TLS handshake. */ + if ((c->flags & VALKEY_BLOCK) && c->connect_timeout != NULL && + c->funcs && c->funcs->set_timeout) { + c->funcs->set_timeout(c, *c->connect_timeout); + } + int rv = SSL_connect(rssl->ssl); + + /* Restore the command_timeout. */ + if ((c->flags & VALKEY_BLOCK) && c->connect_timeout != NULL && + c->funcs && c->funcs->set_timeout) { + if (c->command_timeout != NULL) { + c->funcs->set_timeout(c, *c->command_timeout); + } else { + struct timeval tv_zero = {0, 0}; + c->funcs->set_timeout(c, tv_zero); + } + } + if (rv == 1) { c->funcs = &valkeyContextTLSFuncs; c->privctx = rssl; diff --git a/tests/client_test.c b/tests/client_test.c index 4d2f477a..adb3d58d 100644 --- a/tests/client_test.c +++ b/tests/client_test.c @@ -2676,6 +2676,36 @@ static void sharded_pubsub_test(struct config cfg) { } #endif /* VALKEY_TEST_ASYNC */ +#ifdef VALKEY_TEST_TLS +/* Test that a TLS handshake to a non-TLS server times out using + * connect_timeout rather than hanging indefinitely. */ +static void test_tls_handshake_timeout(struct config config) { + test("TLS handshake to non-TLS port times out with connect_timeout: "); + +#ifndef _WIN32 + /* Abort the test after 5 seconds. */ + unsigned int old_alarm = alarm(5); +#endif + + struct timeval tv = {0, 100000}; /* 100ms */ + valkeyOptions opt = {0}; + VALKEY_OPTIONS_SET_TCP(&opt, config.tcp.host, config.tcp.port); /* TCP */ + opt.connect_timeout = &tv; + /* No command timeout set. */ + + valkeyContext *c = valkeyConnectWithOptions(&opt); + assert(c != NULL && c->err == 0); + + int rc = valkeyInitiateTLSWithContext(c, _tls_ctx); + test_cond(rc == VALKEY_ERR && c->err != 0); + valkeyFree(c); + +#ifndef _WIN32 + alarm(old_alarm); /* Reset any alarm. */ +#endif +} +#endif /* VALKEY_TEST_TLS */ + int main(int argc, char **argv) { struct config cfg = { .tcp = { @@ -2840,6 +2870,7 @@ int main(int argc, char **argv) { test_blocking_io_errors(cfg); test_invalid_timeout_errors(cfg); test_append_formatted_commands(cfg); + test_tls_handshake_timeout(cfg); if (throughput) test_throughput(cfg);