aboutsummaryrefslogtreecommitdiffstats
path: root/src/apps/tls_client.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/apps/tls_client.cpp')
-rw-r--r--src/apps/tls_client.cpp263
1 files changed, 263 insertions, 0 deletions
diff --git a/src/apps/tls_client.cpp b/src/apps/tls_client.cpp
new file mode 100644
index 000000000..24c8197f6
--- /dev/null
+++ b/src/apps/tls_client.cpp
@@ -0,0 +1,263 @@
+#include "apps.h"
+#include <botan/tls_client.h>
+#include <botan/pkcs8.h>
+#include <botan/hex.h>
+#include <stdio.h>
+#include <string>
+#include <iostream>
+#include <memory>
+
+#include <sys/types.h>
+#include <sys/time.h>
+#include <sys/socket.h>
+#include <netinet/in.h>
+#include <netdb.h>
+#include <unistd.h>
+#include <errno.h>
+#include <fcntl.h>
+
+#if !defined(MSG_NOSIGNAL)
+ #define MSG_NOSIGNAL 0
+#endif
+
+#if defined(BOTAN_HAS_TLS_SQLITE3_SESSION_MANAGER)
+ #include <botan/tls_session_manager_sqlite.h>
+#endif
+
+#include "credentials.h"
+
+using namespace Botan;
+
+using namespace std::placeholders;
+
+namespace {
+
+int connect_to_host(const std::string& host, u16bit port, const std::string& transport)
+ {
+ hostent* host_addr = ::gethostbyname(host.c_str());
+
+ if(host_addr == 0)
+ throw std::runtime_error("gethostbyname failed for " + host);
+
+ if(host_addr->h_addrtype != AF_INET) // FIXME
+ throw std::runtime_error(host + " has IPv6 address");
+
+ int type = (transport == "tcp") ? SOCK_STREAM : SOCK_DGRAM;
+
+ int fd = ::socket(PF_INET, type, 0);
+ if(fd == -1)
+ throw std::runtime_error("Unable to acquire socket");
+
+ sockaddr_in socket_info;
+ ::memset(&socket_info, 0, sizeof(socket_info));
+ socket_info.sin_family = AF_INET;
+ socket_info.sin_port = htons(port);
+
+ ::memcpy(&socket_info.sin_addr,
+ host_addr->h_addr,
+ host_addr->h_length);
+
+ socket_info.sin_addr = *(struct in_addr*)host_addr->h_addr; // FIXME
+
+ if(::connect(fd, (sockaddr*)&socket_info, sizeof(struct sockaddr)) != 0)
+ {
+ ::close(fd);
+ throw std::runtime_error("connect failed");
+ }
+
+ return fd;
+ }
+
+bool handshake_complete(const TLS::Session& session)
+ {
+ std::cout << "Handshake complete, " << session.version().to_string()
+ << " using " << session.ciphersuite().to_string() << "\n";
+
+ if(!session.session_id().empty())
+ std::cout << "Session ID " << hex_encode(session.session_id()) << "\n";
+
+ if(!session.session_ticket().empty())
+ std::cout << "Session ticket " << hex_encode(session.session_ticket()) << "\n";
+
+ return true;
+ }
+
+void dgram_socket_write(int sockfd, const byte buf[], size_t length)
+ {
+ send(sockfd, buf, length, MSG_NOSIGNAL);
+ }
+
+void stream_socket_write(int sockfd, const byte buf[], size_t length)
+ {
+ size_t offset = 0;
+
+ while(length)
+ {
+ ssize_t sent = ::send(sockfd, (const char*)buf + offset,
+ length, MSG_NOSIGNAL);
+
+ if(sent == -1)
+ {
+ if(errno == EINTR)
+ sent = 0;
+ else
+ throw std::runtime_error("Socket::write: Socket write failed");
+ }
+
+ offset += sent;
+ length -= sent;
+ }
+ }
+
+bool got_alert = false;
+
+void alert_received(TLS::Alert alert, const byte [], size_t )
+ {
+ std::cout << "Alert: " << alert.type_string() << "\n";
+ got_alert = true;
+ }
+
+void process_data(const byte buf[], size_t buf_size)
+ {
+ for(size_t i = 0; i != buf_size; ++i)
+ std::cout << buf[i];
+ }
+
+std::string protocol_chooser(const std::vector<std::string>& protocols)
+ {
+ for(size_t i = 0; i != protocols.size(); ++i)
+ std::cout << "Protocol " << i << " = " << protocols[i] << "\n";
+ return "http/1.1";
+ }
+
+}
+
+int tls_client(int argc, char* argv[])
+ {
+ if(argc != 2 && argc != 3 && argc != 4)
+ {
+ std::cout << "Usage " << argv[0] << " host [port] [udp|tcp]\n";
+ return 1;
+ }
+
+ try
+ {
+ AutoSeeded_RNG rng;
+ TLS::Policy policy;
+
+#if defined(BOTAN_HAS_TLS_SQLITE3_SESSION_MANAGER)
+ TLS::Session_Manager_SQLite session_manager("my secret passphrase",
+ rng,
+ "sessions.db");
+#else
+ TLS::Session_Manager_In_Memory session_manager(rng);
+#endif
+
+ Credentials_Manager_Simple creds(rng);
+
+ std::string host = argv[1];
+ u32bit port = argc >= 3 ? Botan::to_u32bit(argv[2]) : 443;
+ std::string transport = argc >= 4 ? argv[3] : "tcp";
+
+ int sockfd = connect_to_host(host, port, transport);
+
+ auto socket_write =
+ (transport == "tcp") ?
+ std::bind(stream_socket_write, sockfd, _1, _2) :
+ std::bind(dgram_socket_write, sockfd, _1, _2);
+
+ auto version =
+ (transport == "tcp") ?
+ TLS::Protocol_Version::latest_tls_version() :
+ TLS::Protocol_Version::latest_dtls_version();
+
+ TLS::Client client(socket_write,
+ process_data,
+ alert_received,
+ handshake_complete,
+ session_manager,
+ creds,
+ policy,
+ rng,
+ TLS::Server_Information(host, port),
+ version,
+ protocol_chooser);
+
+ while(!client.is_closed())
+ {
+ fd_set readfds;
+ FD_ZERO(&readfds);
+ FD_SET(sockfd, &readfds);
+ FD_SET(STDIN_FILENO, &readfds);
+
+ ::select(sockfd + 1, &readfds, NULL, NULL, NULL);
+
+ if(FD_ISSET(sockfd, &readfds))
+ {
+ byte buf[4*1024] = { 0 };
+
+ ssize_t got = ::read(sockfd, buf, sizeof(buf));
+
+ if(got == 0)
+ {
+ std::cout << "EOF on socket\n";
+ break;
+ }
+ else if(got == -1)
+ {
+ std::cout << "Socket error: " << errno << " " << strerror(errno) << "\n";
+ continue;
+ }
+
+ //std::cout << "Socket - got " << got << " bytes\n";
+ client.received_data(buf, got);
+ }
+ else if(FD_ISSET(STDIN_FILENO, &readfds))
+ {
+ byte buf[1024] = { 0 };
+ ssize_t got = read(STDIN_FILENO, buf, sizeof(buf));
+
+ if(got == 0)
+ {
+ std::cout << "EOF on stdin\n";
+ client.close();
+ break;
+ }
+ else if(got == -1)
+ {
+ std::cout << "Stdin error: " << errno << " " << strerror(errno) << "\n";
+ continue;
+ }
+
+ if(got == 2 && buf[1] == '\n')
+ {
+ char cmd = buf[0];
+
+ if(cmd == 'R' || cmd == 'r')
+ {
+ std::cout << "Client initiated renegotiation\n";
+ client.renegotiate(cmd == 'R');
+ }
+ else if(cmd == 'Q')
+ {
+ std::cout << "Client initiated close\n";
+ client.close();
+ }
+ }
+ else if(buf[0] == 'H')
+ client.heartbeat(&buf[1], got-1);
+ else
+ client.send(buf, got);
+ }
+ }
+
+ ::close(sockfd);
+
+ }
+ catch(std::exception& e)
+ {
+ std::cout << "Exception: " << e.what() << "\n";
+ return 1;
+ }
+ return 0;
+ }