diff options
author | lloyd <[email protected]> | 2012-09-09 22:24:21 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2012-09-09 22:24:21 +0000 |
commit | feb008e1e98f111828060e5b8f199e451a513815 (patch) | |
tree | 15ce9782be4ea79b17d05faf30ca2b0c43bfd661 /doc/examples | |
parent | 8a95a6a4bf428ac18077fb6b914aa527428e2ee2 (diff) |
Support both TLS and DTLS in the tls_server example.
Drop the fairly bogus socket wrapper layer.
Diffstat (limited to 'doc/examples')
-rw-r--r-- | doc/examples/socket.h | 269 | ||||
-rw-r--r-- | doc/examples/tls_server.cpp | 300 |
2 files changed, 159 insertions, 410 deletions
diff --git a/doc/examples/socket.h b/doc/examples/socket.h deleted file mode 100644 index 9e16ab36a..000000000 --- a/doc/examples/socket.h +++ /dev/null @@ -1,269 +0,0 @@ -/* -* Unix Socket -* (C) 2004-2010 Jack Lloyd -* -* Released under the terms of the Botan license -*/ - -#ifndef SOCKET_WRAPPER_H__ -#define SOCKET_WRAPPER_H__ - -#include <stdexcept> - -#if defined(_MSC_VER) - #define SOCKET_IS_WINSOCK 1 -#endif - -#if !defined(SOCKET_IS_WINSOCK) - #define SOCKET_IS_WINSOCK 0 -#endif - -#if SOCKET_IS_WINSOCK - #include <winsock.h> - - typedef SOCKET socket_t; - const socket_t invalid_socket = INVALID_SOCKET; - #define socket_error_code WSAGetLastError() - typedef int ssize_t; - - class SocketInitializer - { - public: - SocketInitializer() - { - WSADATA wsadata; - WSAStartup(MAKEWORD(2, 2), &wsadata); - } - - ~SocketInitializer() - { - WSACleanup(); - } - }; -#else - #include <sys/types.h> - #include <sys/socket.h> - #include <sys/time.h> - #include <netinet/in.h> - #include <netdb.h> - #include <unistd.h> - #include <errno.h> - #include <fcntl.h> - - typedef int socket_t; - const socket_t invalid_socket = -1; - #define socket_error_code errno - #define closesocket close - - class SocketInitializer {}; -#endif - -#if !defined(MSG_NOSIGNAL) - #define MSG_NOSIGNAL 0 -#endif - -#include <string.h> - -class Socket - { - public: - size_t read(unsigned char[], size_t, bool dont_block = false); - void write(const unsigned char[], size_t); - - std::string peer_id() const { return peer; } - - void close() - { - if(sockfd != invalid_socket) - { - if(::closesocket(sockfd) != 0) - throw std::runtime_error("Socket::close failed"); - sockfd = invalid_socket; - } - } - - Socket(socket_t fd, const std::string& peer_id = "") : - peer(peer_id), sockfd(fd) - { - } - - Socket(const std::string&, unsigned short); - ~Socket() { close(); } - private: - std::string peer; - socket_t sockfd; - }; - -class Server_Socket - { - public: - /** - * Accept a new connection - */ - Socket* accept() - { - socket_t retval = ::accept(sockfd, 0, 0); - if(retval == invalid_socket) - throw std::runtime_error("Server_Socket: accept failed"); - return new Socket(retval); - } - - void close() - { - if(sockfd != invalid_socket) - { - if(::closesocket(sockfd) != 0) - throw std::runtime_error("Server_Socket::close failed"); - sockfd = invalid_socket; - } - } - - Server_Socket(unsigned short); - ~Server_Socket() { close(); } - private: - socket_t sockfd; - }; - -/** -* Unix Socket Constructor -*/ -Socket::Socket(const std::string& host, unsigned short port) : peer(host) - { - sockfd = invalid_socket; - - hostent* host_addr = ::gethostbyname(host.c_str()); - - if(host_addr == 0) - throw std::runtime_error("Socket: gethostbyname failed for " + host); - if(host_addr->h_addrtype != AF_INET) // FIXME - throw std::runtime_error("Socket: " + host + " has IPv6 address"); - - socket_t fd = ::socket(PF_INET, SOCK_STREAM, 0); - if(fd == invalid_socket) - throw std::runtime_error("Socket: Unable to acquire socket"); - - sockaddr_in socket_info; - ::memset(&socket_info, 0, sizeof(socket_info)); - socket_info.sin_family = AF_INET; - socket_info.sin_port = htons(port); - - ::memcpy(&socket_info.sin_addr, - host_addr->h_addr, - host_addr->h_length); - - socket_info.sin_addr = *(struct in_addr*)host_addr->h_addr; // FIXME - - if(::connect(fd, (sockaddr*)&socket_info, sizeof(struct sockaddr)) != 0) - { - ::closesocket(fd); - throw std::runtime_error("Socket: connect failed"); - } - - //fcntl(fd, F_SETFL, O_NONBLOCK); - - sockfd = fd; - } - -/** -* Read from a Unix socket -*/ -size_t Socket::read(unsigned char buf[], size_t length, bool partial) - { - if(sockfd == invalid_socket) - throw std::runtime_error("Socket::read: Socket not connected"); - - size_t got = 0; - - int flags = MSG_NOSIGNAL; - - while(length) - { - ssize_t this_time = ::recv(sockfd, (char*)buf + got, length, flags); - - const bool full_ret = (this_time == (ssize_t)length); - - if(this_time == 0) - break; - - if(this_time == -1) - { - if(socket_error_code == EINTR) - this_time = 0; - else if(socket_error_code == EAGAIN) - break; - else - throw std::runtime_error("Socket::read: Socket read failed"); - } - - got += this_time; - length -= this_time; - - if(partial && !full_ret) - break; - } - - return got; - } - -/** -* Write to a Unix socket -*/ -void Socket::write(const unsigned char buf[], size_t length) - { - if(sockfd == invalid_socket) - throw std::runtime_error("Socket::write: Socket not connected"); - - size_t offset = 0; - while(length) - { - ssize_t sent = ::send(sockfd, (const char*)buf + offset, - length, MSG_NOSIGNAL); - - if(sent == -1) - { - if(socket_error_code == EINTR) - sent = 0; - else - throw std::runtime_error("Socket::write: Socket write failed"); - } - - offset += sent; - length -= sent; - } - } - -/** -* Unix Server Socket Constructor -*/ -Server_Socket::Server_Socket(unsigned short port) - { - sockfd = invalid_socket; - - socket_t fd = ::socket(PF_INET, SOCK_STREAM, 0); - if(fd == invalid_socket) - throw std::runtime_error("Server_Socket: Unable to acquire socket"); - - sockaddr_in socket_info; - ::memset(&socket_info, 0, sizeof(socket_info)); - socket_info.sin_family = AF_INET; - socket_info.sin_port = htons(port); - - // FIXME: support limiting listeners - socket_info.sin_addr.s_addr = INADDR_ANY; - - if(::bind(fd, (sockaddr*)&socket_info, sizeof(struct sockaddr)) != 0) - { - ::closesocket(fd); - throw std::runtime_error("Server_Socket: bind failed"); - } - - if(::listen(fd, 100) != 0) // FIXME: totally arbitrary - { - ::closesocket(fd); - throw std::runtime_error("Server_Socket: listen failed"); - } - - sockfd = fd; - } - -#endif diff --git a/doc/examples/tls_server.cpp b/doc/examples/tls_server.cpp index 727f4c333..980230d58 100644 --- a/doc/examples/tls_server.cpp +++ b/doc/examples/tls_server.cpp @@ -7,7 +7,6 @@ #include <botan/x509self.h> #include <botan/secqueue.h> -#include "socket.h" #include "credentials.h" using namespace Botan; @@ -18,209 +17,228 @@ using namespace std::placeholders; #include <string> #include <iostream> #include <memory> - -class Blocking_TLS_Server +#include <list> + +#include <sys/types.h> +#include <sys/time.h> +#include <sys/socket.h> +#include <netinet/in.h> +#include <netdb.h> +#include <unistd.h> +#include <errno.h> +#include <fcntl.h> + +int make_server_socket(const std::string& transport, u16bit port) { - public: - Blocking_TLS_Server(std::function<void (const byte[], size_t)> output_fn, - std::function<size_t (byte[], size_t)> input_fn, - std::vector<std::string>& protocols, - TLS::Session_Manager& sessions, - Credentials_Manager& creds, - TLS::Policy& policy, - RandomNumberGenerator& rng) : - input_fn(input_fn), - server( - output_fn, - std::bind(&Blocking_TLS_Server::reader_fn, std::ref(*this), _1, _2, _3), - std::bind(&Blocking_TLS_Server::handshake_complete, std::ref(*this), _1), - sessions, - creds, - policy, - rng, - protocols), - exit(false) - { - read_loop(); - } - - bool handshake_complete(const TLS::Session& session) - { - std::cout << "Handshake complete: " - << session.version().to_string() << " " - << session.ciphersuite().to_string() << " " - << "SessionID: " << hex_encode(session.session_id()) << "\n"; + int type = (transport == "tcp") ? SOCK_STREAM : SOCK_DGRAM; - if(session.srp_identifier() != "") - std::cout << "SRP identifier: " << session.srp_identifier() << "\n"; + int fd = ::socket(PF_INET, type, 0); + if(fd == -1) + throw std::runtime_error("Unable to acquire socket"); - if(server.next_protocol() != "") - std::cout << "Next protocol: " << server.next_protocol() << "\n"; + sockaddr_in socket_info; + ::memset(&socket_info, 0, sizeof(socket_info)); + socket_info.sin_family = AF_INET; + socket_info.sin_port = htons(port); - /* - std::vector<X509_Certificate> peer_certs = session.peer_certs(); - if(peer_certs.size()) - std::cout << peer_certs[0].to_string(); - */ + // FIXME: support limiting listeners + socket_info.sin_addr.s_addr = INADDR_ANY; - return true; - } + if(::bind(fd, (sockaddr*)&socket_info, sizeof(struct sockaddr)) != 0) + { + ::close(fd); + throw std::runtime_error("server bind failed"); + } - size_t read(byte buf[], size_t buf_len) + if(transport != "udp") + { + if(::listen(fd, 100) != 0) { - size_t got = read_queue.read(buf, buf_len); - - while(!exit && !got) - { - read_loop(TLS::TLS_HEADER_SIZE); - got = read_queue.read(buf, buf_len); - } - - return got; + ::close(fd); + throw std::runtime_error("listen failed"); } + } - void write(const byte buf[], size_t buf_len) - { - server.send(buf, buf_len); - } + return fd; + } - void close() { server.close(); } +bool handshake_complete(const TLS::Session& session) + { + std::cout << "Handshake complete, " << session.version().to_string() + << " using " << session.ciphersuite().to_string() << "\n"; - bool is_active() const { return server.is_active(); } + if(!session.session_id().empty()) + std::cout << "Session ID " << hex_encode(session.session_id()) << "\n"; - TLS::Server& underlying() { return server; } - private: - void read_loop(size_t init_desired = 0) - { - size_t desired = init_desired; + if(!session.session_ticket().empty()) + std::cout << "Session ticket " << hex_encode(session.session_ticket()) << "\n"; - byte buf[4096]; - while(!exit && (!server.is_active() || desired)) - { - const size_t asking = std::max(sizeof(buf), std::min(desired, static_cast<size_t>(1))); + std::cout << "Secure renegotiation is" + << (session.secure_renegotiation() ? "" : " NOT") + << " supported\n"; - const size_t socket_got = input_fn(&buf[0], asking); + return true; + } - if(socket_got == 0) // eof? - { - close(); - printf("got eof on socket\n"); - exit = true; - } +void dgram_socket_write(int sockfd, const byte buf[], size_t length) + { + ssize_t sent = ::send(sockfd, buf, length, MSG_NOSIGNAL); - desired = server.received_data(&buf[0], socket_got); - } - } + if(sent == -1) + printf("Error writing to socket %s\n", strerror(errno)); + else if(sent != length) + printf("Note: packet of length %d truncated to %d\n", length, sent); + } - void reader_fn(const byte buf[], size_t buf_len, TLS::Alert alert) - { - if(alert.is_valid()) - { - printf("Alert %s\n", alert.type_string().c_str()); - //exit = true; - } +void stream_socket_write(int sockfd, const byte buf[], size_t length) + { + size_t offset = 0; - printf("Got %d bytes: ", (int)buf_len); - for(size_t i = 0; i != buf_len; ++i) - { - if(isprint(buf[i])) - printf("%c", buf[i]); - } - printf("\n"); + while(length) + { + ssize_t sent = ::send(sockfd, (const char*)buf + offset, + length, MSG_NOSIGNAL); - read_queue.write(buf, buf_len); + if(sent == -1) + { + if(errno == EINTR) + sent = 0; + else + throw std::runtime_error("Socket::write: Socket write failed"); } - std::function<size_t (byte[], size_t)> input_fn; - TLS::Server server; - SecureQueue read_queue; - bool exit; - }; + offset += sent; + length -= sent; + } + } int main(int argc, char* argv[]) { int port = 4433; + std::string transport = "tcp"; - if(argc == 2) + if(argc >= 2) port = to_u32bit(argv[1]); + if(argc >= 3) + transport = argv[2]; try { LibraryInitializer botan_init; - //SocketInitializer socket_init; AutoSeeded_RNG rng; - Server_Socket listener(port); - TLS::Policy policy; - TLS::Session_Manager_In_Memory sessions; + TLS::Session_Manager_In_Memory session_manager; Credentials_Manager_Simple creds(rng); - std::vector<std::string> protocols; - /* * These are the protocols we advertise to the client, but the * client will send back whatever it actually plans on talking, * which may or may not take into account what we advertise. */ - protocols.push_back("echo/1.0"); - protocols.push_back("echo/1.1"); + const std::vector<std::string> protocols = { "echo/1.0", "echo/1.1" }; + + printf("Listening for new connection on %s port %d\n", transport.c_str(), port); + + int server_fd = make_server_socket(transport, port); while(true) { - try { - printf("Listening for new connection on port %d\n", port); + try + { + int fd; - std::auto_ptr<Socket> sock(listener.accept()); + if(transport == "tcp") + fd = ::accept(server_fd, NULL, NULL); + else + { + struct sockaddr_in from; + socklen_t from_len = sizeof(sockaddr_in); - printf("Got new connection\n"); + if(::recvfrom(server_fd, NULL, 0, MSG_PEEK, + (struct sockaddr*)&from, &from_len) != 0) + throw std::runtime_error("Could not peek next packet"); - Blocking_TLS_Server tls( - std::bind(&Socket::write, std::ref(sock), _1, _2), - std::bind(&Socket::read, std::ref(sock), _1, _2, true), - protocols, - sessions, - creds, - policy, - rng); + if(::connect(server_fd, (struct sockaddr*)&from, from_len) != 0) + throw std::runtime_error("Could not connect UDP socket"); - const char* msg = "Welcome to the best echo server evar\n"; - tls.write((const Botan::byte*)msg, strlen(msg)); + fd = server_fd; + } - std::string line; + printf("New connection received\n"); - while(tls.is_active()) - { - byte b; - size_t got = tls.read(&b, 1); + auto socket_write = + (transport == "tcp") ? + std::bind(stream_socket_write, fd, _1, _2) : + std::bind(dgram_socket_write, fd, _1, _2); - if(got == 0) - break; + std::string s; + std::list<std::string> pending_output; - line += (char)b; - if(b == '\n') - { - //std::cout << line; + pending_output.push_back("Welcome to the best echo server evar\n"); - tls.write(reinterpret_cast<const byte*>(line.data()), line.size()); + auto proc_fn = [&](const byte input[], size_t input_len, TLS::Alert alert) + { + if(alert.is_valid()) + std::cout << "Alert: " << alert.type_string() << "\n"; - if(line == "quit\n") + for(size_t i = 0; i != input_len; ++i) + { + char c = (char)input[i]; + s += c; + if(c == '\n') { - tls.close(); - break; + pending_output.push_back(s); + s.clear(); } + } + }; + + TLS::Server server(socket_write, + proc_fn, + handshake_complete, + session_manager, + creds, + policy, + rng, + protocols); + + while(!server.is_closed()) + { + byte buf[4*1024] = { 0 }; + size_t got = ::read(fd, buf, sizeof(buf)); + + if(got == -1) + { + printf("Error in socket read %s\n", strerror(errno)); + break; + } + + if(got == 0) + { + printf("EOF on socket\n"); + break; + } - if(line == "reneg\n") - tls.underlying().renegotiate(false); - else if(line == "RENEG\n") - tls.underlying().renegotiate(true); + server.received_data(buf, got); - line.clear(); + while(server.is_active() && !pending_output.empty()) + { + std::string s = pending_output.front(); + pending_output.pop_front(); + server.send(s); + + if(s == "quit\n") + server.close(); } } + + if(transport == "tcp") + ::close(fd); + } catch(std::exception& e) { printf("Connection problem: %s\n", e.what()); } } |