diff options
-rw-r--r-- | src/lib/utils/socket/socket.cpp | 66 |
1 files changed, 61 insertions, 5 deletions
diff --git a/src/lib/utils/socket/socket.cpp b/src/lib/utils/socket/socket.cpp index e1f54a238..0733ae630 100644 --- a/src/lib/utils/socket/socket.cpp +++ b/src/lib/utils/socket/socket.cpp @@ -27,6 +27,7 @@ #include <string.h> #include <unistd.h> #include <errno.h> + #include <fcntl.h> #elif defined(BOTAN_TARGET_OS_TYPE_IS_WINDOWS) #define NOMINMAX 1 @@ -256,6 +257,9 @@ class BSD_Socket final : public OS::Socket for(addrinfo* rp = res; (m_fd < 0) && (rp != nullptr); rp = rp->ai_next) { + if(rp->ai_family != AF_INET && rp->ai_family != AF_INET6) + continue; + m_fd = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); if(m_fd < 0) @@ -264,11 +268,44 @@ class BSD_Socket final : public OS::Socket continue; } - if(::connect(m_fd, rp->ai_addr, rp->ai_addrlen) != 0) + if(::fcntl(m_fd, F_SETFL, O_NONBLOCK) < 0) + throw Exception("Setting socket to non-blocking state failed"); + + int err = ::connect(m_fd, rp->ai_addr, rp->ai_addrlen); + + if(err == -1) { - ::close(m_fd); - m_fd = -1; - continue; + int active = 0; + if(errno == EINPROGRESS) + { + struct timeval timeout = make_timeout_tv(); + fd_set write_set; + FD_ZERO(&write_set); + FD_SET(m_fd, &write_set); + + active = ::select(m_fd + 1, nullptr, &write_set, nullptr, &timeout); + + if(active) + { + int socket_error = 0; + socklen_t len = sizeof(socket_error); + + if(::getsockopt(m_fd, SOL_SOCKET, SO_ERROR, &socket_error, &len) < 0) + throw Exception("Error calling getsockopt"); + + if(socket_error != 0) + { + active = 0; + } + } + } + + if(active == 0) + { + ::close(m_fd); + m_fd = -1; + continue; + } } } @@ -287,12 +324,21 @@ class BSD_Socket final : public OS::Socket m_fd = -1; } - void write(const uint8_t buf[], size_t len) override { + fd_set write_set; + FD_ZERO(&write_set); + FD_SET(m_fd, &write_set); + size_t sent_so_far = 0; while(sent_so_far != len) { + struct timeval timeout = make_timeout_tv(); + int active = ::select(m_fd + 1, nullptr, &write_set, nullptr, &timeout); + + if(active == 0) + throw Exception("Timeout during socket write"); + const size_t left = len - sent_so_far; ssize_t sent = ::write(m_fd, &buf[sent_so_far], left); if(sent < 0) @@ -305,6 +351,16 @@ class BSD_Socket final : public OS::Socket size_t read(uint8_t buf[], size_t len) override { + fd_set read_set; + FD_ZERO(&read_set); + FD_SET(m_fd, &read_set); + + struct timeval timeout = make_timeout_tv(); + int active = ::select(m_fd + 1, &read_set, nullptr, nullptr, &timeout); + + if(active == 0) + throw Exception("Timeout during socket read"); + ssize_t got = ::read(m_fd, buf, len); if(got < 0) |