diff options
author | lloyd <[email protected]> | 2011-12-23 16:14:45 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2011-12-23 16:14:45 +0000 |
commit | d363602f95f1514b4b595d9912fba2e503edcb21 (patch) | |
tree | 70d036ff58e67629103a4ac6c1374ec90b6bd3db | |
parent | a3d81efbd2c56749d4abf9e6a27cb36cbbb10702 (diff) |
First stab at an event driven TLS client.
-rw-r--r-- | doc/examples/socket.h | 20 | ||||
-rw-r--r-- | doc/examples/tls_client.cpp | 109 | ||||
-rw-r--r-- | src/ssl/info.txt | 1 | ||||
-rw-r--r-- | src/ssl/rec_wri.cpp | 2 | ||||
-rw-r--r-- | src/ssl/tls_client.cpp | 255 | ||||
-rw-r--r-- | src/ssl/tls_client.h | 46 | ||||
-rw-r--r-- | src/ssl/tls_connection.h | 36 | ||||
-rw-r--r-- | src/ssl/tls_record.h | 2 | ||||
-rw-r--r-- | src/ssl/tls_server.h | 5 |
9 files changed, 211 insertions, 265 deletions
diff --git a/doc/examples/socket.h b/doc/examples/socket.h index f7ce98fea..f10ff9f26 100644 --- a/doc/examples/socket.h +++ b/doc/examples/socket.h @@ -48,6 +48,7 @@ #include <netdb.h> #include <unistd.h> #include <errno.h> + #include <fcntl.h> typedef int socket_t; const socket_t invalid_socket = -1; @@ -66,7 +67,7 @@ class Socket { public: - size_t read(unsigned char[], size_t); + size_t read(unsigned char[], size_t, bool dont_block = false); void write(const unsigned char[], size_t); std::string peer_id() const { return peer; } @@ -158,23 +159,28 @@ Socket::Socket(const std::string& host, unsigned short port) : peer(host) throw std::runtime_error("Socket: connect failed"); } + //fcntl(fd, F_SETFL, O_NONBLOCK); + sockfd = fd; } /** * Read from a Unix socket */ -size_t Socket::read(unsigned char buf[], size_t length) +size_t Socket::read(unsigned char buf[], size_t length, bool partial) { if(sockfd == invalid_socket) throw std::runtime_error("Socket::read: Socket not connected"); size_t got = 0; + int flags = MSG_NOSIGNAL; + while(length) { - ssize_t this_time = ::recv(sockfd, (char*)buf + got, - length, MSG_NOSIGNAL); + ssize_t this_time = ::recv(sockfd, (char*)buf + got, length, flags); + + const bool full_ret = (this_time == length); if(this_time == 0) break; @@ -183,13 +189,19 @@ size_t Socket::read(unsigned char buf[], size_t length) { if(socket_error_code == EINTR) this_time = 0; + else if(socket_error_code == EAGAIN) + break; else throw std::runtime_error("Socket::read: Socket read failed"); } got += this_time; length -= this_time; + + if(partial && !full_ret) + break; } + return got; } diff --git a/doc/examples/tls_client.cpp b/doc/examples/tls_client.cpp index cedfe1ca8..a51febfcf 100644 --- a/doc/examples/tls_client.cpp +++ b/doc/examples/tls_client.cpp @@ -25,6 +25,76 @@ class Client_TLS_Policy : public TLS_Policy } }; +class HTTPS_Client + { + public: + HTTPS_Client(const std::string& host, u16bit port, RandomNumberGenerator& r) : + rng(r), + socket(host, port), + client(std::tr1::bind(&HTTPS_Client::socket_write, std::tr1::ref(*this), _1, _2), + std::tr1::bind(&HTTPS_Client::proc_data, std::tr1::ref(*this), _1, _2, _3), + policy, + rng) + { + SecureVector<byte> socket_buf(1024); + size_t desired = 0; + + quit_reading = false; + + while(!client.handshake_complete() || desired) + { + const size_t socket_got = socket.read(&socket_buf[0], socket_buf.size()); + //printf("Got %d bytes from socket\n", socket_got); + desired = client.received_data(&socket_buf[0], socket_got); + socket_buf.resize(desired || 1); + //printf("Going around for another read?\n"); + } + } + + void socket_write(const byte buf[], size_t buf_size) + { + std::cout << "socket_write " << buf_size << "\n"; + socket.write(buf, buf_size); + } + + void proc_data(const byte data[], size_t data_len, u16bit alert_info) + { + printf("Block of data %d bytes alert %04X\n", (int)data_len, alert_info); + for(size_t i = 0; i != data_len; ++i) + printf("%c", data[i]); + + if(data_len == 0 && alert_info == 0) + quit_reading = true; + } + + void write(const std::string& s) + { + client.queue_for_sending((const byte*)s.c_str(), s.length()); + } + + void read_response() + { + while(true) + { + SecureVector<byte> buf(4096); + + size_t got = socket.read(&buf[0], buf.size(), true); + + if(got == 0) + break; + + client.received_data(&buf[0], got); + } + } + + private: + bool quit_reading; + RandomNumberGenerator& rng; + Socket socket; + Client_TLS_Policy policy; + TLS_Client client; + }; + int main(int argc, char* argv[]) { if(argc != 2 && argc != 3) @@ -40,48 +110,21 @@ int main(int argc, char* argv[]) std::string host = argv[1]; u32bit port = argc == 3 ? Botan::to_u32bit(argv[2]) : 443; - printf("Connecting to %s:%d...\n", host.c_str(), port); - - SocketInitializer socket_init; - - Socket sock(argv[1], port); + //SocketInitializer socket_init; AutoSeeded_RNG rng; - Client_TLS_Policy policy; - - TLS_Client tls(std::tr1::bind(&Socket::read, std::tr1::ref(sock), _1, _2), - std::tr1::bind(&Socket::write, std::tr1::ref(sock), _1, _2), - policy, rng); + printf("Connecting to %s:%d...\n", host.c_str(), port); - printf("Handshake extablished...\n"); + HTTPS_Client https(host, port, rng); -#if 0 - std::string http_command = "GET / HTTP/1.1\r\n" - "Server: " + host + ':' + to_string(port) + "\r\n\r\n"; -#else std::string http_command = "GET / HTTP/1.0\r\n\r\n"; -#endif - tls.write((const Botan::byte*)http_command.c_str(), - http_command.length()); + printf("Sending request\n"); + https.write(http_command); - size_t total_got = 0; - - while(true) - { - if(tls.is_closed()) - break; - - Botan::byte buf[128+1] = { 0 }; - size_t got = tls.read(buf, sizeof(buf)-1); - printf("%s", buf); - fflush(0); - - total_got += got; - } + https.read_response(); - printf("\nRetrieved %d bytes total\n", total_got); } catch(std::exception& e) { diff --git a/src/ssl/info.txt b/src/ssl/info.txt index f920a733d..1170fef45 100644 --- a/src/ssl/info.txt +++ b/src/ssl/info.txt @@ -9,7 +9,6 @@ uses_tr1 yes <header:public> tls_client.h -tls_connection.h tls_exceptn.h tls_magic.h tls_policy.h diff --git a/src/ssl/rec_wri.cpp b/src/ssl/rec_wri.cpp index 59dead3cd..d3a5c13f7 100644 --- a/src/ssl/rec_wri.cpp +++ b/src/ssl/rec_wri.cpp @@ -223,7 +223,7 @@ void Record_Writer::send_record(byte type, const byte buf[], size_t length) if(block_size) { - size_t pad_val = + const size_t pad_val = (block_size - (1 + length + buf_mac.size())) % block_size; for(size_t i = 0; i != pad_val + 1; ++i) diff --git a/src/ssl/tls_client.cpp b/src/ssl/tls_client.cpp index a136752fd..cfa86881c 100644 --- a/src/ssl/tls_client.cpp +++ b/src/ssl/tls_client.cpp @@ -1,6 +1,6 @@ /* * TLS Client -* (C) 2004-2010 Jack Lloyd +* (C) 2004-2011 Jack Lloyd * * Released under the terms of the Botan license */ @@ -81,16 +81,21 @@ void client_check_state(Handshake_Type new_msg, Handshake_State* state) /** * TLS Client Constructor */ -TLS_Client::TLS_Client(std::tr1::function<size_t (byte[], size_t)> input_fn, - std::tr1::function<void (const byte[], size_t)> output_fn, +TLS_Client::TLS_Client(std::tr1::function<void (const byte[], size_t)> socket_output_fn, + std::tr1::function<void (const byte[], size_t, u16bit)> process_fn, const TLS_Policy& policy, RandomNumberGenerator& rng) : - input_fn(input_fn), policy(policy), rng(rng), - writer(output_fn) + proc_fn(process_fn), + writer(socket_output_fn), + state(0), + active(false) { - initialize(); + writer.set_version(policy.pref_version()); + + state = new Handshake_State; + state->client_hello = new Client_Hello(rng, writer, policy, state->hash); } void TLS_Client::add_client_cert(const X509_Certificate& cert, @@ -110,92 +115,87 @@ TLS_Client::~TLS_Client() delete state; } -/** -* Initialize a TLS client connection -*/ -void TLS_Client::initialize() +size_t TLS_Client::received_data(const byte buf[], size_t buf_size) { - std::string error_str; - Alert_Type error_type = NO_ALERT_TYPE; - - try { - state = 0; - active = false; - writer.set_version(policy.pref_version()); - do_handshake(); - } - catch(TLS_Exception& e) + try { - error_str = e.what(); - error_type = e.type(); - } - catch(std::exception& e) - { - error_str = e.what(); - error_type = HANDSHAKE_FAILURE; - } + reader.add_input(buf, buf_size); - if(error_type != NO_ALERT_TYPE) - { - if(active) + byte rec_type = CONNECTION_CLOSED; + SecureVector<byte> record; + + while(!reader.currently_empty()) { - active = false; - reader.reset(); + const size_t bytes_needed = reader.get_record(rec_type, record); - writer.alert(FATAL, error_type); - writer.reset(); - } + if(bytes_needed > 0) + return bytes_needed; - if(state) - { - delete state; - state = 0; - } + if(rec_type == APPLICATION_DATA) + { + if(active) + { + proc_fn(&record[0], record.size(), NO_ALERT_TYPE); + } + else + { + throw Unexpected_Message("Application data before handshake done"); + } + } + else if(rec_type == HANDSHAKE || rec_type == CHANGE_CIPHER_SPEC) + { + read_handshake(rec_type, record); + } + else if(rec_type == ALERT) + { + Alert alert(record); - throw Stream_IO_Error("TLS_Client: Handshake failed: " + error_str); - } - } + proc_fn(0, 0, alert.type()); -/** -* Return the peer's certificate chain -*/ -std::vector<X509_Certificate> TLS_Client::peer_cert_chain() const - { - return peer_certs; - } + if(alert.is_fatal() || alert.type() == CLOSE_NOTIFY) + { + if(alert.type() == CLOSE_NOTIFY) + { + writer.alert(WARNING, CLOSE_NOTIFY); + } -/** -* Write to a TLS connection -*/ -void TLS_Client::write(const byte buf[], size_t length) - { - if(!active) - throw TLS_Exception(INTERNAL_ERROR, - "TLS_Client::write called while closed"); + close(FATAL, NO_ALERT_TYPE); + } + } + else + throw Unexpected_Message("Unknown message type received"); + } - writer.send(APPLICATION_DATA, buf, length); + return 0; // on a record boundary + } + catch(TLS_Exception& e) + { + close(FATAL, e.type()); + throw; + } + catch(std::exception& e) + { + close(FATAL, INTERNAL_ERROR); + throw; + } } -/** -* Read from a TLS connection -*/ -size_t TLS_Client::read(byte out[], size_t length) +void TLS_Client::queue_for_sending(const byte buf[], size_t buf_size) { - if(!active) - return 0; - - writer.flush(); - - while(read_buf.size() == 0) + if(active) { - state_machine(); - if(active == false) - break; - } + while(!pre_handshake_write_queue.end_of_data()) + { + SecureVector<byte> q_buf(1024); + const size_t got = pre_handshake_write_queue.read(&q_buf[0], q_buf.size()); + writer.send(APPLICATION_DATA, &q_buf[0], got); + } - size_t got = std::min<size_t>(read_buf.size(), length); - read_buf.read(out, got); - return got; + writer.send(APPLICATION_DATA, buf, buf_size); + writer.flush(); + } + else + pre_handshake_write_queue.write(buf, buf_size); } /** @@ -207,94 +207,27 @@ void TLS_Client::close() } /** -* Check connection status -*/ -bool TLS_Client::is_closed() const - { - if(!active) - return true; - return false; - } - -/** * Close a TLS connection */ void TLS_Client::close(Alert_Level level, Alert_Type alert_code) { if(active) { - try { - writer.alert(level, alert_code); - writer.flush(); - } - catch(...) {} - active = false; - } - } -/** -* Iterate the TLS state machine -*/ -void TLS_Client::state_machine() - { - byte rec_type = CONNECTION_CLOSED; - SecureVector<byte> record(1024); - - size_t bytes_needed = reader.get_record(rec_type, record); - - while(bytes_needed) - { - size_t to_get = std::min<size_t>(record.size(), bytes_needed); - size_t got = input_fn(&record[0], to_get); - - if(got == 0) + if(alert_code != NO_ALERT_TYPE) { - rec_type = CONNECTION_CLOSED; - break; + try + { + writer.alert(level, alert_code); + writer.flush(); + } + catch(...) { /* swallow it */ } } - reader.add_input(&record[0], got); - - bytes_needed = reader.get_record(rec_type, record); - } - - if(rec_type == CONNECTION_CLOSED) - { - active = false; reader.reset(); writer.reset(); } - else if(rec_type == APPLICATION_DATA) - { - if(active) - read_buf.write(&record[0], record.size()); - else - throw Unexpected_Message("Application data before handshake done"); - } - else if(rec_type == HANDSHAKE || rec_type == CHANGE_CIPHER_SPEC) - read_handshake(rec_type, record); - else if(rec_type == ALERT) - { - Alert alert(record); - - if(alert.is_fatal() || alert.type() == CLOSE_NOTIFY) - { - if(alert.type() == CLOSE_NOTIFY) - writer.alert(WARNING, CLOSE_NOTIFY); - - reader.reset(); - writer.reset(); - active = false; - if(state) - { - delete state; - state = 0; - } - } - } - else - throw Unexpected_Message("Unknown message type received"); } /** @@ -563,24 +496,4 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, throw Unexpected_Message("Unknown handshake message received"); } -/** -* Perform a client-side TLS handshake -*/ -void TLS_Client::do_handshake() - { - state = new Handshake_State; - - state->client_hello = new Client_Hello(rng, writer, policy, state->hash); - - while(true) - { - if(active && !state) - break; - if(!active && !state) - throw TLS_Exception(HANDSHAKE_FAILURE, "TLS_Client: Handshake failed (do_handshake)"); - - state_machine(); - } - } - } diff --git a/src/ssl/tls_client.h b/src/ssl/tls_client.h index 7d2ce9cda..6d613be33 100644 --- a/src/ssl/tls_client.h +++ b/src/ssl/tls_client.h @@ -1,6 +1,6 @@ /* * TLS Client -* (C) 2004-2010 Jack Lloyd +* (C) 2004-2011 Jack Lloyd * * Released under the terms of the Botan license */ @@ -8,7 +8,6 @@ #ifndef BOTAN_TLS_CLIENT_H__ #define BOTAN_TLS_CLIENT_H__ -#include <botan/tls_connection.h> #include <botan/tls_policy.h> #include <botan/tls_record.h> #include <vector> @@ -19,25 +18,39 @@ namespace Botan { /** * SSL/TLS Client */ -class BOTAN_DLL TLS_Client : public TLS_Connection +class BOTAN_DLL TLS_Client { public: - size_t read(byte buf[], size_t buf_len); - void write(const byte buf[], size_t buf_len); + /** + * Set up a new TLS client session + */ + TLS_Client(std::tr1::function<void (const byte[], size_t)> socket_output_fn, + std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn, + const TLS_Policy& policy, + RandomNumberGenerator& rng); + + /** + * Inject TLS traffic received from counterparty + + * @return a hint as the how many more bytes we need to process the + current record (this may be 0 if on a record boundary) + */ + size_t received_data(const byte buf[], size_t buf_size); + + /** + * Inject plaintext intended for counterparty + */ + void queue_for_sending(const byte buf[], size_t buf_size); void close(); - bool is_closed() const; - std::vector<X509_Certificate> peer_cert_chain() const; + bool handshake_complete() const { return active; } + + std::vector<X509_Certificate> peer_cert_chain() const { return peer_certs; } void add_client_cert(const X509_Certificate& cert, Private_Key* cert_key); - TLS_Client(std::tr1::function<size_t (byte[], size_t)> input_fn, - std::tr1::function<void (const byte[], size_t)> output_fn, - const TLS_Policy& policy, - RandomNumberGenerator& rng); - ~TLS_Client(); private: void close(Alert_Level, Alert_Type); @@ -51,20 +64,21 @@ class BOTAN_DLL TLS_Client : public TLS_Connection void read_handshake(byte, const MemoryRegion<byte>&); void process_handshake_msg(Handshake_Type, const MemoryRegion<byte>&); - std::tr1::function<size_t (byte[], size_t)> input_fn; - const TLS_Policy& policy; RandomNumberGenerator& rng; + std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn; + Record_Writer writer; Record_Reader reader; + SecureQueue pre_handshake_write_queue; + std::vector<X509_Certificate> peer_certs; std::vector<std::pair<X509_Certificate, Private_Key*> > certs; class Handshake_State* state; - SecureVector<byte> session_id; - SecureQueue read_buf; + //SecureVector<byte> session_id; bool active; }; diff --git a/src/ssl/tls_connection.h b/src/ssl/tls_connection.h deleted file mode 100644 index bbefa2114..000000000 --- a/src/ssl/tls_connection.h +++ /dev/null @@ -1,36 +0,0 @@ -/* -* TLS Connection -* (C) 2004-2006 Jack Lloyd -* -* Released under the terms of the Botan license -*/ - -#ifndef BOTAN_TLS_CONNECTION_H__ -#define BOTAN_TLS_CONNECTION_H__ - -#include <botan/x509cert.h> -#include <vector> - -namespace Botan { - -/** -* TLS Connection -*/ -class BOTAN_DLL TLS_Connection - { - public: - virtual size_t read(byte[], size_t) = 0; - virtual void write(const byte[], size_t) = 0; - size_t read(byte& in) { return read(&in, 1); } - void write(byte out) { write(&out, 1); } - - virtual std::vector<X509_Certificate> peer_cert_chain() const = 0; - - virtual void close() = 0; - - virtual ~TLS_Connection() {} - }; - -} - -#endif diff --git a/src/ssl/tls_record.h b/src/ssl/tls_record.h index 09fd921c6..6d5dd057d 100644 --- a/src/ssl/tls_record.h +++ b/src/ssl/tls_record.h @@ -99,6 +99,8 @@ class BOTAN_DLL Record_Reader void reset(); + bool currently_empty() const { return input_queue.size() == 0; } + Record_Reader() { mac = 0; reset(); } ~Record_Reader() { delete mac; } diff --git a/src/ssl/tls_server.h b/src/ssl/tls_server.h index a6b0f9cb4..510ad15a7 100644 --- a/src/ssl/tls_server.h +++ b/src/ssl/tls_server.h @@ -1,6 +1,6 @@ /* * TLS Server -* (C) 2004-2010 Jack Lloyd +* (C) 2004-2011 Jack Lloyd * * Released under the terms of the Botan license */ @@ -8,7 +8,6 @@ #ifndef BOTAN_TLS_SERVER_H__ #define BOTAN_TLS_SERVER_H__ -#include <botan/tls_connection.h> #include <botan/tls_record.h> #include <botan/tls_policy.h> #include <vector> @@ -18,7 +17,7 @@ namespace Botan { /** * TLS Server */ -class BOTAN_DLL TLS_Server : public TLS_Connection +class BOTAN_DLL TLS_Server { public: size_t read(byte buf[], size_t buf_len); |