aboutsummaryrefslogtreecommitdiffstats
path: root/src/lib/utils/socket/socket.cpp
diff options
context:
space:
mode:
authorJack Lloyd <[email protected]>2017-12-17 07:34:48 -0500
committerJack Lloyd <[email protected]>2017-12-17 14:59:21 -0500
commit08740bad13334128a0837364e4473c8a17f12dd4 (patch)
treec8dcdec1310b84372977c6216f97b9c73091da4d /src/lib/utils/socket/socket.cpp
parenteb76255d74f8281f0edd6acc0df6086a29c6d1d4 (diff)
Merge BSD and Winsock variations together
Diffstat (limited to 'src/lib/utils/socket/socket.cpp')
-rw-r--r--src/lib/utils/socket/socket.cpp166
1 files changed, 61 insertions, 105 deletions
diff --git a/src/lib/utils/socket/socket.cpp b/src/lib/utils/socket/socket.cpp
index 0733ae630..55adc1319 100644
--- a/src/lib/utils/socket/socket.cpp
+++ b/src/lib/utils/socket/socket.cpp
@@ -135,12 +135,29 @@ class Asio_Socket final : public OS::Socket
boost::asio::ip::tcp::socket m_tcp;
};
-#elif defined(BOTAN_TARGET_OS_TYPE_IS_WINDOWS)
+#elif defined(BOTAN_TARGET_OS_TYPE_IS_UNIX) || defined(BOTAN_TARGET_OS_TYPE_IS_WINDOWS)
-class Winsock_Socket final : public OS::Socket
+class BSD_Socket final : public OS::Socket
{
- public:
- Winsock_Socket(const std::string& hostname, const std::string& service)
+ private:
+#if defined(BOTAN_TARGET_OS_TYPE_IS_WINDOWS)
+ typedef SOCKET socket_type;
+ static socket_type invalid_socket() { return INVALID_SOCKET; }
+ static void close_socket(socket_type s) { ::closesocket(s); }
+ static std::string get_last_socket_error() { return std::to_string(::WSAGetLastError()); }
+
+ static bool nonblocking_connect_in_progress()
+ {
+ return (::WSAGetLastError() == WSAEWOULDBLOCK);
+ }
+
+ static void set_nonblocking(socket_type s)
+ {
+ u_long nonblocking = 1;
+ ::ioctlsocket(s, FIONBIO, &nonblocking);
+ }
+
+ static void socket_init()
{
WSAData wsa_data;
WORD wsa_version = MAKEWORD(2, 2);
@@ -155,93 +172,37 @@ class Winsock_Socket final : public OS::Socket
::WSACleanup();
throw Exception("Could not find a usable version of Winsock.dll");
}
-
- addrinfo hints;
- ::memset(&hints, 0, sizeof(addrinfo));
- hints.ai_family = AF_UNSPEC;
- hints.ai_socktype = SOCK_STREAM;
- addrinfo* res;
-
- if(::getaddrinfo(hostname.c_str(), service.c_str(), &hints, &res) != 0)
- {
- throw Exception("Name resolution failed for " + hostname);
- }
-
- for(addrinfo* rp = res; (m_socket == INVALID_SOCKET) && (rp != nullptr); rp = rp->ai_next)
- {
- m_socket = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
-
- // unsupported socket type?
- if(m_socket == INVALID_SOCKET)
- continue;
-
- if(::connect(m_socket, rp->ai_addr, rp->ai_addrlen) != 0)
- {
- ::closesocket(m_socket);
- m_socket = INVALID_SOCKET;
- continue;
- }
- }
-
- ::freeaddrinfo(res);
-
- if(m_socket == INVALID_SOCKET)
- {
- throw Exception("Connecting to " + hostname +
- " for service " + service + " failed");
- }
}
- ~Winsock_Socket()
+ static void socket_fini()
{
- ::closesocket(m_socket);
- m_socket = INVALID_SOCKET;
::WSACleanup();
}
-
- void write(const uint8_t buf[], size_t len) override
- {
- size_t sent_so_far = 0;
- while(sent_so_far != len)
- {
- const size_t left = len - sent_so_far;
- int sent = ::send(m_socket,
- cast_uint8_ptr_to_char(buf + sent_so_far),
- static_cast<int>(left),
- 0);
-
- if(sent == SOCKET_ERROR)
- throw Exception("Socket write failed with error " +
- std::to_string(::WSAGetLastError()));
- else
- sent_so_far += static_cast<size_t>(sent);
- }
- }
-
- size_t read(uint8_t buf[], size_t len) override
+#else
+ typedef int socket_type;
+ static socket_type invalid_socket() { return -1; }
+ static void close_socket(socket_type s) { ::close(s); }
+ static std::string get_last_socket_error() { return ::strerror(errno); }
+ static bool nonblocking_connect_in_progress() { return (errno == EINPROGRESS); }
+ static void set_nonblocking(socket_type s)
{
- int got = ::recv(m_socket,
- cast_uint8_ptr_to_char(buf),
- static_cast<int>(len), 0);
-
- if(got == SOCKET_ERROR)
- throw Exception("Socket read failed with error " +
- std::to_string(::WSAGetLastError()));
- return static_cast<size_t>(got);
+ if(::fcntl(s, F_SETFL, O_NONBLOCK) < 0)
+ throw Exception("Setting socket to non-blocking state failed");
}
- private:
- SOCKET m_socket = INVALID_SOCKET;
- };
+ static void socket_init() {}
+ static void socket_fini() {}
+#endif
-#elif defined(BOTAN_TARGET_OS_TYPE_IS_UNIX)
-class BSD_Socket final : public OS::Socket
- {
public:
BSD_Socket(const std::string& hostname,
const std::string& service,
std::chrono::microseconds timeout) : m_timeout(timeout)
{
+ socket_init();
+
+ m_socket = invalid_socket();
+
addrinfo hints;
::memset(&hints, 0, sizeof(addrinfo));
hints.ai_family = AF_UNSPEC;
@@ -253,44 +214,41 @@ class BSD_Socket final : public OS::Socket
throw Exception("Name resolution failed for " + hostname);
}
- m_fd = -1;
-
- for(addrinfo* rp = res; (m_fd < 0) && (rp != nullptr); rp = rp->ai_next)
+ for(addrinfo* rp = res; (m_socket == invalid_socket()) && (rp != nullptr); rp = rp->ai_next)
{
if(rp->ai_family != AF_INET && rp->ai_family != AF_INET6)
continue;
- m_fd = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
+ m_socket = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
- if(m_fd < 0)
+ if(m_socket == invalid_socket())
{
// unsupported socket type?
continue;
}
- if(::fcntl(m_fd, F_SETFL, O_NONBLOCK) < 0)
- throw Exception("Setting socket to non-blocking state failed");
+ set_nonblocking(m_socket);
- int err = ::connect(m_fd, rp->ai_addr, rp->ai_addrlen);
+ int err = ::connect(m_socket, rp->ai_addr, rp->ai_addrlen);
if(err == -1)
{
int active = 0;
- if(errno == EINPROGRESS)
+ if(nonblocking_connect_in_progress())
{
struct timeval timeout = make_timeout_tv();
fd_set write_set;
FD_ZERO(&write_set);
- FD_SET(m_fd, &write_set);
+ FD_SET(m_socket, &write_set);
- active = ::select(m_fd + 1, nullptr, &write_set, nullptr, &timeout);
+ active = ::select(m_socket + 1, nullptr, &write_set, nullptr, &timeout);
if(active)
{
int socket_error = 0;
socklen_t len = sizeof(socket_error);
- if(::getsockopt(m_fd, SOL_SOCKET, SO_ERROR, &socket_error, &len) < 0)
+ if(::getsockopt(m_socket, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&socket_error), &len) < 0)
throw Exception("Error calling getsockopt");
if(socket_error != 0)
@@ -302,8 +260,8 @@ class BSD_Socket final : public OS::Socket
if(active == 0)
{
- ::close(m_fd);
- m_fd = -1;
+ close_socket(m_socket);
+ m_socket = invalid_socket();
continue;
}
}
@@ -311,7 +269,7 @@ class BSD_Socket final : public OS::Socket
::freeaddrinfo(res);
- if(m_fd < 0)
+ if(m_socket == invalid_socket())
{
throw Exception("Connecting to " + hostname +
" for service " + service + " failed");
@@ -320,27 +278,28 @@ class BSD_Socket final : public OS::Socket
~BSD_Socket()
{
- ::close(m_fd);
- m_fd = -1;
+ close_socket(m_socket);
+ m_socket = invalid_socket();
+ socket_fini();
}
void write(const uint8_t buf[], size_t len) override
{
fd_set write_set;
FD_ZERO(&write_set);
- FD_SET(m_fd, &write_set);
+ FD_SET(m_socket, &write_set);
size_t sent_so_far = 0;
while(sent_so_far != len)
{
struct timeval timeout = make_timeout_tv();
- int active = ::select(m_fd + 1, nullptr, &write_set, nullptr, &timeout);
+ int active = ::select(m_socket + 1, nullptr, &write_set, nullptr, &timeout);
if(active == 0)
throw Exception("Timeout during socket write");
const size_t left = len - sent_so_far;
- ssize_t sent = ::write(m_fd, &buf[sent_so_far], left);
+ ssize_t sent = ::send(m_socket, cast_uint8_ptr_to_char(&buf[sent_so_far]), left, 0);
if(sent < 0)
throw Exception("Socket write failed with error '" +
std::string(::strerror(errno)) + "'");
@@ -353,15 +312,15 @@ class BSD_Socket final : public OS::Socket
{
fd_set read_set;
FD_ZERO(&read_set);
- FD_SET(m_fd, &read_set);
+ FD_SET(m_socket, &read_set);
struct timeval timeout = make_timeout_tv();
- int active = ::select(m_fd + 1, &read_set, nullptr, nullptr, &timeout);
+ int active = ::select(m_socket + 1, &read_set, nullptr, nullptr, &timeout);
if(active == 0)
throw Exception("Timeout during socket read");
- ssize_t got = ::read(m_fd, buf, len);
+ ssize_t got = ::recv(m_socket, cast_uint8_ptr_to_char(buf), len, 0);
if(got < 0)
throw Exception("Socket read failed with error '" +
@@ -379,7 +338,7 @@ class BSD_Socket final : public OS::Socket
}
const std::chrono::microseconds m_timeout;
- int m_fd;
+ socket_type m_socket;
};
#endif
@@ -394,10 +353,7 @@ OS::open_socket(const std::string& hostname,
#if defined(BOTAN_HAS_BOOST_ASIO)
return std::unique_ptr<OS::Socket>(new Asio_Socket(hostname, service, timeout));
-#elif defined(BOTAN_TARGET_OS_TYPE_IS_WINDOWS)
- return std::unique_ptr<OS::Socket>(new Winsock_Socket(hostname, service));
-
-#elif defined(BOTAN_TARGET_OS_TYPE_IS_UNIX)
+#elif defined(BOTAN_TARGET_OS_TYPE_IS_UNIX) || defined(BOTAN_TARGET_OS_TYPE_IS_WINDOWS)
return std::unique_ptr<OS::Socket>(new BSD_Socket(hostname, service, timeout));
#else