diff options
author | Jack Lloyd <[email protected]> | 2017-12-17 07:34:48 -0500 |
---|---|---|
committer | Jack Lloyd <[email protected]> | 2017-12-17 14:59:21 -0500 |
commit | 08740bad13334128a0837364e4473c8a17f12dd4 (patch) | |
tree | c8dcdec1310b84372977c6216f97b9c73091da4d /src/lib/utils/socket | |
parent | eb76255d74f8281f0edd6acc0df6086a29c6d1d4 (diff) |
Merge BSD and Winsock variations together
Diffstat (limited to 'src/lib/utils/socket')
-rw-r--r-- | src/lib/utils/socket/socket.cpp | 166 |
1 files changed, 61 insertions, 105 deletions
diff --git a/src/lib/utils/socket/socket.cpp b/src/lib/utils/socket/socket.cpp index 0733ae630..55adc1319 100644 --- a/src/lib/utils/socket/socket.cpp +++ b/src/lib/utils/socket/socket.cpp @@ -135,12 +135,29 @@ class Asio_Socket final : public OS::Socket boost::asio::ip::tcp::socket m_tcp; }; -#elif defined(BOTAN_TARGET_OS_TYPE_IS_WINDOWS) +#elif defined(BOTAN_TARGET_OS_TYPE_IS_UNIX) || defined(BOTAN_TARGET_OS_TYPE_IS_WINDOWS) -class Winsock_Socket final : public OS::Socket +class BSD_Socket final : public OS::Socket { - public: - Winsock_Socket(const std::string& hostname, const std::string& service) + private: +#if defined(BOTAN_TARGET_OS_TYPE_IS_WINDOWS) + typedef SOCKET socket_type; + static socket_type invalid_socket() { return INVALID_SOCKET; } + static void close_socket(socket_type s) { ::closesocket(s); } + static std::string get_last_socket_error() { return std::to_string(::WSAGetLastError()); } + + static bool nonblocking_connect_in_progress() + { + return (::WSAGetLastError() == WSAEWOULDBLOCK); + } + + static void set_nonblocking(socket_type s) + { + u_long nonblocking = 1; + ::ioctlsocket(s, FIONBIO, &nonblocking); + } + + static void socket_init() { WSAData wsa_data; WORD wsa_version = MAKEWORD(2, 2); @@ -155,93 +172,37 @@ class Winsock_Socket final : public OS::Socket ::WSACleanup(); throw Exception("Could not find a usable version of Winsock.dll"); } - - addrinfo hints; - ::memset(&hints, 0, sizeof(addrinfo)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - addrinfo* res; - - if(::getaddrinfo(hostname.c_str(), service.c_str(), &hints, &res) != 0) - { - throw Exception("Name resolution failed for " + hostname); - } - - for(addrinfo* rp = res; (m_socket == INVALID_SOCKET) && (rp != nullptr); rp = rp->ai_next) - { - m_socket = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); - - // unsupported socket type? - if(m_socket == INVALID_SOCKET) - continue; - - if(::connect(m_socket, rp->ai_addr, rp->ai_addrlen) != 0) - { - ::closesocket(m_socket); - m_socket = INVALID_SOCKET; - continue; - } - } - - ::freeaddrinfo(res); - - if(m_socket == INVALID_SOCKET) - { - throw Exception("Connecting to " + hostname + - " for service " + service + " failed"); - } } - ~Winsock_Socket() + static void socket_fini() { - ::closesocket(m_socket); - m_socket = INVALID_SOCKET; ::WSACleanup(); } - - void write(const uint8_t buf[], size_t len) override - { - size_t sent_so_far = 0; - while(sent_so_far != len) - { - const size_t left = len - sent_so_far; - int sent = ::send(m_socket, - cast_uint8_ptr_to_char(buf + sent_so_far), - static_cast<int>(left), - 0); - - if(sent == SOCKET_ERROR) - throw Exception("Socket write failed with error " + - std::to_string(::WSAGetLastError())); - else - sent_so_far += static_cast<size_t>(sent); - } - } - - size_t read(uint8_t buf[], size_t len) override +#else + typedef int socket_type; + static socket_type invalid_socket() { return -1; } + static void close_socket(socket_type s) { ::close(s); } + static std::string get_last_socket_error() { return ::strerror(errno); } + static bool nonblocking_connect_in_progress() { return (errno == EINPROGRESS); } + static void set_nonblocking(socket_type s) { - int got = ::recv(m_socket, - cast_uint8_ptr_to_char(buf), - static_cast<int>(len), 0); - - if(got == SOCKET_ERROR) - throw Exception("Socket read failed with error " + - std::to_string(::WSAGetLastError())); - return static_cast<size_t>(got); + if(::fcntl(s, F_SETFL, O_NONBLOCK) < 0) + throw Exception("Setting socket to non-blocking state failed"); } - private: - SOCKET m_socket = INVALID_SOCKET; - }; + static void socket_init() {} + static void socket_fini() {} +#endif -#elif defined(BOTAN_TARGET_OS_TYPE_IS_UNIX) -class BSD_Socket final : public OS::Socket - { public: BSD_Socket(const std::string& hostname, const std::string& service, std::chrono::microseconds timeout) : m_timeout(timeout) { + socket_init(); + + m_socket = invalid_socket(); + addrinfo hints; ::memset(&hints, 0, sizeof(addrinfo)); hints.ai_family = AF_UNSPEC; @@ -253,44 +214,41 @@ class BSD_Socket final : public OS::Socket throw Exception("Name resolution failed for " + hostname); } - m_fd = -1; - - for(addrinfo* rp = res; (m_fd < 0) && (rp != nullptr); rp = rp->ai_next) + for(addrinfo* rp = res; (m_socket == invalid_socket()) && (rp != nullptr); rp = rp->ai_next) { if(rp->ai_family != AF_INET && rp->ai_family != AF_INET6) continue; - m_fd = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + m_socket = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); - if(m_fd < 0) + if(m_socket == invalid_socket()) { // unsupported socket type? continue; } - if(::fcntl(m_fd, F_SETFL, O_NONBLOCK) < 0) - throw Exception("Setting socket to non-blocking state failed"); + set_nonblocking(m_socket); - int err = ::connect(m_fd, rp->ai_addr, rp->ai_addrlen); + int err = ::connect(m_socket, rp->ai_addr, rp->ai_addrlen); if(err == -1) { int active = 0; - if(errno == EINPROGRESS) + if(nonblocking_connect_in_progress()) { struct timeval timeout = make_timeout_tv(); fd_set write_set; FD_ZERO(&write_set); - FD_SET(m_fd, &write_set); + FD_SET(m_socket, &write_set); - active = ::select(m_fd + 1, nullptr, &write_set, nullptr, &timeout); + active = ::select(m_socket + 1, nullptr, &write_set, nullptr, &timeout); if(active) { int socket_error = 0; socklen_t len = sizeof(socket_error); - if(::getsockopt(m_fd, SOL_SOCKET, SO_ERROR, &socket_error, &len) < 0) + if(::getsockopt(m_socket, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&socket_error), &len) < 0) throw Exception("Error calling getsockopt"); if(socket_error != 0) @@ -302,8 +260,8 @@ class BSD_Socket final : public OS::Socket if(active == 0) { - ::close(m_fd); - m_fd = -1; + close_socket(m_socket); + m_socket = invalid_socket(); continue; } } @@ -311,7 +269,7 @@ class BSD_Socket final : public OS::Socket ::freeaddrinfo(res); - if(m_fd < 0) + if(m_socket == invalid_socket()) { throw Exception("Connecting to " + hostname + " for service " + service + " failed"); @@ -320,27 +278,28 @@ class BSD_Socket final : public OS::Socket ~BSD_Socket() { - ::close(m_fd); - m_fd = -1; + close_socket(m_socket); + m_socket = invalid_socket(); + socket_fini(); } void write(const uint8_t buf[], size_t len) override { fd_set write_set; FD_ZERO(&write_set); - FD_SET(m_fd, &write_set); + FD_SET(m_socket, &write_set); size_t sent_so_far = 0; while(sent_so_far != len) { struct timeval timeout = make_timeout_tv(); - int active = ::select(m_fd + 1, nullptr, &write_set, nullptr, &timeout); + int active = ::select(m_socket + 1, nullptr, &write_set, nullptr, &timeout); if(active == 0) throw Exception("Timeout during socket write"); const size_t left = len - sent_so_far; - ssize_t sent = ::write(m_fd, &buf[sent_so_far], left); + ssize_t sent = ::send(m_socket, cast_uint8_ptr_to_char(&buf[sent_so_far]), left, 0); if(sent < 0) throw Exception("Socket write failed with error '" + std::string(::strerror(errno)) + "'"); @@ -353,15 +312,15 @@ class BSD_Socket final : public OS::Socket { fd_set read_set; FD_ZERO(&read_set); - FD_SET(m_fd, &read_set); + FD_SET(m_socket, &read_set); struct timeval timeout = make_timeout_tv(); - int active = ::select(m_fd + 1, &read_set, nullptr, nullptr, &timeout); + int active = ::select(m_socket + 1, &read_set, nullptr, nullptr, &timeout); if(active == 0) throw Exception("Timeout during socket read"); - ssize_t got = ::read(m_fd, buf, len); + ssize_t got = ::recv(m_socket, cast_uint8_ptr_to_char(buf), len, 0); if(got < 0) throw Exception("Socket read failed with error '" + @@ -379,7 +338,7 @@ class BSD_Socket final : public OS::Socket } const std::chrono::microseconds m_timeout; - int m_fd; + socket_type m_socket; }; #endif @@ -394,10 +353,7 @@ OS::open_socket(const std::string& hostname, #if defined(BOTAN_HAS_BOOST_ASIO) return std::unique_ptr<OS::Socket>(new Asio_Socket(hostname, service, timeout)); -#elif defined(BOTAN_TARGET_OS_TYPE_IS_WINDOWS) - return std::unique_ptr<OS::Socket>(new Winsock_Socket(hostname, service)); - -#elif defined(BOTAN_TARGET_OS_TYPE_IS_UNIX) +#elif defined(BOTAN_TARGET_OS_TYPE_IS_UNIX) || defined(BOTAN_TARGET_OS_TYPE_IS_WINDOWS) return std::unique_ptr<OS::Socket>(new BSD_Socket(hostname, service, timeout)); #else |