diff options
author | lloyd <[email protected]> | 2012-08-03 14:40:08 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2012-08-03 14:40:08 +0000 |
commit | db2a5f10716f69a58f8c554c8e65d21e198ffbc5 (patch) | |
tree | 1c7a302a3f34fb46f201bf6b658884421609a559 /src | |
parent | ba0e7cc86e7fa6606a04c3ae34be354d8ed801b3 (diff) |
Combine Handshake_Writer and Handshake_Reader into Handshake_IO.
This is mostly just a minor code savings for TLS, but it actually
seems important for DTLS because getting a handshake message can be a
trigger for retransmitting previously sent handshake messages in some
circumstances. Having the reading and writing all in one layer makes
it a bit easier to accomplish that.
Diffstat (limited to 'src')
-rw-r--r-- | src/tls/c_hello.cpp | 14 | ||||
-rw-r--r-- | src/tls/c_kex.cpp | 6 | ||||
-rw-r--r-- | src/tls/cert_req.cpp | 10 | ||||
-rw-r--r-- | src/tls/cert_ver.cpp | 6 | ||||
-rw-r--r-- | src/tls/finished.cpp | 6 | ||||
-rw-r--r-- | src/tls/info.txt | 6 | ||||
-rw-r--r-- | src/tls/next_protocol.cpp | 6 | ||||
-rw-r--r-- | src/tls/s_hello.cpp | 10 | ||||
-rw-r--r-- | src/tls/s_kex.cpp | 6 | ||||
-rw-r--r-- | src/tls/session_ticket.cpp | 10 | ||||
-rw-r--r-- | src/tls/tls_channel.cpp | 6 | ||||
-rw-r--r-- | src/tls/tls_client.cpp | 23 | ||||
-rw-r--r-- | src/tls/tls_handshake_io.cpp (renamed from src/tls/tls_handshake_reader.cpp) | 55 | ||||
-rw-r--r-- | src/tls/tls_handshake_io.h (renamed from src/tls/tls_handshake_reader.h) | 40 | ||||
-rw-r--r-- | src/tls/tls_handshake_state.cpp | 9 | ||||
-rw-r--r-- | src/tls/tls_handshake_state.h | 13 | ||||
-rw-r--r-- | src/tls/tls_handshake_writer.cpp | 56 | ||||
-rw-r--r-- | src/tls/tls_handshake_writer.h | 66 | ||||
-rw-r--r-- | src/tls/tls_messages.h | 30 | ||||
-rw-r--r-- | src/tls/tls_server.cpp | 37 |
20 files changed, 170 insertions, 245 deletions
diff --git a/src/tls/c_hello.cpp b/src/tls/c_hello.cpp index eacbecc6c..2d2e03752 100644 --- a/src/tls/c_hello.cpp +++ b/src/tls/c_hello.cpp @@ -9,7 +9,7 @@ #include <botan/internal/tls_reader.h> #include <botan/internal/tls_session_key.h> #include <botan/internal/tls_extensions.h> -#include <botan/internal/tls_handshake_writer.h> +#include <botan/internal/tls_handshake_io.h> #include <botan/internal/stl_util.h> #include <chrono> @@ -36,9 +36,9 @@ std::vector<byte> make_hello_random(RandomNumberGenerator& rng) /* * Create a new Hello Request message */ -Hello_Request::Hello_Request(Handshake_Writer& writer) +Hello_Request::Hello_Request(Handshake_IO& io) { - writer.send(*this); + io.send(*this); } /* @@ -61,7 +61,7 @@ std::vector<byte> Hello_Request::serialize() const /* * Create a new Client Hello message */ -Client_Hello::Client_Hello(Handshake_Writer& writer, +Client_Hello::Client_Hello(Handshake_IO& io, Handshake_Hash& hash, Protocol_Version version, const Policy& policy, @@ -92,13 +92,13 @@ Client_Hello::Client_Hello(Handshake_Writer& writer, for(size_t j = 0; j != sigs.size(); ++j) m_supported_algos.push_back(std::make_pair(hashes[i], sigs[j])); - hash.update(writer.send(*this)); + hash.update(io.send(*this)); } /* * Create a new Client Hello message (session resumption case) */ -Client_Hello::Client_Hello(Handshake_Writer& writer, +Client_Hello::Client_Hello(Handshake_IO& io, Handshake_Hash& hash, const Policy& policy, RandomNumberGenerator& rng, @@ -135,7 +135,7 @@ Client_Hello::Client_Hello(Handshake_Writer& writer, for(size_t j = 0; j != sigs.size(); ++j) m_supported_algos.push_back(std::make_pair(hashes[i], sigs[j])); - hash.update(writer.send(*this)); + hash.update(io.send(*this)); } /* diff --git a/src/tls/c_kex.cpp b/src/tls/c_kex.cpp index f1b2306f1..1bf41cff2 100644 --- a/src/tls/c_kex.cpp +++ b/src/tls/c_kex.cpp @@ -8,7 +8,7 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_reader.h> #include <botan/internal/tls_extensions.h> -#include <botan/internal/tls_handshake_writer.h> +#include <botan/internal/tls_handshake_io.h> #include <botan/internal/assert.h> #include <botan/credentials_manager.h> #include <botan/pubkey.h> @@ -47,7 +47,7 @@ secure_vector<byte> strip_leading_zeros(const secure_vector<byte>& input) /* * Create a new Client Key Exchange message */ -Client_Key_Exchange::Client_Key_Exchange(Handshake_Writer& writer, +Client_Key_Exchange::Client_Key_Exchange(Handshake_IO& io, Handshake_State* state, const Policy& policy, Credentials_Manager& creds, @@ -259,7 +259,7 @@ Client_Key_Exchange::Client_Key_Exchange(Handshake_Writer& writer, pub_key->algo_name()); } - state->hash.update(writer.send(*this)); + state->hash.update(io.send(*this)); } /* diff --git a/src/tls/cert_req.cpp b/src/tls/cert_req.cpp index 0806f5f66..5087865d4 100644 --- a/src/tls/cert_req.cpp +++ b/src/tls/cert_req.cpp @@ -8,7 +8,7 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_reader.h> #include <botan/internal/tls_extensions.h> -#include <botan/internal/tls_handshake_writer.h> +#include <botan/internal/tls_handshake_io.h> #include <botan/der_enc.h> #include <botan/ber_dec.h> #include <botan/loadstor.h> @@ -51,7 +51,7 @@ byte cert_type_name_to_code(const std::string& name) /** * Create a new Certificate Request message */ -Certificate_Req::Certificate_Req(Handshake_Writer& writer, +Certificate_Req::Certificate_Req(Handshake_IO& io, Handshake_Hash& hash, const Policy& policy, const std::vector<X509_Certificate>& ca_certs, @@ -74,7 +74,7 @@ Certificate_Req::Certificate_Req(Handshake_Writer& writer, m_supported_algos.push_back(std::make_pair(hashes[i], sigs[j])); } - hash.update(writer.send(*this)); + hash.update(io.send(*this)); } /** @@ -166,12 +166,12 @@ std::vector<byte> Certificate_Req::serialize() const /** * Create a new Certificate message */ -Certificate::Certificate(Handshake_Writer& writer, +Certificate::Certificate(Handshake_IO& io, Handshake_Hash& hash, const std::vector<X509_Certificate>& cert_list) : m_certs(cert_list) { - hash.update(writer.send(*this)); + hash.update(io.send(*this)); } /** diff --git a/src/tls/cert_ver.cpp b/src/tls/cert_ver.cpp index 4dbae9da3..7a58ea28a 100644 --- a/src/tls/cert_ver.cpp +++ b/src/tls/cert_ver.cpp @@ -8,7 +8,7 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_reader.h> #include <botan/internal/tls_extensions.h> -#include <botan/internal/tls_handshake_writer.h> +#include <botan/internal/tls_handshake_io.h> #include <botan/internal/assert.h> #include <memory> @@ -19,7 +19,7 @@ namespace TLS { /* * Create a new Certificate Verify message */ -Certificate_Verify::Certificate_Verify(Handshake_Writer& writer, +Certificate_Verify::Certificate_Verify(Handshake_IO& io, Handshake_State* state, const Policy& policy, RandomNumberGenerator& rng, @@ -47,7 +47,7 @@ Certificate_Verify::Certificate_Verify(Handshake_Writer& writer, signature = signer.sign_message(state->hash.get_contents(), rng); } - state->hash.update(writer.send(*this)); + state->hash.update(io.send(*this)); } /* diff --git a/src/tls/finished.cpp b/src/tls/finished.cpp index 4dcc9e1ae..9205331de 100644 --- a/src/tls/finished.cpp +++ b/src/tls/finished.cpp @@ -6,7 +6,7 @@ */ #include <botan/internal/tls_messages.h> -#include <botan/internal/tls_handshake_writer.h> +#include <botan/internal/tls_handshake_io.h> #include <memory> namespace Botan { @@ -66,12 +66,12 @@ std::vector<byte> finished_compute_verify(Handshake_State* state, /* * Create a new Finished message */ -Finished::Finished(Handshake_Writer& writer, +Finished::Finished(Handshake_IO& io, Handshake_State* state, Connection_Side side) { verification_data = finished_compute_verify(state, side); - state->hash.update(writer.send(*this)); + state->hash.update(io.send(*this)); } /* diff --git a/src/tls/info.txt b/src/tls/info.txt index 212562373..bc2bc41c3 100644 --- a/src/tls/info.txt +++ b/src/tls/info.txt @@ -25,9 +25,8 @@ tls_version.h <header:internal> tls_extensions.h tls_handshake_hash.h -tls_handshake_reader.h +tls_handshake_io.h tls_handshake_state.h -tls_handshake_writer.h tls_heartbeats.h tls_messages.h tls_reader.h @@ -53,9 +52,8 @@ tls_ciphersuite.cpp tls_client.cpp tls_extensions.cpp tls_handshake_hash.cpp -tls_handshake_reader.cpp +tls_handshake_io.cpp tls_handshake_state.cpp -tls_handshake_writer.cpp tls_heartbeats.cpp tls_policy.cpp tls_server.cpp diff --git a/src/tls/next_protocol.cpp b/src/tls/next_protocol.cpp index a8989c5a9..71bb0eb9e 100644 --- a/src/tls/next_protocol.cpp +++ b/src/tls/next_protocol.cpp @@ -8,18 +8,18 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_extensions.h> #include <botan/internal/tls_reader.h> -#include <botan/internal/tls_handshake_writer.h> +#include <botan/internal/tls_handshake_io.h> namespace Botan { namespace TLS { -Next_Protocol::Next_Protocol(Handshake_Writer& writer, +Next_Protocol::Next_Protocol(Handshake_IO& io, Handshake_Hash& hash, const std::string& protocol) : m_protocol(protocol) { - hash.update(writer.send(*this)); + hash.update(io.send(*this)); } Next_Protocol::Next_Protocol(const std::vector<byte>& buf) diff --git a/src/tls/s_hello.cpp b/src/tls/s_hello.cpp index d34fa5e70..8d151b2b0 100644 --- a/src/tls/s_hello.cpp +++ b/src/tls/s_hello.cpp @@ -9,7 +9,7 @@ #include <botan/internal/tls_reader.h> #include <botan/internal/tls_session_key.h> #include <botan/internal/tls_extensions.h> -#include <botan/internal/tls_handshake_writer.h> +#include <botan/internal/tls_handshake_io.h> #include <botan/internal/stl_util.h> namespace Botan { @@ -19,7 +19,7 @@ namespace TLS { /* * Create a new Server Hello message */ -Server_Hello::Server_Hello(Handshake_Writer& writer, +Server_Hello::Server_Hello(Handshake_IO& io, Handshake_Hash& hash, const std::vector<byte>& session_id, Protocol_Version ver, @@ -47,7 +47,7 @@ Server_Hello::Server_Hello(Handshake_Writer& writer, m_supports_heartbeats(client_has_heartbeat), m_peer_can_send_heartbeats(true) { - hash.update(writer.send(*this)); + hash.update(io.send(*this)); } /* @@ -149,10 +149,10 @@ std::vector<byte> Server_Hello::serialize() const /* * Create a new Server Hello Done message */ -Server_Hello_Done::Server_Hello_Done(Handshake_Writer& writer, +Server_Hello_Done::Server_Hello_Done(Handshake_IO& io, Handshake_Hash& hash) { - hash.update(writer.send(*this)); + hash.update(io.send(*this)); } /* diff --git a/src/tls/s_kex.cpp b/src/tls/s_kex.cpp index 423497976..2d27c6f0d 100644 --- a/src/tls/s_kex.cpp +++ b/src/tls/s_kex.cpp @@ -8,7 +8,7 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_reader.h> #include <botan/internal/tls_extensions.h> -#include <botan/internal/tls_handshake_writer.h> +#include <botan/internal/tls_handshake_io.h> #include <botan/internal/assert.h> #include <botan/credentials_manager.h> #include <botan/loadstor.h> @@ -27,7 +27,7 @@ namespace TLS { /** * Create a new Server Key Exchange message */ -Server_Key_Exchange::Server_Key_Exchange(Handshake_Writer& writer, +Server_Key_Exchange::Server_Key_Exchange(Handshake_IO& io, Handshake_State* state, const Policy& policy, Credentials_Manager& creds, @@ -136,7 +136,7 @@ Server_Key_Exchange::Server_Key_Exchange(Handshake_Writer& writer, m_signature = signer.signature(rng); } - state->hash.update(writer.send(*this)); + state->hash.update(io.send(*this)); } /** diff --git a/src/tls/session_ticket.cpp b/src/tls/session_ticket.cpp index 3affe8fcf..2bb9987a9 100644 --- a/src/tls/session_ticket.cpp +++ b/src/tls/session_ticket.cpp @@ -8,28 +8,28 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_extensions.h> #include <botan/internal/tls_reader.h> -#include <botan/internal/tls_handshake_writer.h> +#include <botan/internal/tls_handshake_io.h> #include <botan/loadstor.h> namespace Botan { namespace TLS { -New_Session_Ticket::New_Session_Ticket(Handshake_Writer& writer, +New_Session_Ticket::New_Session_Ticket(Handshake_IO& io, Handshake_Hash& hash, const std::vector<byte>& ticket, u32bit lifetime) : m_ticket_lifetime_hint(lifetime), m_ticket(ticket) { - hash.update(writer.send(*this)); + hash.update(io.send(*this)); } -New_Session_Ticket::New_Session_Ticket(Handshake_Writer& writer, +New_Session_Ticket::New_Session_Ticket(Handshake_IO& io, Handshake_Hash& hash) : m_ticket_lifetime_hint(0) { - hash.update(writer.send(*this)); + hash.update(io.send(*this)); } New_Session_Ticket::New_Session_Ticket(const std::vector<byte>& buf) : diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index 4c9c12d92..0c1f9fd09 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -174,12 +174,12 @@ void Channel::read_handshake(byte rec_type, if(!m_state) m_state.reset(new_handshake_state()); - m_state->handshake_reader().add_input(rec_type, &rec_buf[0], rec_buf.size()); + m_state->handshake_io().add_input(rec_type, &rec_buf[0], rec_buf.size()); - while(m_state && m_state->handshake_reader().have_full_record()) + while(m_state && m_state->handshake_io().have_full_record()) { std::pair<Handshake_Type, std::vector<byte> > msg = - m_state->handshake_reader().get_next_record(); + m_state->handshake_io().get_next_record(); process_handshake_msg(msg.first, msg.second); } } diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index 17a7879d6..77ff010f3 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -45,8 +45,7 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn, Handshake_State* Client::new_handshake_state() { - return new Handshake_State(new Stream_Handshake_Reader, - new Stream_Handshake_Writer(m_writer)); + return new Handshake_State(new Stream_Handshake_IO(m_writer)); } /* @@ -84,7 +83,7 @@ void Client::initiate_handshake(bool force_full_renegotiation, if(srp_identifier == "" || session_info.srp_identifier() == srp_identifier) { m_state->client_hello = new Client_Hello( - m_state->handshake_writer(), + m_state->handshake_io(), m_state->hash, m_policy, m_rng, @@ -100,7 +99,7 @@ void Client::initiate_handshake(bool force_full_renegotiation, if(!m_state->client_hello) // not resuming { m_state->client_hello = new Client_Hello( - m_state->handshake_writer(), + m_state->handshake_io(), m_state->hash, version, m_policy, @@ -157,7 +156,7 @@ void Client::process_handshake_msg(Handshake_Type type, m_state->confirm_transition_to(type); if(type != HANDSHAKE_CCS && type != FINISHED) - m_state->hash.update(m_state->handshake_writer().format(contents, type)); + m_state->hash.update(m_state->handshake_io().format(contents, type)); if(type == SERVER_HELLO) { @@ -344,13 +343,13 @@ void Client::process_handshake_msg(Handshake_Type type, "tls-client", m_hostname); - m_state->client_certs = new Certificate(m_state->handshake_writer(), + m_state->client_certs = new Certificate(m_state->handshake_io(), m_state->hash, client_certs); } m_state->client_kex = - new Client_Key_Exchange(m_state->handshake_writer(), + new Client_Key_Exchange(m_state->handshake_io(), m_state.get(), m_policy, m_creds, @@ -370,7 +369,7 @@ void Client::process_handshake_msg(Handshake_Type type, "tls-client", m_hostname); - m_state->client_verify = new Certificate_Verify(m_state->handshake_writer(), + m_state->client_verify = new Certificate_Verify(m_state->handshake_io(), m_state.get(), m_policy, m_rng, @@ -389,10 +388,10 @@ void Client::process_handshake_msg(Handshake_Type type, const std::string protocol = m_state->client_npn_cb(m_state->server_hello->next_protocols()); - m_state->next_protocol = new Next_Protocol(m_state->handshake_writer(), m_state->hash, protocol); + m_state->next_protocol = new Next_Protocol(m_state->handshake_io(), m_state->hash, protocol); } - m_state->client_finished = new Finished(m_state->handshake_writer(), + m_state->client_finished = new Finished(m_state->handshake_io(), m_state.get(), CLIENT); if(m_state->server_hello->supports_session_ticket()) @@ -425,7 +424,7 @@ void Client::process_handshake_msg(Handshake_Type type, throw TLS_Exception(Alert::DECRYPT_ERROR, "Finished message didn't verify"); - m_state->hash.update(m_state->handshake_writer().format(contents, type)); + m_state->hash.update(m_state->handshake_io().format(contents, type)); if(!m_state->client_finished) // session resume case { @@ -436,7 +435,7 @@ void Client::process_handshake_msg(Handshake_Type type, m_state->keys, m_state->server_hello->compression_method()); - m_state->client_finished = new Finished(m_state->handshake_writer(), + m_state->client_finished = new Finished(m_state->handshake_io(), m_state.get(), CLIENT); } diff --git a/src/tls/tls_handshake_reader.cpp b/src/tls/tls_handshake_io.cpp index 3721ec5b5..fe1b9c790 100644 --- a/src/tls/tls_handshake_reader.cpp +++ b/src/tls/tls_handshake_io.cpp @@ -1,11 +1,13 @@ /* -* TLS Handshake Reader +* TLS Handshake IO * (C) 2012 Jack Lloyd * * Released under the terms of the Botan license */ -#include <botan/internal/tls_handshake_reader.h> +#include <botan/internal/tls_handshake_io.h> +#include <botan/internal/tls_messages.h> +#include <botan/tls_record.h> #include <botan/exceptn.h> namespace Botan { @@ -22,12 +24,18 @@ inline size_t load_be24(const byte q[3]) q[2]); } -} +void store_be24(byte out[3], size_t val) + { + out[0] = get_byte<u32bit>(1, val); + out[1] = get_byte<u32bit>(2, val); + out[2] = get_byte<u32bit>(3, val); + } +} -void Stream_Handshake_Reader::add_input(const byte rec_type, - const byte record[], - size_t record_size) +void Stream_Handshake_IO::add_input(const byte rec_type, + const byte record[], + size_t record_size) { if(rec_type == HANDSHAKE) { @@ -45,12 +53,12 @@ void Stream_Handshake_Reader::add_input(const byte rec_type, throw Decoding_Error("Unknown message type in handshake processing"); } -bool Stream_Handshake_Reader::empty() const +bool Stream_Handshake_IO::empty() const { return m_queue.empty(); } -bool Stream_Handshake_Reader::have_full_record() const +bool Stream_Handshake_IO::have_full_record() const { if(m_queue.size() >= 4) { @@ -62,7 +70,8 @@ bool Stream_Handshake_Reader::have_full_record() const return false; } -std::pair<Handshake_Type, std::vector<byte> > Stream_Handshake_Reader::get_next_record() +std::pair<Handshake_Type, std::vector<byte> > +Stream_Handshake_IO::get_next_record() { if(m_queue.size() >= 4) { @@ -81,7 +90,33 @@ std::pair<Handshake_Type, std::vector<byte> > Stream_Handshake_Reader::get_next_ } } - throw Internal_Error("Stream_Handshake_Reader::get_next_record called without a full record"); + throw Internal_Error("Stream_Handshake_IO::get_next_record called without a full record"); + } + +std::vector<byte> +Stream_Handshake_IO::format(const std::vector<byte>& msg, + Handshake_Type type) + { + std::vector<byte> send_buf(4 + msg.size()); + + const size_t buf_size = msg.size(); + + send_buf[0] = type; + + store_be24(&send_buf[1], buf_size); + + copy_mem(&send_buf[4], &msg[0], msg.size()); + + return send_buf; + } + +std::vector<byte> Stream_Handshake_IO::send(Handshake_Message& msg) + { + const std::vector<byte> buf = format(msg.serialize(), msg.type()); + + m_writer.send(HANDSHAKE, &buf[0], buf.size()); + + return buf; } } diff --git a/src/tls/tls_handshake_reader.h b/src/tls/tls_handshake_io.h index 791a2628a..f71b2c034 100644 --- a/src/tls/tls_handshake_reader.h +++ b/src/tls/tls_handshake_io.h @@ -1,12 +1,12 @@ /* -* TLS Handshake Reader +* TLS Handshake Serialization * (C) 2012 Jack Lloyd * * Released under the terms of the Botan license */ -#ifndef BOTAN_TLS_HANDSHAKE_READER_H__ -#define BOTAN_TLS_HANDSHAKE_READER_H__ +#ifndef BOTAN_TLS_HANDSHAKE_IO_H__ +#define BOTAN_TLS_HANDSHAKE_IO_H__ #include <botan/tls_magic.h> #include <botan/loadstor.h> @@ -18,12 +18,21 @@ namespace Botan { namespace TLS { +class Record_Writer; +class Handshake_Message; + /** -* Handshake Reader Interface +* Handshake IO Interface */ -class Handshake_Reader +class Handshake_IO { public: + virtual std::vector<byte> send(Handshake_Message& msg) = 0; + + virtual std::vector<byte> format( + const std::vector<byte>& handshake_msg, + Handshake_Type handshake_type) = 0; + virtual void add_input(byte record_type, const byte record[], size_t record_size) = 0; @@ -34,15 +43,29 @@ class Handshake_Reader virtual std::pair<Handshake_Type, std::vector<byte> > get_next_record() = 0; - virtual ~Handshake_Reader() {} + Handshake_IO() {} + + Handshake_IO(const Handshake_IO&) = delete; + + Handshake_IO& operator=(const Handshake_IO&) = delete; + + virtual ~Handshake_IO() {} }; /** -* Reader of TLS handshake messages +* Handshake IO for stream-based handshakes */ -class Stream_Handshake_Reader : public Handshake_Reader +class Stream_Handshake_IO : public Handshake_IO { public: + Stream_Handshake_IO(Record_Writer& writer) : m_writer(writer) {} + + std::vector<byte> send(Handshake_Message& msg) override; + + std::vector<byte> format( + const std::vector<byte>& handshake_msg, + Handshake_Type handshake_type) override; + void add_input(byte record_type, const byte record[], size_t record_size) override; @@ -54,6 +77,7 @@ class Stream_Handshake_Reader : public Handshake_Reader std::pair<Handshake_Type, std::vector<byte> > get_next_record() override; private: std::deque<byte> m_queue; + Record_Writer& m_writer; }; } diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp index 023b1816a..77a1b52fc 100644 --- a/src/tls/tls_handshake_state.cpp +++ b/src/tls/tls_handshake_state.cpp @@ -85,10 +85,8 @@ u32bit bitmask_for_handshake_type(Handshake_Type type) /* * Initialize the SSL/TLS Handshake State */ -Handshake_State::Handshake_State(Handshake_Reader* reader, - Handshake_Writer* writer) : - m_handshake_reader(reader), - m_handshake_writer(writer), +Handshake_State::Handshake_State(Handshake_IO* io) : + m_handshake_io(io), m_version(Protocol_Version::SSL_V3) { } @@ -345,8 +343,7 @@ Handshake_State::~Handshake_State() delete client_finished; delete server_finished; - delete m_handshake_reader; - delete m_handshake_writer; + delete m_handshake_io; } } diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h index 0f48c976b..49470fecb 100644 --- a/src/tls/tls_handshake_state.h +++ b/src/tls/tls_handshake_state.h @@ -9,8 +9,7 @@ #define BOTAN_TLS_HANDSHAKE_STATE_H__ #include <botan/internal/tls_handshake_hash.h> -#include <botan/internal/tls_handshake_reader.h> -#include <botan/internal/tls_handshake_writer.h> +#include <botan/internal/tls_handshake_io.h> #include <botan/internal/tls_session_key.h> #include <botan/pk_keys.h> #include <botan/pubkey.h> @@ -32,8 +31,7 @@ class Policy; class Handshake_State { public: - Handshake_State(Handshake_Reader* reader, - Handshake_Writer* writer); + Handshake_State(Handshake_IO* io); ~Handshake_State(); @@ -108,12 +106,9 @@ class Handshake_State */ std::function<std::string (std::vector<std::string>)> client_npn_cb; - Handshake_Reader& handshake_reader() { return *m_handshake_reader; } - - Handshake_Writer& handshake_writer() { return *m_handshake_writer; } + Handshake_IO& handshake_io() { return *m_handshake_io; } private: - Handshake_Reader* m_handshake_reader = nullptr; - Handshake_Writer* m_handshake_writer = nullptr; + Handshake_IO* m_handshake_io = nullptr; u32bit m_hand_expecting_mask = 0; u32bit m_hand_received_mask = 0; diff --git a/src/tls/tls_handshake_writer.cpp b/src/tls/tls_handshake_writer.cpp deleted file mode 100644 index 7af9a3f52..000000000 --- a/src/tls/tls_handshake_writer.cpp +++ /dev/null @@ -1,56 +0,0 @@ -/* -* Handshake Message Writer -* (C) 2012 Jack Lloyd -* -* Released under the terms of the Botan license -*/ - -#include <botan/internal/tls_handshake_writer.h> -#include <botan/internal/tls_messages.h> -#include <botan/tls_record.h> -#include <botan/exceptn.h> - -namespace Botan { - -namespace TLS { - -namespace { - -void store_be24(byte* out, size_t val) - { - out[0] = get_byte<u32bit>(1, val); - out[1] = get_byte<u32bit>(2, val); - out[2] = get_byte<u32bit>(3, val); - } - -} - -std::vector<byte> -Stream_Handshake_Writer::format(const std::vector<byte>& msg, - Handshake_Type type) - { - std::vector<byte> send_buf(4 + msg.size()); - - const size_t buf_size = msg.size(); - - send_buf[0] = type; - - store_be24(&send_buf[1], buf_size); - - copy_mem(&send_buf[4], &msg[0], msg.size()); - - return send_buf; - } - -std::vector<byte> Stream_Handshake_Writer::send(Handshake_Message& msg) - { - const std::vector<byte> buf = format(msg.serialize(), msg.type()); - - m_writer.send(HANDSHAKE, &buf[0], buf.size()); - - return buf; - } - -} - -} diff --git a/src/tls/tls_handshake_writer.h b/src/tls/tls_handshake_writer.h deleted file mode 100644 index 3bbb1c93e..000000000 --- a/src/tls/tls_handshake_writer.h +++ /dev/null @@ -1,66 +0,0 @@ -/* -* TLS Handshake Writer -* (C) 2012 Jack Lloyd -* -* Released under the terms of the Botan license -*/ - -#ifndef BOTAN_TLS_HANDSHAKE_WRITER_H__ -#define BOTAN_TLS_HANDSHAKE_WRITER_H__ - -#include <botan/tls_magic.h> -#include <botan/loadstor.h> -#include <vector> -#include <deque> -#include <utility> - -namespace Botan { - -namespace TLS { - -class Record_Writer; -class Handshake_Message; - -/** -* Handshake Writer -*/ -class Handshake_Writer - { - public: - virtual std::vector<byte> send(Handshake_Message& msg) = 0; - - virtual std::vector<byte> format( - const std::vector<byte>& handshake_msg, - Handshake_Type handshake_type) = 0; - - Handshake_Writer() {} - - Handshake_Writer(const Handshake_Writer&) = delete; - - Handshake_Writer& operator=(const Handshake_Writer&) = delete; - - virtual ~Handshake_Writer() {} - }; - -/** -* Stream Handshake Writer -*/ -class Stream_Handshake_Writer : public Handshake_Writer - { - public: - Stream_Handshake_Writer(Record_Writer& writer) : m_writer(writer) {} - - std::vector<byte> send(Handshake_Message& msg) override; - - std::vector<byte> format( - const std::vector<byte>& handshake_msg, - Handshake_Type handshake_type) override; - private: - Record_Writer& m_writer; - }; - -} - -} - -#endif diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h index a0e7d8630..0969aea06 100644 --- a/src/tls/tls_messages.h +++ b/src/tls/tls_messages.h @@ -25,7 +25,7 @@ class SRP6_Server_Session; namespace TLS { -class Handshake_Writer; +class Handshake_IO; /** * TLS Handshake Message Base Class @@ -112,7 +112,7 @@ class Client_Hello : public Handshake_Message bool peer_can_send_heartbeats() const { return m_peer_can_send_heartbeats; } - Client_Hello(Handshake_Writer& writer, + Client_Hello(Handshake_IO& io, Handshake_Hash& hash, Protocol_Version version, const Policy& policy, @@ -122,7 +122,7 @@ class Client_Hello : public Handshake_Message const std::string& hostname = "", const std::string& srp_identifier = ""); - Client_Hello(Handshake_Writer& writer, + Client_Hello(Handshake_IO& io, Handshake_Hash& hash, const Policy& policy, RandomNumberGenerator& rng, @@ -196,7 +196,7 @@ class Server_Hello : public Handshake_Message bool peer_can_send_heartbeats() const { return m_peer_can_send_heartbeats; } - Server_Hello(Handshake_Writer& writer, + Server_Hello(Handshake_IO& io, Handshake_Hash& hash, const std::vector<byte>& session_id, Protocol_Version ver, @@ -243,7 +243,7 @@ class Client_Key_Exchange : public Handshake_Message const secure_vector<byte>& pre_master_secret() const { return pre_master; } - Client_Key_Exchange(Handshake_Writer& output, + Client_Key_Exchange(Handshake_IO& io, Handshake_State* state, const Policy& policy, Credentials_Manager& creds, @@ -276,7 +276,7 @@ class Certificate : public Handshake_Message size_t count() const { return m_certs.size(); } bool empty() const { return m_certs.empty(); } - Certificate(Handshake_Writer& writer, + Certificate(Handshake_IO& io, Handshake_Hash& hash, const std::vector<X509_Certificate>& certs); @@ -303,7 +303,7 @@ class Certificate_Req : public Handshake_Message std::vector<std::pair<std::string, std::string> > supported_algos() const { return m_supported_algos; } - Certificate_Req(Handshake_Writer& writer, + Certificate_Req(Handshake_IO& io, Handshake_Hash& hash, const Policy& policy, const std::vector<X509_Certificate>& allowed_cas, @@ -336,7 +336,7 @@ class Certificate_Verify : public Handshake_Message bool verify(const X509_Certificate& cert, Handshake_State* state); - Certificate_Verify(Handshake_Writer& writer, + Certificate_Verify(Handshake_IO& io, Handshake_State* state, const Policy& policy, RandomNumberGenerator& rng, @@ -366,7 +366,7 @@ class Finished : public Handshake_Message bool verify(Handshake_State* state, Connection_Side side); - Finished(Handshake_Writer& writer, + Finished(Handshake_IO& io, Handshake_State* state, Connection_Side side); @@ -386,7 +386,7 @@ class Hello_Request : public Handshake_Message public: Handshake_Type type() const { return HELLO_REQUEST; } - Hello_Request(Handshake_Writer& writer); + Hello_Request(Handshake_IO& io); Hello_Request(const std::vector<byte>& buf); private: std::vector<byte> serialize() const; @@ -411,7 +411,7 @@ class Server_Key_Exchange : public Handshake_Message // Only valid for SRP negotiation SRP6_Server_Session& server_srp_params(); - Server_Key_Exchange(Handshake_Writer& writer, + Server_Key_Exchange(Handshake_IO& io, Handshake_State* state, const Policy& policy, Credentials_Manager& creds, @@ -445,7 +445,7 @@ class Server_Hello_Done : public Handshake_Message public: Handshake_Type type() const { return SERVER_HELLO_DONE; } - Server_Hello_Done(Handshake_Writer& writer, Handshake_Hash& hash); + Server_Hello_Done(Handshake_IO& io, Handshake_Hash& hash); Server_Hello_Done(const std::vector<byte>& buf); private: std::vector<byte> serialize() const; @@ -461,7 +461,7 @@ class Next_Protocol : public Handshake_Message std::string protocol() const { return m_protocol; } - Next_Protocol(Handshake_Writer& writer, + Next_Protocol(Handshake_IO& io, Handshake_Hash& hash, const std::string& protocol); @@ -480,12 +480,12 @@ class New_Session_Ticket : public Handshake_Message u32bit ticket_lifetime_hint() const { return m_ticket_lifetime_hint; } const std::vector<byte>& ticket() const { return m_ticket; } - New_Session_Ticket(Handshake_Writer& writer, + New_Session_Ticket(Handshake_IO& io, Handshake_Hash& hash, const std::vector<byte>& ticket, u32bit lifetime); - New_Session_Ticket(Handshake_Writer& writer, + New_Session_Ticket(Handshake_IO& io, Handshake_Hash& hash); New_Session_Ticket(const std::vector<byte>& buf); diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp index 0f1b24045..9c6250273 100644 --- a/src/tls/tls_server.cpp +++ b/src/tls/tls_server.cpp @@ -207,8 +207,7 @@ Server::Server(std::function<void (const byte[], size_t)> output_fn, Handshake_State* Server::new_handshake_state() { - return new Handshake_State(new Stream_Handshake_Reader, - new Stream_Handshake_Writer(m_writer)); + return new Handshake_State(new Stream_Handshake_IO(m_writer)); } /* @@ -223,7 +222,7 @@ void Server::renegotiate(bool force_full_renegotiation) m_state->allow_session_resumption = !force_full_renegotiation; m_state->set_expected_next(CLIENT_HELLO); - Hello_Request hello_req(m_state->handshake_writer()); + Hello_Request hello_req(m_state->handshake_io()); } void Server::alert_notify(const Alert& alert) @@ -273,7 +272,7 @@ void Server::process_handshake_msg(Handshake_Type type, if(type == CLIENT_HELLO_SSLV2) m_state->hash.update(contents); else - m_state->hash.update(m_state->handshake_writer().format(contents, type)); + m_state->hash.update(m_state->handshake_io().format(contents, type)); } if(type == CLIENT_HELLO || type == CLIENT_HELLO_SSLV2) @@ -374,7 +373,7 @@ void Server::process_handshake_msg(Handshake_Type type, // resume session m_state->server_hello = new Server_Hello( - m_state->handshake_writer(), + m_state->handshake_io(), m_state->hash, m_state->client_hello->session_id(), Protocol_Version(session_info.version()), @@ -410,7 +409,7 @@ void Server::process_handshake_msg(Handshake_Type type, if(m_state->server_hello->supports_session_ticket()) // send an empty ticket { m_state->new_session_ticket = - new New_Session_Ticket(m_state->handshake_writer(), + new New_Session_Ticket(m_state->handshake_io(), m_state->hash); } } @@ -422,7 +421,7 @@ void Server::process_handshake_msg(Handshake_Type type, const SymmetricKey ticket_key = m_creds.psk("tls-server", "session-ticket", ""); m_state->new_session_ticket = - new New_Session_Ticket(m_state->handshake_writer(), + new New_Session_Ticket(m_state->handshake_io(), m_state->hash, session_info.encrypt(ticket_key, m_rng), m_policy.session_ticket_lifetime()); @@ -432,7 +431,7 @@ void Server::process_handshake_msg(Handshake_Type type, if(!m_state->new_session_ticket) { m_state->new_session_ticket = - new New_Session_Ticket(m_state->handshake_writer(), + new New_Session_Ticket(m_state->handshake_io(), m_state->hash); } } @@ -444,7 +443,7 @@ void Server::process_handshake_msg(Handshake_Type type, m_state->keys, m_state->server_hello->compression_method()); - m_state->server_finished = new Finished(m_state->handshake_writer(), + m_state->server_finished = new Finished(m_state->handshake_io(), m_state.get(), SERVER); m_state->set_expected_next(HANDSHAKE_CCS); @@ -471,7 +470,7 @@ void Server::process_handshake_msg(Handshake_Type type, } m_state->server_hello = new Server_Hello( - m_state->handshake_writer(), + m_state->handshake_io(), m_state->hash, make_hello_random(m_rng), // new session ID m_state->version(), @@ -508,7 +507,7 @@ void Server::process_handshake_msg(Handshake_Type type, BOTAN_ASSERT(!cert_chains[sig_algo].empty(), "Attempting to send empty certificate chain"); - m_state->server_certs = new Certificate(m_state->handshake_writer(), + m_state->server_certs = new Certificate(m_state->handshake_io(), m_state->hash, cert_chains[sig_algo]); } @@ -533,7 +532,7 @@ void Server::process_handshake_msg(Handshake_Type type, else { m_state->server_kex = - new Server_Key_Exchange(m_state->handshake_writer(), + new Server_Key_Exchange(m_state->handshake_io(), m_state.get(), m_policy, m_creds, @@ -546,7 +545,7 @@ void Server::process_handshake_msg(Handshake_Type type, if(!client_auth_CAs.empty() && m_state->suite.sig_algo() != "") { - m_state->cert_req = new Certificate_Req(m_state->handshake_writer(), + m_state->cert_req = new Certificate_Req(m_state->handshake_io(), m_state->hash, m_policy, client_auth_CAs, @@ -562,7 +561,7 @@ void Server::process_handshake_msg(Handshake_Type type, */ m_state->set_expected_next(CLIENT_KEX); - m_state->server_hello_done = new Server_Hello_Done(m_state->handshake_writer(), + m_state->server_hello_done = new Server_Hello_Done(m_state->handshake_io(), m_state->hash); } } @@ -599,7 +598,7 @@ void Server::process_handshake_msg(Handshake_Type type, const bool sig_valid = m_state->client_verify->verify(m_peer_certs[0], m_state.get()); - m_state->hash.update(m_state->handshake_writer().format(contents, type)); + m_state->hash.update(m_state->handshake_io().format(contents, type)); /* * Using DECRYPT_ERROR looks weird here, but per RFC 4346 is for @@ -654,7 +653,7 @@ void Server::process_handshake_msg(Handshake_Type type, { // already sent finished if resuming, so this is a new session - m_state->hash.update(m_state->handshake_writer().format(contents, type)); + m_state->hash.update(m_state->handshake_io().format(contents, type)); Session session_info( m_state->server_hello->session_id(), @@ -680,7 +679,7 @@ void Server::process_handshake_msg(Handshake_Type type, const SymmetricKey ticket_key = m_creds.psk("tls-server", "session-ticket", ""); m_state->new_session_ticket = - new New_Session_Ticket(m_state->handshake_writer(), + new New_Session_Ticket(m_state->handshake_io(), m_state->hash, session_info.encrypt(ticket_key, m_rng), m_policy.session_ticket_lifetime()); @@ -693,7 +692,7 @@ void Server::process_handshake_msg(Handshake_Type type, if(m_state->server_hello->supports_session_ticket() && !m_state->new_session_ticket) { - m_state->new_session_ticket = new New_Session_Ticket(m_state->handshake_writer(), + m_state->new_session_ticket = new New_Session_Ticket(m_state->handshake_io(), m_state->hash); } @@ -705,7 +704,7 @@ void Server::process_handshake_msg(Handshake_Type type, m_state->keys, m_state->server_hello->compression_method()); - m_state->server_finished = new Finished(m_state->handshake_writer(), + m_state->server_finished = new Finished(m_state->handshake_io(), m_state.get(), SERVER); } |