aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorlloyd <[email protected]>2011-12-23 16:14:45 +0000
committerlloyd <[email protected]>2011-12-23 16:14:45 +0000
commitd363602f95f1514b4b595d9912fba2e503edcb21 (patch)
tree70d036ff58e67629103a4ac6c1374ec90b6bd3db
parenta3d81efbd2c56749d4abf9e6a27cb36cbbb10702 (diff)
First stab at an event driven TLS client.
-rw-r--r--doc/examples/socket.h20
-rw-r--r--doc/examples/tls_client.cpp109
-rw-r--r--src/ssl/info.txt1
-rw-r--r--src/ssl/rec_wri.cpp2
-rw-r--r--src/ssl/tls_client.cpp255
-rw-r--r--src/ssl/tls_client.h46
-rw-r--r--src/ssl/tls_connection.h36
-rw-r--r--src/ssl/tls_record.h2
-rw-r--r--src/ssl/tls_server.h5
9 files changed, 211 insertions, 265 deletions
diff --git a/doc/examples/socket.h b/doc/examples/socket.h
index f7ce98fea..f10ff9f26 100644
--- a/doc/examples/socket.h
+++ b/doc/examples/socket.h
@@ -48,6 +48,7 @@
#include <netdb.h>
#include <unistd.h>
#include <errno.h>
+ #include <fcntl.h>
typedef int socket_t;
const socket_t invalid_socket = -1;
@@ -66,7 +67,7 @@
class Socket
{
public:
- size_t read(unsigned char[], size_t);
+ size_t read(unsigned char[], size_t, bool dont_block = false);
void write(const unsigned char[], size_t);
std::string peer_id() const { return peer; }
@@ -158,23 +159,28 @@ Socket::Socket(const std::string& host, unsigned short port) : peer(host)
throw std::runtime_error("Socket: connect failed");
}
+ //fcntl(fd, F_SETFL, O_NONBLOCK);
+
sockfd = fd;
}
/**
* Read from a Unix socket
*/
-size_t Socket::read(unsigned char buf[], size_t length)
+size_t Socket::read(unsigned char buf[], size_t length, bool partial)
{
if(sockfd == invalid_socket)
throw std::runtime_error("Socket::read: Socket not connected");
size_t got = 0;
+ int flags = MSG_NOSIGNAL;
+
while(length)
{
- ssize_t this_time = ::recv(sockfd, (char*)buf + got,
- length, MSG_NOSIGNAL);
+ ssize_t this_time = ::recv(sockfd, (char*)buf + got, length, flags);
+
+ const bool full_ret = (this_time == length);
if(this_time == 0)
break;
@@ -183,13 +189,19 @@ size_t Socket::read(unsigned char buf[], size_t length)
{
if(socket_error_code == EINTR)
this_time = 0;
+ else if(socket_error_code == EAGAIN)
+ break;
else
throw std::runtime_error("Socket::read: Socket read failed");
}
got += this_time;
length -= this_time;
+
+ if(partial && !full_ret)
+ break;
}
+
return got;
}
diff --git a/doc/examples/tls_client.cpp b/doc/examples/tls_client.cpp
index cedfe1ca8..a51febfcf 100644
--- a/doc/examples/tls_client.cpp
+++ b/doc/examples/tls_client.cpp
@@ -25,6 +25,76 @@ class Client_TLS_Policy : public TLS_Policy
}
};
+class HTTPS_Client
+ {
+ public:
+ HTTPS_Client(const std::string& host, u16bit port, RandomNumberGenerator& r) :
+ rng(r),
+ socket(host, port),
+ client(std::tr1::bind(&HTTPS_Client::socket_write, std::tr1::ref(*this), _1, _2),
+ std::tr1::bind(&HTTPS_Client::proc_data, std::tr1::ref(*this), _1, _2, _3),
+ policy,
+ rng)
+ {
+ SecureVector<byte> socket_buf(1024);
+ size_t desired = 0;
+
+ quit_reading = false;
+
+ while(!client.handshake_complete() || desired)
+ {
+ const size_t socket_got = socket.read(&socket_buf[0], socket_buf.size());
+ //printf("Got %d bytes from socket\n", socket_got);
+ desired = client.received_data(&socket_buf[0], socket_got);
+ socket_buf.resize(desired || 1);
+ //printf("Going around for another read?\n");
+ }
+ }
+
+ void socket_write(const byte buf[], size_t buf_size)
+ {
+ std::cout << "socket_write " << buf_size << "\n";
+ socket.write(buf, buf_size);
+ }
+
+ void proc_data(const byte data[], size_t data_len, u16bit alert_info)
+ {
+ printf("Block of data %d bytes alert %04X\n", (int)data_len, alert_info);
+ for(size_t i = 0; i != data_len; ++i)
+ printf("%c", data[i]);
+
+ if(data_len == 0 && alert_info == 0)
+ quit_reading = true;
+ }
+
+ void write(const std::string& s)
+ {
+ client.queue_for_sending((const byte*)s.c_str(), s.length());
+ }
+
+ void read_response()
+ {
+ while(true)
+ {
+ SecureVector<byte> buf(4096);
+
+ size_t got = socket.read(&buf[0], buf.size(), true);
+
+ if(got == 0)
+ break;
+
+ client.received_data(&buf[0], got);
+ }
+ }
+
+ private:
+ bool quit_reading;
+ RandomNumberGenerator& rng;
+ Socket socket;
+ Client_TLS_Policy policy;
+ TLS_Client client;
+ };
+
int main(int argc, char* argv[])
{
if(argc != 2 && argc != 3)
@@ -40,48 +110,21 @@ int main(int argc, char* argv[])
std::string host = argv[1];
u32bit port = argc == 3 ? Botan::to_u32bit(argv[2]) : 443;
- printf("Connecting to %s:%d...\n", host.c_str(), port);
-
- SocketInitializer socket_init;
-
- Socket sock(argv[1], port);
+ //SocketInitializer socket_init;
AutoSeeded_RNG rng;
- Client_TLS_Policy policy;
-
- TLS_Client tls(std::tr1::bind(&Socket::read, std::tr1::ref(sock), _1, _2),
- std::tr1::bind(&Socket::write, std::tr1::ref(sock), _1, _2),
- policy, rng);
+ printf("Connecting to %s:%d...\n", host.c_str(), port);
- printf("Handshake extablished...\n");
+ HTTPS_Client https(host, port, rng);
-#if 0
- std::string http_command = "GET / HTTP/1.1\r\n"
- "Server: " + host + ':' + to_string(port) + "\r\n\r\n";
-#else
std::string http_command = "GET / HTTP/1.0\r\n\r\n";
-#endif
- tls.write((const Botan::byte*)http_command.c_str(),
- http_command.length());
+ printf("Sending request\n");
+ https.write(http_command);
- size_t total_got = 0;
-
- while(true)
- {
- if(tls.is_closed())
- break;
-
- Botan::byte buf[128+1] = { 0 };
- size_t got = tls.read(buf, sizeof(buf)-1);
- printf("%s", buf);
- fflush(0);
-
- total_got += got;
- }
+ https.read_response();
- printf("\nRetrieved %d bytes total\n", total_got);
}
catch(std::exception& e)
{
diff --git a/src/ssl/info.txt b/src/ssl/info.txt
index f920a733d..1170fef45 100644
--- a/src/ssl/info.txt
+++ b/src/ssl/info.txt
@@ -9,7 +9,6 @@ uses_tr1 yes
<header:public>
tls_client.h
-tls_connection.h
tls_exceptn.h
tls_magic.h
tls_policy.h
diff --git a/src/ssl/rec_wri.cpp b/src/ssl/rec_wri.cpp
index 59dead3cd..d3a5c13f7 100644
--- a/src/ssl/rec_wri.cpp
+++ b/src/ssl/rec_wri.cpp
@@ -223,7 +223,7 @@ void Record_Writer::send_record(byte type, const byte buf[], size_t length)
if(block_size)
{
- size_t pad_val =
+ const size_t pad_val =
(block_size - (1 + length + buf_mac.size())) % block_size;
for(size_t i = 0; i != pad_val + 1; ++i)
diff --git a/src/ssl/tls_client.cpp b/src/ssl/tls_client.cpp
index a136752fd..cfa86881c 100644
--- a/src/ssl/tls_client.cpp
+++ b/src/ssl/tls_client.cpp
@@ -1,6 +1,6 @@
/*
* TLS Client
-* (C) 2004-2010 Jack Lloyd
+* (C) 2004-2011 Jack Lloyd
*
* Released under the terms of the Botan license
*/
@@ -81,16 +81,21 @@ void client_check_state(Handshake_Type new_msg, Handshake_State* state)
/**
* TLS Client Constructor
*/
-TLS_Client::TLS_Client(std::tr1::function<size_t (byte[], size_t)> input_fn,
- std::tr1::function<void (const byte[], size_t)> output_fn,
+TLS_Client::TLS_Client(std::tr1::function<void (const byte[], size_t)> socket_output_fn,
+ std::tr1::function<void (const byte[], size_t, u16bit)> process_fn,
const TLS_Policy& policy,
RandomNumberGenerator& rng) :
- input_fn(input_fn),
policy(policy),
rng(rng),
- writer(output_fn)
+ proc_fn(process_fn),
+ writer(socket_output_fn),
+ state(0),
+ active(false)
{
- initialize();
+ writer.set_version(policy.pref_version());
+
+ state = new Handshake_State;
+ state->client_hello = new Client_Hello(rng, writer, policy, state->hash);
}
void TLS_Client::add_client_cert(const X509_Certificate& cert,
@@ -110,92 +115,87 @@ TLS_Client::~TLS_Client()
delete state;
}
-/**
-* Initialize a TLS client connection
-*/
-void TLS_Client::initialize()
+size_t TLS_Client::received_data(const byte buf[], size_t buf_size)
{
- std::string error_str;
- Alert_Type error_type = NO_ALERT_TYPE;
-
- try {
- state = 0;
- active = false;
- writer.set_version(policy.pref_version());
- do_handshake();
- }
- catch(TLS_Exception& e)
+ try
{
- error_str = e.what();
- error_type = e.type();
- }
- catch(std::exception& e)
- {
- error_str = e.what();
- error_type = HANDSHAKE_FAILURE;
- }
+ reader.add_input(buf, buf_size);
- if(error_type != NO_ALERT_TYPE)
- {
- if(active)
+ byte rec_type = CONNECTION_CLOSED;
+ SecureVector<byte> record;
+
+ while(!reader.currently_empty())
{
- active = false;
- reader.reset();
+ const size_t bytes_needed = reader.get_record(rec_type, record);
- writer.alert(FATAL, error_type);
- writer.reset();
- }
+ if(bytes_needed > 0)
+ return bytes_needed;
- if(state)
- {
- delete state;
- state = 0;
- }
+ if(rec_type == APPLICATION_DATA)
+ {
+ if(active)
+ {
+ proc_fn(&record[0], record.size(), NO_ALERT_TYPE);
+ }
+ else
+ {
+ throw Unexpected_Message("Application data before handshake done");
+ }
+ }
+ else if(rec_type == HANDSHAKE || rec_type == CHANGE_CIPHER_SPEC)
+ {
+ read_handshake(rec_type, record);
+ }
+ else if(rec_type == ALERT)
+ {
+ Alert alert(record);
- throw Stream_IO_Error("TLS_Client: Handshake failed: " + error_str);
- }
- }
+ proc_fn(0, 0, alert.type());
-/**
-* Return the peer's certificate chain
-*/
-std::vector<X509_Certificate> TLS_Client::peer_cert_chain() const
- {
- return peer_certs;
- }
+ if(alert.is_fatal() || alert.type() == CLOSE_NOTIFY)
+ {
+ if(alert.type() == CLOSE_NOTIFY)
+ {
+ writer.alert(WARNING, CLOSE_NOTIFY);
+ }
-/**
-* Write to a TLS connection
-*/
-void TLS_Client::write(const byte buf[], size_t length)
- {
- if(!active)
- throw TLS_Exception(INTERNAL_ERROR,
- "TLS_Client::write called while closed");
+ close(FATAL, NO_ALERT_TYPE);
+ }
+ }
+ else
+ throw Unexpected_Message("Unknown message type received");
+ }
- writer.send(APPLICATION_DATA, buf, length);
+ return 0; // on a record boundary
+ }
+ catch(TLS_Exception& e)
+ {
+ close(FATAL, e.type());
+ throw;
+ }
+ catch(std::exception& e)
+ {
+ close(FATAL, INTERNAL_ERROR);
+ throw;
+ }
}
-/**
-* Read from a TLS connection
-*/
-size_t TLS_Client::read(byte out[], size_t length)
+void TLS_Client::queue_for_sending(const byte buf[], size_t buf_size)
{
- if(!active)
- return 0;
-
- writer.flush();
-
- while(read_buf.size() == 0)
+ if(active)
{
- state_machine();
- if(active == false)
- break;
- }
+ while(!pre_handshake_write_queue.end_of_data())
+ {
+ SecureVector<byte> q_buf(1024);
+ const size_t got = pre_handshake_write_queue.read(&q_buf[0], q_buf.size());
+ writer.send(APPLICATION_DATA, &q_buf[0], got);
+ }
- size_t got = std::min<size_t>(read_buf.size(), length);
- read_buf.read(out, got);
- return got;
+ writer.send(APPLICATION_DATA, buf, buf_size);
+ writer.flush();
+ }
+ else
+ pre_handshake_write_queue.write(buf, buf_size);
}
/**
@@ -207,94 +207,27 @@ void TLS_Client::close()
}
/**
-* Check connection status
-*/
-bool TLS_Client::is_closed() const
- {
- if(!active)
- return true;
- return false;
- }
-
-/**
* Close a TLS connection
*/
void TLS_Client::close(Alert_Level level, Alert_Type alert_code)
{
if(active)
{
- try {
- writer.alert(level, alert_code);
- writer.flush();
- }
- catch(...) {}
-
active = false;
- }
- }
-/**
-* Iterate the TLS state machine
-*/
-void TLS_Client::state_machine()
- {
- byte rec_type = CONNECTION_CLOSED;
- SecureVector<byte> record(1024);
-
- size_t bytes_needed = reader.get_record(rec_type, record);
-
- while(bytes_needed)
- {
- size_t to_get = std::min<size_t>(record.size(), bytes_needed);
- size_t got = input_fn(&record[0], to_get);
-
- if(got == 0)
+ if(alert_code != NO_ALERT_TYPE)
{
- rec_type = CONNECTION_CLOSED;
- break;
+ try
+ {
+ writer.alert(level, alert_code);
+ writer.flush();
+ }
+ catch(...) { /* swallow it */ }
}
- reader.add_input(&record[0], got);
-
- bytes_needed = reader.get_record(rec_type, record);
- }
-
- if(rec_type == CONNECTION_CLOSED)
- {
- active = false;
reader.reset();
writer.reset();
}
- else if(rec_type == APPLICATION_DATA)
- {
- if(active)
- read_buf.write(&record[0], record.size());
- else
- throw Unexpected_Message("Application data before handshake done");
- }
- else if(rec_type == HANDSHAKE || rec_type == CHANGE_CIPHER_SPEC)
- read_handshake(rec_type, record);
- else if(rec_type == ALERT)
- {
- Alert alert(record);
-
- if(alert.is_fatal() || alert.type() == CLOSE_NOTIFY)
- {
- if(alert.type() == CLOSE_NOTIFY)
- writer.alert(WARNING, CLOSE_NOTIFY);
-
- reader.reset();
- writer.reset();
- active = false;
- if(state)
- {
- delete state;
- state = 0;
- }
- }
- }
- else
- throw Unexpected_Message("Unknown message type received");
}
/**
@@ -563,24 +496,4 @@ void TLS_Client::process_handshake_msg(Handshake_Type type,
throw Unexpected_Message("Unknown handshake message received");
}
-/**
-* Perform a client-side TLS handshake
-*/
-void TLS_Client::do_handshake()
- {
- state = new Handshake_State;
-
- state->client_hello = new Client_Hello(rng, writer, policy, state->hash);
-
- while(true)
- {
- if(active && !state)
- break;
- if(!active && !state)
- throw TLS_Exception(HANDSHAKE_FAILURE, "TLS_Client: Handshake failed (do_handshake)");
-
- state_machine();
- }
- }
-
}
diff --git a/src/ssl/tls_client.h b/src/ssl/tls_client.h
index 7d2ce9cda..6d613be33 100644
--- a/src/ssl/tls_client.h
+++ b/src/ssl/tls_client.h
@@ -1,6 +1,6 @@
/*
* TLS Client
-* (C) 2004-2010 Jack Lloyd
+* (C) 2004-2011 Jack Lloyd
*
* Released under the terms of the Botan license
*/
@@ -8,7 +8,6 @@
#ifndef BOTAN_TLS_CLIENT_H__
#define BOTAN_TLS_CLIENT_H__
-#include <botan/tls_connection.h>
#include <botan/tls_policy.h>
#include <botan/tls_record.h>
#include <vector>
@@ -19,25 +18,39 @@ namespace Botan {
/**
* SSL/TLS Client
*/
-class BOTAN_DLL TLS_Client : public TLS_Connection
+class BOTAN_DLL TLS_Client
{
public:
- size_t read(byte buf[], size_t buf_len);
- void write(const byte buf[], size_t buf_len);
+ /**
+ * Set up a new TLS client session
+ */
+ TLS_Client(std::tr1::function<void (const byte[], size_t)> socket_output_fn,
+ std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn,
+ const TLS_Policy& policy,
+ RandomNumberGenerator& rng);
+
+ /**
+ * Inject TLS traffic received from counterparty
+
+ * @return a hint as the how many more bytes we need to process the
+ current record (this may be 0 if on a record boundary)
+ */
+ size_t received_data(const byte buf[], size_t buf_size);
+
+ /**
+ * Inject plaintext intended for counterparty
+ */
+ void queue_for_sending(const byte buf[], size_t buf_size);
void close();
- bool is_closed() const;
- std::vector<X509_Certificate> peer_cert_chain() const;
+ bool handshake_complete() const { return active; }
+
+ std::vector<X509_Certificate> peer_cert_chain() const { return peer_certs; }
void add_client_cert(const X509_Certificate& cert,
Private_Key* cert_key);
- TLS_Client(std::tr1::function<size_t (byte[], size_t)> input_fn,
- std::tr1::function<void (const byte[], size_t)> output_fn,
- const TLS_Policy& policy,
- RandomNumberGenerator& rng);
-
~TLS_Client();
private:
void close(Alert_Level, Alert_Type);
@@ -51,20 +64,21 @@ class BOTAN_DLL TLS_Client : public TLS_Connection
void read_handshake(byte, const MemoryRegion<byte>&);
void process_handshake_msg(Handshake_Type, const MemoryRegion<byte>&);
- std::tr1::function<size_t (byte[], size_t)> input_fn;
-
const TLS_Policy& policy;
RandomNumberGenerator& rng;
+ std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn;
+
Record_Writer writer;
Record_Reader reader;
+ SecureQueue pre_handshake_write_queue;
+
std::vector<X509_Certificate> peer_certs;
std::vector<std::pair<X509_Certificate, Private_Key*> > certs;
class Handshake_State* state;
- SecureVector<byte> session_id;
- SecureQueue read_buf;
+ //SecureVector<byte> session_id;
bool active;
};
diff --git a/src/ssl/tls_connection.h b/src/ssl/tls_connection.h
deleted file mode 100644
index bbefa2114..000000000
--- a/src/ssl/tls_connection.h
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
-* TLS Connection
-* (C) 2004-2006 Jack Lloyd
-*
-* Released under the terms of the Botan license
-*/
-
-#ifndef BOTAN_TLS_CONNECTION_H__
-#define BOTAN_TLS_CONNECTION_H__
-
-#include <botan/x509cert.h>
-#include <vector>
-
-namespace Botan {
-
-/**
-* TLS Connection
-*/
-class BOTAN_DLL TLS_Connection
- {
- public:
- virtual size_t read(byte[], size_t) = 0;
- virtual void write(const byte[], size_t) = 0;
- size_t read(byte& in) { return read(&in, 1); }
- void write(byte out) { write(&out, 1); }
-
- virtual std::vector<X509_Certificate> peer_cert_chain() const = 0;
-
- virtual void close() = 0;
-
- virtual ~TLS_Connection() {}
- };
-
-}
-
-#endif
diff --git a/src/ssl/tls_record.h b/src/ssl/tls_record.h
index 09fd921c6..6d5dd057d 100644
--- a/src/ssl/tls_record.h
+++ b/src/ssl/tls_record.h
@@ -99,6 +99,8 @@ class BOTAN_DLL Record_Reader
void reset();
+ bool currently_empty() const { return input_queue.size() == 0; }
+
Record_Reader() { mac = 0; reset(); }
~Record_Reader() { delete mac; }
diff --git a/src/ssl/tls_server.h b/src/ssl/tls_server.h
index a6b0f9cb4..510ad15a7 100644
--- a/src/ssl/tls_server.h
+++ b/src/ssl/tls_server.h
@@ -1,6 +1,6 @@
/*
* TLS Server
-* (C) 2004-2010 Jack Lloyd
+* (C) 2004-2011 Jack Lloyd
*
* Released under the terms of the Botan license
*/
@@ -8,7 +8,6 @@
#ifndef BOTAN_TLS_SERVER_H__
#define BOTAN_TLS_SERVER_H__
-#include <botan/tls_connection.h>
#include <botan/tls_record.h>
#include <botan/tls_policy.h>
#include <vector>
@@ -18,7 +17,7 @@ namespace Botan {
/**
* TLS Server
*/
-class BOTAN_DLL TLS_Server : public TLS_Connection
+class BOTAN_DLL TLS_Server
{
public:
size_t read(byte buf[], size_t buf_len);