diff options
-rw-r--r-- | src/tls/c_hello.cpp | 8 | ||||
-rw-r--r-- | src/tls/c_kex.cpp | 4 | ||||
-rw-r--r-- | src/tls/cert_req.cpp | 6 | ||||
-rw-r--r-- | src/tls/cert_ver.cpp | 4 | ||||
-rw-r--r-- | src/tls/finished.cpp | 6 | ||||
-rw-r--r-- | src/tls/info.txt | 2 | ||||
-rw-r--r-- | src/tls/next_protocol.cpp | 4 | ||||
-rw-r--r-- | src/tls/rec_wri.cpp | 19 | ||||
-rw-r--r-- | src/tls/s_hello.cpp | 6 | ||||
-rw-r--r-- | src/tls/s_kex.cpp | 4 | ||||
-rw-r--r-- | src/tls/session_ticket.cpp | 6 | ||||
-rw-r--r-- | src/tls/tls_channel.cpp | 12 | ||||
-rw-r--r-- | src/tls/tls_channel.h | 2 | ||||
-rw-r--r-- | src/tls/tls_client.cpp | 30 | ||||
-rw-r--r-- | src/tls/tls_client.h | 2 | ||||
-rw-r--r-- | src/tls/tls_handshake_state.cpp | 46 | ||||
-rw-r--r-- | src/tls/tls_handshake_state.h | 49 | ||||
-rw-r--r-- | src/tls/tls_handshake_writer.cpp | 38 | ||||
-rw-r--r-- | src/tls/tls_handshake_writer.h | 52 | ||||
-rw-r--r-- | src/tls/tls_messages.h | 31 | ||||
-rw-r--r-- | src/tls/tls_policy.cpp | 10 | ||||
-rw-r--r-- | src/tls/tls_record.h | 2 | ||||
-rw-r--r-- | src/tls/tls_server.cpp | 69 | ||||
-rw-r--r-- | src/tls/tls_server.h | 2 |
24 files changed, 248 insertions, 166 deletions
diff --git a/src/tls/c_hello.cpp b/src/tls/c_hello.cpp index c9249ab9a..465e6714a 100644 --- a/src/tls/c_hello.cpp +++ b/src/tls/c_hello.cpp @@ -9,7 +9,7 @@ #include <botan/internal/tls_reader.h> #include <botan/internal/tls_session_key.h> #include <botan/internal/tls_extensions.h> -#include <botan/tls_record.h> +#include <botan/internal/tls_handshake_writer.h> #include <botan/internal/stl_util.h> #include <chrono> @@ -36,7 +36,7 @@ std::vector<byte> make_hello_random(RandomNumberGenerator& rng) /* * Create a new Hello Request message */ -Hello_Request::Hello_Request(Record_Writer& writer) +Hello_Request::Hello_Request(Handshake_Writer& writer) { writer.send(*this); } @@ -61,7 +61,7 @@ std::vector<byte> Hello_Request::serialize() const /* * Create a new Client Hello message */ -Client_Hello::Client_Hello(Record_Writer& writer, +Client_Hello::Client_Hello(Handshake_Writer& writer, Handshake_Hash& hash, Protocol_Version version, const Policy& policy, @@ -98,7 +98,7 @@ Client_Hello::Client_Hello(Record_Writer& writer, /* * Create a new Client Hello message (session resumption case) */ -Client_Hello::Client_Hello(Record_Writer& writer, +Client_Hello::Client_Hello(Handshake_Writer& writer, Handshake_Hash& hash, const Policy& policy, RandomNumberGenerator& rng, diff --git a/src/tls/c_kex.cpp b/src/tls/c_kex.cpp index 28449e614..f1b2306f1 100644 --- a/src/tls/c_kex.cpp +++ b/src/tls/c_kex.cpp @@ -8,7 +8,7 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_reader.h> #include <botan/internal/tls_extensions.h> -#include <botan/tls_record.h> +#include <botan/internal/tls_handshake_writer.h> #include <botan/internal/assert.h> #include <botan/credentials_manager.h> #include <botan/pubkey.h> @@ -47,7 +47,7 @@ secure_vector<byte> strip_leading_zeros(const secure_vector<byte>& input) /* * Create a new Client Key Exchange message */ -Client_Key_Exchange::Client_Key_Exchange(Record_Writer& writer, +Client_Key_Exchange::Client_Key_Exchange(Handshake_Writer& writer, Handshake_State* state, const Policy& policy, Credentials_Manager& creds, diff --git a/src/tls/cert_req.cpp b/src/tls/cert_req.cpp index 4578148f5..0806f5f66 100644 --- a/src/tls/cert_req.cpp +++ b/src/tls/cert_req.cpp @@ -8,7 +8,7 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_reader.h> #include <botan/internal/tls_extensions.h> -#include <botan/tls_record.h> +#include <botan/internal/tls_handshake_writer.h> #include <botan/der_enc.h> #include <botan/ber_dec.h> #include <botan/loadstor.h> @@ -51,7 +51,7 @@ byte cert_type_name_to_code(const std::string& name) /** * Create a new Certificate Request message */ -Certificate_Req::Certificate_Req(Record_Writer& writer, +Certificate_Req::Certificate_Req(Handshake_Writer& writer, Handshake_Hash& hash, const Policy& policy, const std::vector<X509_Certificate>& ca_certs, @@ -166,7 +166,7 @@ std::vector<byte> Certificate_Req::serialize() const /** * Create a new Certificate message */ -Certificate::Certificate(Record_Writer& writer, +Certificate::Certificate(Handshake_Writer& writer, Handshake_Hash& hash, const std::vector<X509_Certificate>& cert_list) : m_certs(cert_list) diff --git a/src/tls/cert_ver.cpp b/src/tls/cert_ver.cpp index 870d70951..4dbae9da3 100644 --- a/src/tls/cert_ver.cpp +++ b/src/tls/cert_ver.cpp @@ -8,7 +8,7 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_reader.h> #include <botan/internal/tls_extensions.h> -#include <botan/tls_record.h> +#include <botan/internal/tls_handshake_writer.h> #include <botan/internal/assert.h> #include <memory> @@ -19,7 +19,7 @@ namespace TLS { /* * Create a new Certificate Verify message */ -Certificate_Verify::Certificate_Verify(Record_Writer& writer, +Certificate_Verify::Certificate_Verify(Handshake_Writer& writer, Handshake_State* state, const Policy& policy, RandomNumberGenerator& rng, diff --git a/src/tls/finished.cpp b/src/tls/finished.cpp index c8ae4a343..4dcc9e1ae 100644 --- a/src/tls/finished.cpp +++ b/src/tls/finished.cpp @@ -6,7 +6,7 @@ */ #include <botan/internal/tls_messages.h> -#include <botan/tls_record.h> +#include <botan/internal/tls_handshake_writer.h> #include <memory> namespace Botan { @@ -19,7 +19,7 @@ namespace { * Compute the verify_data */ std::vector<byte> finished_compute_verify(Handshake_State* state, - Connection_Side side) + Connection_Side side) { if(state->version() == Protocol_Version::SSL_V3) { @@ -66,7 +66,7 @@ std::vector<byte> finished_compute_verify(Handshake_State* state, /* * Create a new Finished message */ -Finished::Finished(Record_Writer& writer, +Finished::Finished(Handshake_Writer& writer, Handshake_State* state, Connection_Side side) { diff --git a/src/tls/info.txt b/src/tls/info.txt index 1863be577..212562373 100644 --- a/src/tls/info.txt +++ b/src/tls/info.txt @@ -27,6 +27,7 @@ tls_extensions.h tls_handshake_hash.h tls_handshake_reader.h tls_handshake_state.h +tls_handshake_writer.h tls_heartbeats.h tls_messages.h tls_reader.h @@ -54,6 +55,7 @@ tls_extensions.cpp tls_handshake_hash.cpp tls_handshake_reader.cpp tls_handshake_state.cpp +tls_handshake_writer.cpp tls_heartbeats.cpp tls_policy.cpp tls_server.cpp diff --git a/src/tls/next_protocol.cpp b/src/tls/next_protocol.cpp index adf9acbe9..a8989c5a9 100644 --- a/src/tls/next_protocol.cpp +++ b/src/tls/next_protocol.cpp @@ -8,13 +8,13 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_extensions.h> #include <botan/internal/tls_reader.h> -#include <botan/tls_record.h> +#include <botan/internal/tls_handshake_writer.h> namespace Botan { namespace TLS { -Next_Protocol::Next_Protocol(Record_Writer& writer, +Next_Protocol::Next_Protocol(Handshake_Writer& writer, Handshake_Hash& hash, const std::string& protocol) : m_protocol(protocol) diff --git a/src/tls/rec_wri.cpp b/src/tls/rec_wri.cpp index b5b9e826c..2523f8229 100644 --- a/src/tls/rec_wri.cpp +++ b/src/tls/rec_wri.cpp @@ -148,25 +148,6 @@ void Record_Writer::activate(Connection_Side side, throw Invalid_Argument("Record_Writer: Unknown hash " + mac_algo); } -std::vector<byte> Record_Writer::send(Handshake_Message& msg) - { - const std::vector<byte> buf = msg.serialize(); - std::vector<byte> send_buf(4); - - const size_t buf_size = buf.size(); - - send_buf[0] = msg.type(); - - for(size_t i = 1; i != 4; ++i) - send_buf[i] = get_byte<u32bit>(i, buf_size); - - send_buf += buf; - - send(HANDSHAKE, &send_buf[0], send_buf.size()); - - return send_buf; - } - /* * Send one or more records to the other side */ diff --git a/src/tls/s_hello.cpp b/src/tls/s_hello.cpp index 3b65b39f1..d34fa5e70 100644 --- a/src/tls/s_hello.cpp +++ b/src/tls/s_hello.cpp @@ -9,7 +9,7 @@ #include <botan/internal/tls_reader.h> #include <botan/internal/tls_session_key.h> #include <botan/internal/tls_extensions.h> -#include <botan/tls_record.h> +#include <botan/internal/tls_handshake_writer.h> #include <botan/internal/stl_util.h> namespace Botan { @@ -19,7 +19,7 @@ namespace TLS { /* * Create a new Server Hello message */ -Server_Hello::Server_Hello(Record_Writer& writer, +Server_Hello::Server_Hello(Handshake_Writer& writer, Handshake_Hash& hash, const std::vector<byte>& session_id, Protocol_Version ver, @@ -149,7 +149,7 @@ std::vector<byte> Server_Hello::serialize() const /* * Create a new Server Hello Done message */ -Server_Hello_Done::Server_Hello_Done(Record_Writer& writer, +Server_Hello_Done::Server_Hello_Done(Handshake_Writer& writer, Handshake_Hash& hash) { hash.update(writer.send(*this)); diff --git a/src/tls/s_kex.cpp b/src/tls/s_kex.cpp index 834dff979..423497976 100644 --- a/src/tls/s_kex.cpp +++ b/src/tls/s_kex.cpp @@ -8,7 +8,7 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_reader.h> #include <botan/internal/tls_extensions.h> -#include <botan/tls_record.h> +#include <botan/internal/tls_handshake_writer.h> #include <botan/internal/assert.h> #include <botan/credentials_manager.h> #include <botan/loadstor.h> @@ -27,7 +27,7 @@ namespace TLS { /** * Create a new Server Key Exchange message */ -Server_Key_Exchange::Server_Key_Exchange(Record_Writer& writer, +Server_Key_Exchange::Server_Key_Exchange(Handshake_Writer& writer, Handshake_State* state, const Policy& policy, Credentials_Manager& creds, diff --git a/src/tls/session_ticket.cpp b/src/tls/session_ticket.cpp index 8cee2a454..3affe8fcf 100644 --- a/src/tls/session_ticket.cpp +++ b/src/tls/session_ticket.cpp @@ -8,14 +8,14 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_extensions.h> #include <botan/internal/tls_reader.h> -#include <botan/tls_record.h> +#include <botan/internal/tls_handshake_writer.h> #include <botan/loadstor.h> namespace Botan { namespace TLS { -New_Session_Ticket::New_Session_Ticket(Record_Writer& writer, +New_Session_Ticket::New_Session_Ticket(Handshake_Writer& writer, Handshake_Hash& hash, const std::vector<byte>& ticket, u32bit lifetime) : @@ -25,7 +25,7 @@ New_Session_Ticket::New_Session_Ticket(Record_Writer& writer, hash.update(writer.send(*this)); } -New_Session_Ticket::New_Session_Ticket(Record_Writer& writer, +New_Session_Ticket::New_Session_Ticket(Handshake_Writer& writer, Handshake_Hash& hash) : m_ticket_lifetime_hint(0) { diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index 84ee69e04..d77f6dbcf 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -177,8 +177,8 @@ void Channel::read_handshake(byte rec_type, if(rec_type == HANDSHAKE) { if(!m_state) - m_state = new Handshake_State(this->new_handshake_reader()); - m_state->handshake_reader()->add_input(&rec_buf[0], rec_buf.size()); + m_state = new_handshake_state(); + m_state->handshake_reader().add_input(&rec_buf[0], rec_buf.size()); } BOTAN_ASSERT_NONNULL(m_state); @@ -189,10 +189,10 @@ void Channel::read_handshake(byte rec_type, if(rec_type == HANDSHAKE) { - if(m_state->handshake_reader()->have_full_record()) + if(m_state->handshake_reader().have_full_record()) { std::pair<Handshake_Type, std::vector<byte> > msg = - m_state->handshake_reader()->get_next_record(); + m_state->handshake_reader().get_next_record(); process_handshake_msg(msg.first, msg.second); } else @@ -200,7 +200,7 @@ void Channel::read_handshake(byte rec_type, } else if(rec_type == CHANGE_CIPHER_SPEC) { - if(m_state->handshake_reader()->empty() && rec_buf.size() == 1 && rec_buf[0] == 1) + if(m_state->handshake_reader().empty() && rec_buf.size() == 1 && rec_buf[0] == 1) process_handshake_msg(HANDSHAKE_CCS, std::vector<byte>()); else throw Decoding_Error("Malformed ChangeCipherSpec message"); @@ -208,7 +208,7 @@ void Channel::read_handshake(byte rec_type, else throw Decoding_Error("Unknown message type in handshake processing"); - if(type == HANDSHAKE_CCS || !m_state || !m_state->handshake_reader()->have_full_record()) + if(type == HANDSHAKE_CCS || !m_state || !m_state->handshake_reader().have_full_record()) break; } } diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h index ae4108e84..bd81a1745 100644 --- a/src/tls/tls_channel.h +++ b/src/tls/tls_channel.h @@ -111,7 +111,7 @@ class BOTAN_DLL Channel virtual void alert_notify(const Alert& alert) = 0; - virtual class Handshake_Reader* new_handshake_reader() const = 0; + virtual class Handshake_State* new_handshake_state() = 0; class Secure_Renegotiation_State { diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index 471cbefed..a62bcbba5 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -37,7 +37,7 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn, { m_writer.set_version(Protocol_Version::SSL_V3); - m_state = new Handshake_State(this->new_handshake_reader()); + m_state = new_handshake_state(); m_state->set_expected_next(SERVER_HELLO); m_state->client_npn_cb = next_protocol; @@ -54,7 +54,7 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn, if(session_info.srp_identifier() == srp_identifier) { m_state->client_hello = new Client_Hello( - m_writer, + m_state->handshake_writer(), m_state->hash, m_policy, m_rng, @@ -70,7 +70,7 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn, if(!m_state->client_hello) // not resuming { m_state->client_hello = new Client_Hello( - m_writer, + m_state->handshake_writer(), m_state->hash, m_policy.pref_version(), m_policy, @@ -84,9 +84,10 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn, m_secure_renegotiation.update(m_state->client_hello); } -Handshake_Reader* Client::new_handshake_reader() const +Handshake_State* Client::new_handshake_state() { - return new Stream_Handshake_Reader; + return new Handshake_State(new Stream_Handshake_Reader, + new Stream_Handshake_Writer(m_writer)); } /* @@ -98,7 +99,7 @@ void Client::renegotiate(bool force_full_renegotiation) return; // currently in active handshake delete m_state; - m_state = new Handshake_State(this->new_handshake_reader()); + m_state = new_handshake_state(); m_state->set_expected_next(SERVER_HELLO); @@ -108,7 +109,7 @@ void Client::renegotiate(bool force_full_renegotiation) if(m_session_manager.load_from_host_info(m_hostname, m_port, session_info)) { m_state->client_hello = new Client_Hello( - m_writer, + m_state->handshake_writer(), m_state->hash, m_policy, m_rng, @@ -122,7 +123,7 @@ void Client::renegotiate(bool force_full_renegotiation) if(!m_state->client_hello) { m_state->client_hello = new Client_Hello( - m_writer, + m_state->handshake_writer(), m_state->hash, m_reader.get_version(), m_policy, @@ -367,13 +368,13 @@ void Client::process_handshake_msg(Handshake_Type type, "tls-client", m_hostname); - m_state->client_certs = new Certificate(m_writer, + m_state->client_certs = new Certificate(m_state->handshake_writer(), m_state->hash, client_certs); } m_state->client_kex = - new Client_Key_Exchange(m_writer, + new Client_Key_Exchange(m_state->handshake_writer(), m_state, m_policy, m_creds, @@ -393,7 +394,7 @@ void Client::process_handshake_msg(Handshake_Type type, "tls-client", m_hostname); - m_state->client_verify = new Certificate_Verify(m_writer, + m_state->client_verify = new Certificate_Verify(m_state->handshake_writer(), m_state, m_policy, m_rng, @@ -410,10 +411,10 @@ void Client::process_handshake_msg(Handshake_Type type, const std::string protocol = m_state->client_npn_cb(m_state->server_hello->next_protocols()); - m_state->next_protocol = new Next_Protocol(m_writer, m_state->hash, protocol); + m_state->next_protocol = new Next_Protocol(m_state->handshake_writer(), m_state->hash, protocol); } - m_state->client_finished = new Finished(m_writer, m_state, CLIENT); + m_state->client_finished = new Finished(m_state->handshake_writer(), m_state, CLIENT); if(m_state->server_hello->supports_session_ticket()) m_state->set_expected_next(NEW_SESSION_TICKET); @@ -452,7 +453,8 @@ void Client::process_handshake_msg(Handshake_Type type, m_writer.activate(CLIENT, m_state->suite, m_state->keys, m_state->server_hello->compression_method()); - m_state->client_finished = new Finished(m_writer, m_state, CLIENT); + m_state->client_finished = new Finished(m_state->handshake_writer(), + m_state, CLIENT); } m_secure_renegotiation.update(m_state->client_finished, m_state->server_finished); diff --git a/src/tls/tls_client.h b/src/tls/tls_client.h index b62e4aadf..ad13a94dc 100644 --- a/src/tls/tls_client.h +++ b/src/tls/tls_client.h @@ -73,7 +73,7 @@ class BOTAN_DLL Client : public Channel void alert_notify(const Alert& alert) override; - class Handshake_Reader* new_handshake_reader() const override; + class Handshake_State* new_handshake_state() override; const Policy& m_policy; RandomNumberGenerator& m_rng; diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp index 8bb251b73..304366719 100644 --- a/src/tls/tls_handshake_state.cpp +++ b/src/tls/tls_handshake_state.cpp @@ -85,33 +85,12 @@ u32bit bitmask_for_handshake_type(Handshake_Type type) /* * Initialize the SSL/TLS Handshake State */ -Handshake_State::Handshake_State(Handshake_Reader* reader) +Handshake_State::Handshake_State(Handshake_Reader* reader, + Handshake_Writer* writer) : + m_handshake_reader(reader), + m_handshake_writer(writer), + m_version(Protocol_Version::SSL_V3) { - client_hello = nullptr; - server_hello = nullptr; - server_certs = nullptr; - server_kex = nullptr; - cert_req = nullptr; - server_hello_done = nullptr; - next_protocol = nullptr; - new_session_ticket = nullptr; - - client_certs = nullptr; - client_kex = nullptr; - client_verify = nullptr; - client_finished = nullptr; - server_finished = nullptr; - - m_handshake_reader = reader; - - server_rsa_kex_key = nullptr; - - m_version = Protocol_Version::SSL_V3; - - hand_expecting_mask = 0; - hand_received_mask = 0; - - allow_session_resumption = true; } void Handshake_State::set_version(const Protocol_Version& version) @@ -123,33 +102,33 @@ void Handshake_State::confirm_transition_to(Handshake_Type handshake_msg) { const u32bit mask = bitmask_for_handshake_type(handshake_msg); - hand_received_mask |= mask; + m_hand_received_mask |= mask; - const bool ok = (hand_expecting_mask & mask); // overlap? + const bool ok = (m_hand_expecting_mask & mask); // overlap? if(!ok) throw Unexpected_Message("Unexpected state transition in handshake, got " + std::to_string(handshake_msg) + - " expected " + std::to_string(hand_expecting_mask) + - " received " + std::to_string(hand_received_mask)); + " expected " + std::to_string(m_hand_expecting_mask) + + " received " + std::to_string(m_hand_received_mask)); /* We don't know what to expect next, so force a call to set_expected_next; if it doesn't happen, the next transition check will always fail which is what we want. */ - hand_expecting_mask = 0; + m_hand_expecting_mask = 0; } void Handshake_State::set_expected_next(Handshake_Type handshake_msg) { - hand_expecting_mask |= bitmask_for_handshake_type(handshake_msg); + m_hand_expecting_mask |= bitmask_for_handshake_type(handshake_msg); } bool Handshake_State::received_handshake_msg(Handshake_Type handshake_msg) const { const u32bit mask = bitmask_for_handshake_type(handshake_msg); - return (hand_received_mask & mask); + return (m_hand_received_mask & mask); } std::string Handshake_State::srp_identifier() const @@ -370,6 +349,7 @@ Handshake_State::~Handshake_State() delete server_finished; delete m_handshake_reader; + delete m_handshake_writer; } } diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h index 521da0205..0f48c976b 100644 --- a/src/tls/tls_handshake_state.h +++ b/src/tls/tls_handshake_state.h @@ -10,6 +10,7 @@ #include <botan/internal/tls_handshake_hash.h> #include <botan/internal/tls_handshake_reader.h> +#include <botan/internal/tls_handshake_writer.h> #include <botan/internal/tls_session_key.h> #include <botan/pk_keys.h> #include <botan/pubkey.h> @@ -31,7 +32,9 @@ class Policy; class Handshake_State { public: - Handshake_State(Handshake_Reader* reader); + Handshake_State(Handshake_Reader* reader, + Handshake_Writer* writer); + ~Handshake_State(); Handshake_State(const Handshake_State&) = delete; @@ -65,25 +68,25 @@ class Handshake_State void set_version(const Protocol_Version& version); - class Client_Hello* client_hello; - class Server_Hello* server_hello; - class Certificate* server_certs; - class Server_Key_Exchange* server_kex; - class Certificate_Req* cert_req; - class Server_Hello_Done* server_hello_done; + class Client_Hello* client_hello = nullptr; + class Server_Hello* server_hello = nullptr; + class Certificate* server_certs = nullptr; + class Server_Key_Exchange* server_kex = nullptr; + class Certificate_Req* cert_req = nullptr; + class Server_Hello_Done* server_hello_done = nullptr; - class Certificate* client_certs; - class Client_Key_Exchange* client_kex; - class Certificate_Verify* client_verify; + class Certificate* client_certs = nullptr; + class Client_Key_Exchange* client_kex = nullptr; + class Certificate_Verify* client_verify = nullptr; - class Next_Protocol* next_protocol; - class New_Session_Ticket* new_session_ticket; + class Next_Protocol* next_protocol = nullptr; + class New_Session_Ticket* new_session_ticket = nullptr; - class Finished* client_finished; - class Finished* server_finished; + class Finished* client_finished = nullptr; + class Finished* server_finished = nullptr; // Used by the server only, in case of RSA key exchange - Private_Key* server_rsa_kex_key; + Private_Key* server_rsa_kex_key = nullptr; Ciphersuite suite; Session_Keys keys; @@ -95,19 +98,25 @@ class Handshake_State secure_vector<byte> resume_master_secret; /* - * + * Used by the server to know if resumption should be allowed on + * a server-initiated renegotiation */ - bool allow_session_resumption; + bool allow_session_resumption = true; /** * Used by client using NPN */ std::function<std::string (std::vector<std::string>)> client_npn_cb; - Handshake_Reader* handshake_reader() { return m_handshake_reader; } + Handshake_Reader& handshake_reader() { return *m_handshake_reader; } + + Handshake_Writer& handshake_writer() { return *m_handshake_writer; } private: - Handshake_Reader* m_handshake_reader; - u32bit hand_expecting_mask, hand_received_mask; + Handshake_Reader* m_handshake_reader = nullptr; + Handshake_Writer* m_handshake_writer = nullptr; + + u32bit m_hand_expecting_mask = 0; + u32bit m_hand_received_mask = 0; Protocol_Version m_version; }; diff --git a/src/tls/tls_handshake_writer.cpp b/src/tls/tls_handshake_writer.cpp new file mode 100644 index 000000000..b237e8f3a --- /dev/null +++ b/src/tls/tls_handshake_writer.cpp @@ -0,0 +1,38 @@ +/* +* Handshake Message Writer +* (C) 2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#include <botan/internal/tls_handshake_writer.h> +#include <botan/internal/tls_messages.h> +#include <botan/tls_record.h> +#include <botan/exceptn.h> + +namespace Botan { + +namespace TLS { + +std::vector<byte> Stream_Handshake_Writer::send(Handshake_Message& msg) + { + const std::vector<byte> buf = msg.serialize(); + std::vector<byte> send_buf(4); + + const size_t buf_size = buf.size(); + + send_buf[0] = msg.type(); + + for(size_t i = 1; i != 4; ++i) + send_buf[i] = get_byte<u32bit>(i, buf_size); + + send_buf += buf; + + m_writer.send(HANDSHAKE, &send_buf[0], send_buf.size()); + + return send_buf; + } + +} + +} diff --git a/src/tls/tls_handshake_writer.h b/src/tls/tls_handshake_writer.h new file mode 100644 index 000000000..0d6ddb0a0 --- /dev/null +++ b/src/tls/tls_handshake_writer.h @@ -0,0 +1,52 @@ +/* +* TLS Handshake Writer +* (C) 2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#ifndef BOTAN_TLS_HANDSHAKE_WRITER_H__ +#define BOTAN_TLS_HANDSHAKE_WRITER_H__ + +#include <botan/tls_magic.h> +#include <botan/loadstor.h> +#include <vector> +#include <deque> +#include <utility> + +namespace Botan { + +namespace TLS { + +class Record_Writer; +class Handshake_Message; + +/** +* Handshake Writer +*/ +class Handshake_Writer + { + public: + virtual std::vector<byte> send(Handshake_Message& msg) = 0; + + virtual ~Handshake_Writer() {} + }; + +/** +* Stream Handshake Writer +*/ +class Stream_Handshake_Writer : public Handshake_Writer + { + public: + Stream_Handshake_Writer(Record_Writer& writer) : m_writer(writer) {} + + std::vector<byte> send(Handshake_Message& msg) override; + private: + Record_Writer& m_writer; + }; + +} + +} + +#endif diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h index 2e8bf9ba3..a0e7d8630 100644 --- a/src/tls/tls_messages.h +++ b/src/tls/tls_messages.h @@ -25,8 +25,7 @@ class SRP6_Server_Session; namespace TLS { -class Record_Writer; -class Record_Reader; +class Handshake_Writer; /** * TLS Handshake Message Base Class @@ -113,7 +112,7 @@ class Client_Hello : public Handshake_Message bool peer_can_send_heartbeats() const { return m_peer_can_send_heartbeats; } - Client_Hello(Record_Writer& writer, + Client_Hello(Handshake_Writer& writer, Handshake_Hash& hash, Protocol_Version version, const Policy& policy, @@ -123,7 +122,7 @@ class Client_Hello : public Handshake_Message const std::string& hostname = "", const std::string& srp_identifier = ""); - Client_Hello(Record_Writer& writer, + Client_Hello(Handshake_Writer& writer, Handshake_Hash& hash, const Policy& policy, RandomNumberGenerator& rng, @@ -197,7 +196,7 @@ class Server_Hello : public Handshake_Message bool peer_can_send_heartbeats() const { return m_peer_can_send_heartbeats; } - Server_Hello(Record_Writer& writer, + Server_Hello(Handshake_Writer& writer, Handshake_Hash& hash, const std::vector<byte>& session_id, Protocol_Version ver, @@ -244,7 +243,7 @@ class Client_Key_Exchange : public Handshake_Message const secure_vector<byte>& pre_master_secret() const { return pre_master; } - Client_Key_Exchange(Record_Writer& output, + Client_Key_Exchange(Handshake_Writer& output, Handshake_State* state, const Policy& policy, Credentials_Manager& creds, @@ -277,7 +276,7 @@ class Certificate : public Handshake_Message size_t count() const { return m_certs.size(); } bool empty() const { return m_certs.empty(); } - Certificate(Record_Writer& writer, + Certificate(Handshake_Writer& writer, Handshake_Hash& hash, const std::vector<X509_Certificate>& certs); @@ -304,7 +303,7 @@ class Certificate_Req : public Handshake_Message std::vector<std::pair<std::string, std::string> > supported_algos() const { return m_supported_algos; } - Certificate_Req(Record_Writer& writer, + Certificate_Req(Handshake_Writer& writer, Handshake_Hash& hash, const Policy& policy, const std::vector<X509_Certificate>& allowed_cas, @@ -337,7 +336,7 @@ class Certificate_Verify : public Handshake_Message bool verify(const X509_Certificate& cert, Handshake_State* state); - Certificate_Verify(Record_Writer& writer, + Certificate_Verify(Handshake_Writer& writer, Handshake_State* state, const Policy& policy, RandomNumberGenerator& rng, @@ -367,7 +366,7 @@ class Finished : public Handshake_Message bool verify(Handshake_State* state, Connection_Side side); - Finished(Record_Writer& writer, + Finished(Handshake_Writer& writer, Handshake_State* state, Connection_Side side); @@ -387,7 +386,7 @@ class Hello_Request : public Handshake_Message public: Handshake_Type type() const { return HELLO_REQUEST; } - Hello_Request(Record_Writer& writer); + Hello_Request(Handshake_Writer& writer); Hello_Request(const std::vector<byte>& buf); private: std::vector<byte> serialize() const; @@ -412,7 +411,7 @@ class Server_Key_Exchange : public Handshake_Message // Only valid for SRP negotiation SRP6_Server_Session& server_srp_params(); - Server_Key_Exchange(Record_Writer& writer, + Server_Key_Exchange(Handshake_Writer& writer, Handshake_State* state, const Policy& policy, Credentials_Manager& creds, @@ -446,7 +445,7 @@ class Server_Hello_Done : public Handshake_Message public: Handshake_Type type() const { return SERVER_HELLO_DONE; } - Server_Hello_Done(Record_Writer& writer, Handshake_Hash& hash); + Server_Hello_Done(Handshake_Writer& writer, Handshake_Hash& hash); Server_Hello_Done(const std::vector<byte>& buf); private: std::vector<byte> serialize() const; @@ -462,7 +461,7 @@ class Next_Protocol : public Handshake_Message std::string protocol() const { return m_protocol; } - Next_Protocol(Record_Writer& writer, + Next_Protocol(Handshake_Writer& writer, Handshake_Hash& hash, const std::string& protocol); @@ -481,12 +480,12 @@ class New_Session_Ticket : public Handshake_Message u32bit ticket_lifetime_hint() const { return m_ticket_lifetime_hint; } const std::vector<byte>& ticket() const { return m_ticket; } - New_Session_Ticket(Record_Writer& writer, + New_Session_Ticket(Handshake_Writer& writer, Handshake_Hash& hash, const std::vector<byte>& ticket, u32bit lifetime); - New_Session_Ticket(Record_Writer& writer, + New_Session_Ticket(Handshake_Writer& writer, Handshake_Hash& hash); New_Session_Ticket(const std::vector<byte>& buf); diff --git a/src/tls/tls_policy.cpp b/src/tls/tls_policy.cpp index 99ac66369..76492a668 100644 --- a/src/tls/tls_policy.cpp +++ b/src/tls/tls_policy.cpp @@ -83,11 +83,11 @@ std::vector<std::string> Policy::allowed_ecc_curves() const "secp256k1", "secp224r1", "secp224k1", - //"secp192r1", - //"secp192k1", - //"secp160r2", - //"secp160r1", - //"secp160k1", + "secp192r1", + "secp192k1", + "secp160r2", + "secp160r1", + "secp160k1", }); } diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h index 0b67f9a63..924d25f80 100644 --- a/src/tls/tls_record.h +++ b/src/tls/tls_record.h @@ -35,8 +35,6 @@ class BOTAN_DLL Record_Writer void send(byte type, const std::vector<byte>& input) { send(type, &input[0], input.size()); } - std::vector<byte> send(class Handshake_Message& msg); - void send_alert(const Alert& alert); void activate(Connection_Side side, diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp index 97db6934e..d6d408db5 100644 --- a/src/tls/tls_server.cpp +++ b/src/tls/tls_server.cpp @@ -200,9 +200,10 @@ Server::Server(std::function<void (const byte[], size_t)> output_fn, { } -Handshake_Reader* Server::new_handshake_reader() const +Handshake_State* Server::new_handshake_state() { - return new Stream_Handshake_Reader; + return new Handshake_State(new Stream_Handshake_Reader, + new Stream_Handshake_Writer(m_writer)); } /* @@ -213,11 +214,11 @@ void Server::renegotiate(bool force_full_renegotiation) if(m_state) return; // currently in handshake - m_state = new Handshake_State(this->new_handshake_reader()); + m_state = new_handshake_state(); m_state->allow_session_resumption = !force_full_renegotiation; m_state->set_expected_next(CLIENT_HELLO); - Hello_Request hello_req(m_writer); + Hello_Request hello_req(m_state->handshake_writer()); } void Server::alert_notify(const Alert& alert) @@ -240,7 +241,7 @@ void Server::read_handshake(byte rec_type, { if(rec_type == HANDSHAKE && !m_state) { - m_state = new Handshake_State(this->new_handshake_reader()); + m_state = new_handshake_state(); m_state->set_expected_next(CLIENT_HELLO); } @@ -368,7 +369,7 @@ void Server::process_handshake_msg(Handshake_Type type, // resume session m_state->server_hello = new Server_Hello( - m_writer, + m_state->handshake_writer(), m_state->hash, m_state->client_hello->session_id(), Protocol_Version(session_info.version()), @@ -402,7 +403,11 @@ void Server::process_handshake_msg(Handshake_Type type, m_session_manager.remove_entry(session_info.session_id()); if(m_state->server_hello->supports_session_ticket()) // send an empty ticket - m_state->new_session_ticket = new New_Session_Ticket(m_writer, m_state->hash); + { + m_state->new_session_ticket = + new New_Session_Ticket(m_state->handshake_writer(), + m_state->hash); + } } if(m_state->server_hello->supports_session_ticket() && !m_state->new_session_ticket) @@ -412,14 +417,19 @@ void Server::process_handshake_msg(Handshake_Type type, const SymmetricKey ticket_key = m_creds.psk("tls-server", "session-ticket", ""); m_state->new_session_ticket = - new New_Session_Ticket(m_writer, m_state->hash, + new New_Session_Ticket(m_state->handshake_writer(), + m_state->hash, session_info.encrypt(ticket_key, m_rng), m_policy.session_ticket_lifetime()); } catch(...) {} if(!m_state->new_session_ticket) - m_state->new_session_ticket = new New_Session_Ticket(m_writer, m_state->hash); + { + m_state->new_session_ticket = + new New_Session_Ticket(m_state->handshake_writer(), + m_state->hash); + } } m_writer.send(CHANGE_CIPHER_SPEC, 1); @@ -427,7 +437,7 @@ void Server::process_handshake_msg(Handshake_Type type, m_writer.activate(SERVER, m_state->suite, m_state->keys, m_state->server_hello->compression_method()); - m_state->server_finished = new Finished(m_writer, m_state, SERVER); + m_state->server_finished = new Finished(m_state->handshake_writer(), m_state, SERVER); m_state->set_expected_next(HANDSHAKE_CCS); } @@ -453,7 +463,7 @@ void Server::process_handshake_msg(Handshake_Type type, } m_state->server_hello = new Server_Hello( - m_writer, + m_state->handshake_writer(), m_state->hash, make_hello_random(m_rng), // new session ID m_state->version(), @@ -486,9 +496,9 @@ void Server::process_handshake_msg(Handshake_Type type, BOTAN_ASSERT(!cert_chains[sig_algo].empty(), "Attempting to send empty certificate chain"); - m_state->server_certs = new Certificate(m_writer, - m_state->hash, - cert_chains[sig_algo]); + m_state->server_certs = new Certificate(m_state->handshake_writer(), + m_state->hash, + cert_chains[sig_algo]); } Private_Key* private_key = nullptr; @@ -511,7 +521,12 @@ void Server::process_handshake_msg(Handshake_Type type, else { m_state->server_kex = - new Server_Key_Exchange(m_writer, m_state, m_policy, m_creds, m_rng, private_key); + new Server_Key_Exchange(m_state->handshake_writer(), + m_state, + m_policy, + m_creds, + m_rng, + private_key); } std::vector<X509_Certificate> client_auth_CAs = @@ -519,11 +534,11 @@ void Server::process_handshake_msg(Handshake_Type type, if(!client_auth_CAs.empty() && m_state->suite.sig_algo() != "") { - m_state->cert_req = new Certificate_Req(m_writer, - m_state->hash, - m_policy, - client_auth_CAs, - m_state->version()); + m_state->cert_req = new Certificate_Req(m_state->handshake_writer(), + m_state->hash, + m_policy, + client_auth_CAs, + m_state->version()); m_state->set_expected_next(CERTIFICATE); } @@ -535,7 +550,8 @@ void Server::process_handshake_msg(Handshake_Type type, */ m_state->set_expected_next(CLIENT_KEX); - m_state->server_hello_done = new Server_Hello_Done(m_writer, m_state->hash); + m_state->server_hello_done = new Server_Hello_Done(m_state->handshake_writer(), + m_state->hash); } } else if(type == CERTIFICATE) @@ -643,7 +659,8 @@ void Server::process_handshake_msg(Handshake_Type type, const SymmetricKey ticket_key = m_creds.psk("tls-server", "session-ticket", ""); m_state->new_session_ticket = - new New_Session_Ticket(m_writer, m_state->hash, + new New_Session_Ticket(m_state->handshake_writer(), + m_state->hash, session_info.encrypt(ticket_key, m_rng), m_policy.session_ticket_lifetime()); } @@ -654,14 +671,18 @@ void Server::process_handshake_msg(Handshake_Type type, } if(m_state->server_hello->supports_session_ticket() && !m_state->new_session_ticket) - m_state->new_session_ticket = new New_Session_Ticket(m_writer, m_state->hash); + { + m_state->new_session_ticket = new New_Session_Ticket(m_state->handshake_writer(), + m_state->hash); + + } m_writer.send(CHANGE_CIPHER_SPEC, 1); m_writer.activate(SERVER, m_state->suite, m_state->keys, m_state->server_hello->compression_method()); - m_state->server_finished = new Finished(m_writer, m_state, SERVER); + m_state->server_finished = new Finished(m_state->handshake_writer(), m_state, SERVER); } m_secure_renegotiation.update(m_state->client_finished, diff --git a/src/tls/tls_server.h b/src/tls/tls_server.h index 2d1502e1f..c0e687604 100644 --- a/src/tls/tls_server.h +++ b/src/tls/tls_server.h @@ -56,7 +56,7 @@ class BOTAN_DLL Server : public Channel void alert_notify(const Alert& alert) override; - class Handshake_Reader* new_handshake_reader() const override; + class Handshake_State* new_handshake_state() override; const Policy& m_policy; RandomNumberGenerator& m_rng; |