aboutsummaryrefslogtreecommitdiffstats
path: root/doc/examples
diff options
context:
space:
mode:
authorlloyd <[email protected]>2012-09-09 22:24:21 +0000
committerlloyd <[email protected]>2012-09-09 22:24:21 +0000
commitfeb008e1e98f111828060e5b8f199e451a513815 (patch)
tree15ce9782be4ea79b17d05faf30ca2b0c43bfd661 /doc/examples
parent8a95a6a4bf428ac18077fb6b914aa527428e2ee2 (diff)
Support both TLS and DTLS in the tls_server example.
Drop the fairly bogus socket wrapper layer.
Diffstat (limited to 'doc/examples')
-rw-r--r--doc/examples/socket.h269
-rw-r--r--doc/examples/tls_server.cpp300
2 files changed, 159 insertions, 410 deletions
diff --git a/doc/examples/socket.h b/doc/examples/socket.h
deleted file mode 100644
index 9e16ab36a..000000000
--- a/doc/examples/socket.h
+++ /dev/null
@@ -1,269 +0,0 @@
-/*
-* Unix Socket
-* (C) 2004-2010 Jack Lloyd
-*
-* Released under the terms of the Botan license
-*/
-
-#ifndef SOCKET_WRAPPER_H__
-#define SOCKET_WRAPPER_H__
-
-#include <stdexcept>
-
-#if defined(_MSC_VER)
- #define SOCKET_IS_WINSOCK 1
-#endif
-
-#if !defined(SOCKET_IS_WINSOCK)
- #define SOCKET_IS_WINSOCK 0
-#endif
-
-#if SOCKET_IS_WINSOCK
- #include <winsock.h>
-
- typedef SOCKET socket_t;
- const socket_t invalid_socket = INVALID_SOCKET;
- #define socket_error_code WSAGetLastError()
- typedef int ssize_t;
-
- class SocketInitializer
- {
- public:
- SocketInitializer()
- {
- WSADATA wsadata;
- WSAStartup(MAKEWORD(2, 2), &wsadata);
- }
-
- ~SocketInitializer()
- {
- WSACleanup();
- }
- };
-#else
- #include <sys/types.h>
- #include <sys/socket.h>
- #include <sys/time.h>
- #include <netinet/in.h>
- #include <netdb.h>
- #include <unistd.h>
- #include <errno.h>
- #include <fcntl.h>
-
- typedef int socket_t;
- const socket_t invalid_socket = -1;
- #define socket_error_code errno
- #define closesocket close
-
- class SocketInitializer {};
-#endif
-
-#if !defined(MSG_NOSIGNAL)
- #define MSG_NOSIGNAL 0
-#endif
-
-#include <string.h>
-
-class Socket
- {
- public:
- size_t read(unsigned char[], size_t, bool dont_block = false);
- void write(const unsigned char[], size_t);
-
- std::string peer_id() const { return peer; }
-
- void close()
- {
- if(sockfd != invalid_socket)
- {
- if(::closesocket(sockfd) != 0)
- throw std::runtime_error("Socket::close failed");
- sockfd = invalid_socket;
- }
- }
-
- Socket(socket_t fd, const std::string& peer_id = "") :
- peer(peer_id), sockfd(fd)
- {
- }
-
- Socket(const std::string&, unsigned short);
- ~Socket() { close(); }
- private:
- std::string peer;
- socket_t sockfd;
- };
-
-class Server_Socket
- {
- public:
- /**
- * Accept a new connection
- */
- Socket* accept()
- {
- socket_t retval = ::accept(sockfd, 0, 0);
- if(retval == invalid_socket)
- throw std::runtime_error("Server_Socket: accept failed");
- return new Socket(retval);
- }
-
- void close()
- {
- if(sockfd != invalid_socket)
- {
- if(::closesocket(sockfd) != 0)
- throw std::runtime_error("Server_Socket::close failed");
- sockfd = invalid_socket;
- }
- }
-
- Server_Socket(unsigned short);
- ~Server_Socket() { close(); }
- private:
- socket_t sockfd;
- };
-
-/**
-* Unix Socket Constructor
-*/
-Socket::Socket(const std::string& host, unsigned short port) : peer(host)
- {
- sockfd = invalid_socket;
-
- hostent* host_addr = ::gethostbyname(host.c_str());
-
- if(host_addr == 0)
- throw std::runtime_error("Socket: gethostbyname failed for " + host);
- if(host_addr->h_addrtype != AF_INET) // FIXME
- throw std::runtime_error("Socket: " + host + " has IPv6 address");
-
- socket_t fd = ::socket(PF_INET, SOCK_STREAM, 0);
- if(fd == invalid_socket)
- throw std::runtime_error("Socket: Unable to acquire socket");
-
- sockaddr_in socket_info;
- ::memset(&socket_info, 0, sizeof(socket_info));
- socket_info.sin_family = AF_INET;
- socket_info.sin_port = htons(port);
-
- ::memcpy(&socket_info.sin_addr,
- host_addr->h_addr,
- host_addr->h_length);
-
- socket_info.sin_addr = *(struct in_addr*)host_addr->h_addr; // FIXME
-
- if(::connect(fd, (sockaddr*)&socket_info, sizeof(struct sockaddr)) != 0)
- {
- ::closesocket(fd);
- throw std::runtime_error("Socket: connect failed");
- }
-
- //fcntl(fd, F_SETFL, O_NONBLOCK);
-
- sockfd = fd;
- }
-
-/**
-* Read from a Unix socket
-*/
-size_t Socket::read(unsigned char buf[], size_t length, bool partial)
- {
- if(sockfd == invalid_socket)
- throw std::runtime_error("Socket::read: Socket not connected");
-
- size_t got = 0;
-
- int flags = MSG_NOSIGNAL;
-
- while(length)
- {
- ssize_t this_time = ::recv(sockfd, (char*)buf + got, length, flags);
-
- const bool full_ret = (this_time == (ssize_t)length);
-
- if(this_time == 0)
- break;
-
- if(this_time == -1)
- {
- if(socket_error_code == EINTR)
- this_time = 0;
- else if(socket_error_code == EAGAIN)
- break;
- else
- throw std::runtime_error("Socket::read: Socket read failed");
- }
-
- got += this_time;
- length -= this_time;
-
- if(partial && !full_ret)
- break;
- }
-
- return got;
- }
-
-/**
-* Write to a Unix socket
-*/
-void Socket::write(const unsigned char buf[], size_t length)
- {
- if(sockfd == invalid_socket)
- throw std::runtime_error("Socket::write: Socket not connected");
-
- size_t offset = 0;
- while(length)
- {
- ssize_t sent = ::send(sockfd, (const char*)buf + offset,
- length, MSG_NOSIGNAL);
-
- if(sent == -1)
- {
- if(socket_error_code == EINTR)
- sent = 0;
- else
- throw std::runtime_error("Socket::write: Socket write failed");
- }
-
- offset += sent;
- length -= sent;
- }
- }
-
-/**
-* Unix Server Socket Constructor
-*/
-Server_Socket::Server_Socket(unsigned short port)
- {
- sockfd = invalid_socket;
-
- socket_t fd = ::socket(PF_INET, SOCK_STREAM, 0);
- if(fd == invalid_socket)
- throw std::runtime_error("Server_Socket: Unable to acquire socket");
-
- sockaddr_in socket_info;
- ::memset(&socket_info, 0, sizeof(socket_info));
- socket_info.sin_family = AF_INET;
- socket_info.sin_port = htons(port);
-
- // FIXME: support limiting listeners
- socket_info.sin_addr.s_addr = INADDR_ANY;
-
- if(::bind(fd, (sockaddr*)&socket_info, sizeof(struct sockaddr)) != 0)
- {
- ::closesocket(fd);
- throw std::runtime_error("Server_Socket: bind failed");
- }
-
- if(::listen(fd, 100) != 0) // FIXME: totally arbitrary
- {
- ::closesocket(fd);
- throw std::runtime_error("Server_Socket: listen failed");
- }
-
- sockfd = fd;
- }
-
-#endif
diff --git a/doc/examples/tls_server.cpp b/doc/examples/tls_server.cpp
index 727f4c333..980230d58 100644
--- a/doc/examples/tls_server.cpp
+++ b/doc/examples/tls_server.cpp
@@ -7,7 +7,6 @@
#include <botan/x509self.h>
#include <botan/secqueue.h>
-#include "socket.h"
#include "credentials.h"
using namespace Botan;
@@ -18,209 +17,228 @@ using namespace std::placeholders;
#include <string>
#include <iostream>
#include <memory>
-
-class Blocking_TLS_Server
+#include <list>
+
+#include <sys/types.h>
+#include <sys/time.h>
+#include <sys/socket.h>
+#include <netinet/in.h>
+#include <netdb.h>
+#include <unistd.h>
+#include <errno.h>
+#include <fcntl.h>
+
+int make_server_socket(const std::string& transport, u16bit port)
{
- public:
- Blocking_TLS_Server(std::function<void (const byte[], size_t)> output_fn,
- std::function<size_t (byte[], size_t)> input_fn,
- std::vector<std::string>& protocols,
- TLS::Session_Manager& sessions,
- Credentials_Manager& creds,
- TLS::Policy& policy,
- RandomNumberGenerator& rng) :
- input_fn(input_fn),
- server(
- output_fn,
- std::bind(&Blocking_TLS_Server::reader_fn, std::ref(*this), _1, _2, _3),
- std::bind(&Blocking_TLS_Server::handshake_complete, std::ref(*this), _1),
- sessions,
- creds,
- policy,
- rng,
- protocols),
- exit(false)
- {
- read_loop();
- }
-
- bool handshake_complete(const TLS::Session& session)
- {
- std::cout << "Handshake complete: "
- << session.version().to_string() << " "
- << session.ciphersuite().to_string() << " "
- << "SessionID: " << hex_encode(session.session_id()) << "\n";
+ int type = (transport == "tcp") ? SOCK_STREAM : SOCK_DGRAM;
- if(session.srp_identifier() != "")
- std::cout << "SRP identifier: " << session.srp_identifier() << "\n";
+ int fd = ::socket(PF_INET, type, 0);
+ if(fd == -1)
+ throw std::runtime_error("Unable to acquire socket");
- if(server.next_protocol() != "")
- std::cout << "Next protocol: " << server.next_protocol() << "\n";
+ sockaddr_in socket_info;
+ ::memset(&socket_info, 0, sizeof(socket_info));
+ socket_info.sin_family = AF_INET;
+ socket_info.sin_port = htons(port);
- /*
- std::vector<X509_Certificate> peer_certs = session.peer_certs();
- if(peer_certs.size())
- std::cout << peer_certs[0].to_string();
- */
+ // FIXME: support limiting listeners
+ socket_info.sin_addr.s_addr = INADDR_ANY;
- return true;
- }
+ if(::bind(fd, (sockaddr*)&socket_info, sizeof(struct sockaddr)) != 0)
+ {
+ ::close(fd);
+ throw std::runtime_error("server bind failed");
+ }
- size_t read(byte buf[], size_t buf_len)
+ if(transport != "udp")
+ {
+ if(::listen(fd, 100) != 0)
{
- size_t got = read_queue.read(buf, buf_len);
-
- while(!exit && !got)
- {
- read_loop(TLS::TLS_HEADER_SIZE);
- got = read_queue.read(buf, buf_len);
- }
-
- return got;
+ ::close(fd);
+ throw std::runtime_error("listen failed");
}
+ }
- void write(const byte buf[], size_t buf_len)
- {
- server.send(buf, buf_len);
- }
+ return fd;
+ }
- void close() { server.close(); }
+bool handshake_complete(const TLS::Session& session)
+ {
+ std::cout << "Handshake complete, " << session.version().to_string()
+ << " using " << session.ciphersuite().to_string() << "\n";
- bool is_active() const { return server.is_active(); }
+ if(!session.session_id().empty())
+ std::cout << "Session ID " << hex_encode(session.session_id()) << "\n";
- TLS::Server& underlying() { return server; }
- private:
- void read_loop(size_t init_desired = 0)
- {
- size_t desired = init_desired;
+ if(!session.session_ticket().empty())
+ std::cout << "Session ticket " << hex_encode(session.session_ticket()) << "\n";
- byte buf[4096];
- while(!exit && (!server.is_active() || desired))
- {
- const size_t asking = std::max(sizeof(buf), std::min(desired, static_cast<size_t>(1)));
+ std::cout << "Secure renegotiation is"
+ << (session.secure_renegotiation() ? "" : " NOT")
+ << " supported\n";
- const size_t socket_got = input_fn(&buf[0], asking);
+ return true;
+ }
- if(socket_got == 0) // eof?
- {
- close();
- printf("got eof on socket\n");
- exit = true;
- }
+void dgram_socket_write(int sockfd, const byte buf[], size_t length)
+ {
+ ssize_t sent = ::send(sockfd, buf, length, MSG_NOSIGNAL);
- desired = server.received_data(&buf[0], socket_got);
- }
- }
+ if(sent == -1)
+ printf("Error writing to socket %s\n", strerror(errno));
+ else if(sent != length)
+ printf("Note: packet of length %d truncated to %d\n", length, sent);
+ }
- void reader_fn(const byte buf[], size_t buf_len, TLS::Alert alert)
- {
- if(alert.is_valid())
- {
- printf("Alert %s\n", alert.type_string().c_str());
- //exit = true;
- }
+void stream_socket_write(int sockfd, const byte buf[], size_t length)
+ {
+ size_t offset = 0;
- printf("Got %d bytes: ", (int)buf_len);
- for(size_t i = 0; i != buf_len; ++i)
- {
- if(isprint(buf[i]))
- printf("%c", buf[i]);
- }
- printf("\n");
+ while(length)
+ {
+ ssize_t sent = ::send(sockfd, (const char*)buf + offset,
+ length, MSG_NOSIGNAL);
- read_queue.write(buf, buf_len);
+ if(sent == -1)
+ {
+ if(errno == EINTR)
+ sent = 0;
+ else
+ throw std::runtime_error("Socket::write: Socket write failed");
}
- std::function<size_t (byte[], size_t)> input_fn;
- TLS::Server server;
- SecureQueue read_queue;
- bool exit;
- };
+ offset += sent;
+ length -= sent;
+ }
+ }
int main(int argc, char* argv[])
{
int port = 4433;
+ std::string transport = "tcp";
- if(argc == 2)
+ if(argc >= 2)
port = to_u32bit(argv[1]);
+ if(argc >= 3)
+ transport = argv[2];
try
{
LibraryInitializer botan_init;
- //SocketInitializer socket_init;
AutoSeeded_RNG rng;
- Server_Socket listener(port);
-
TLS::Policy policy;
- TLS::Session_Manager_In_Memory sessions;
+ TLS::Session_Manager_In_Memory session_manager;
Credentials_Manager_Simple creds(rng);
- std::vector<std::string> protocols;
-
/*
* These are the protocols we advertise to the client, but the
* client will send back whatever it actually plans on talking,
* which may or may not take into account what we advertise.
*/
- protocols.push_back("echo/1.0");
- protocols.push_back("echo/1.1");
+ const std::vector<std::string> protocols = { "echo/1.0", "echo/1.1" };
+
+ printf("Listening for new connection on %s port %d\n", transport.c_str(), port);
+
+ int server_fd = make_server_socket(transport, port);
while(true)
{
- try {
- printf("Listening for new connection on port %d\n", port);
+ try
+ {
+ int fd;
- std::auto_ptr<Socket> sock(listener.accept());
+ if(transport == "tcp")
+ fd = ::accept(server_fd, NULL, NULL);
+ else
+ {
+ struct sockaddr_in from;
+ socklen_t from_len = sizeof(sockaddr_in);
- printf("Got new connection\n");
+ if(::recvfrom(server_fd, NULL, 0, MSG_PEEK,
+ (struct sockaddr*)&from, &from_len) != 0)
+ throw std::runtime_error("Could not peek next packet");
- Blocking_TLS_Server tls(
- std::bind(&Socket::write, std::ref(sock), _1, _2),
- std::bind(&Socket::read, std::ref(sock), _1, _2, true),
- protocols,
- sessions,
- creds,
- policy,
- rng);
+ if(::connect(server_fd, (struct sockaddr*)&from, from_len) != 0)
+ throw std::runtime_error("Could not connect UDP socket");
- const char* msg = "Welcome to the best echo server evar\n";
- tls.write((const Botan::byte*)msg, strlen(msg));
+ fd = server_fd;
+ }
- std::string line;
+ printf("New connection received\n");
- while(tls.is_active())
- {
- byte b;
- size_t got = tls.read(&b, 1);
+ auto socket_write =
+ (transport == "tcp") ?
+ std::bind(stream_socket_write, fd, _1, _2) :
+ std::bind(dgram_socket_write, fd, _1, _2);
- if(got == 0)
- break;
+ std::string s;
+ std::list<std::string> pending_output;
- line += (char)b;
- if(b == '\n')
- {
- //std::cout << line;
+ pending_output.push_back("Welcome to the best echo server evar\n");
- tls.write(reinterpret_cast<const byte*>(line.data()), line.size());
+ auto proc_fn = [&](const byte input[], size_t input_len, TLS::Alert alert)
+ {
+ if(alert.is_valid())
+ std::cout << "Alert: " << alert.type_string() << "\n";
- if(line == "quit\n")
+ for(size_t i = 0; i != input_len; ++i)
+ {
+ char c = (char)input[i];
+ s += c;
+ if(c == '\n')
{
- tls.close();
- break;
+ pending_output.push_back(s);
+ s.clear();
}
+ }
+ };
+
+ TLS::Server server(socket_write,
+ proc_fn,
+ handshake_complete,
+ session_manager,
+ creds,
+ policy,
+ rng,
+ protocols);
+
+ while(!server.is_closed())
+ {
+ byte buf[4*1024] = { 0 };
+ size_t got = ::read(fd, buf, sizeof(buf));
+
+ if(got == -1)
+ {
+ printf("Error in socket read %s\n", strerror(errno));
+ break;
+ }
+
+ if(got == 0)
+ {
+ printf("EOF on socket\n");
+ break;
+ }
- if(line == "reneg\n")
- tls.underlying().renegotiate(false);
- else if(line == "RENEG\n")
- tls.underlying().renegotiate(true);
+ server.received_data(buf, got);
- line.clear();
+ while(server.is_active() && !pending_output.empty())
+ {
+ std::string s = pending_output.front();
+ pending_output.pop_front();
+ server.send(s);
+
+ if(s == "quit\n")
+ server.close();
}
}
+
+ if(transport == "tcp")
+ ::close(fd);
+
}
catch(std::exception& e) { printf("Connection problem: %s\n", e.what()); }
}