From 6925789077edbeab13540cde3cf84b1d0e6feefc Mon Sep 17 00:00:00 2001 From: lloyd Date: Fri, 7 Sep 2012 16:49:44 +0000 Subject: Keep two handshake states around, swap them when Channel::activate_session is called. --- src/tls/tls_channel.cpp | 117 ++++++++++++++++++++++++------------------------ src/tls/tls_channel.h | 16 +++---- src/tls/tls_client.cpp | 2 +- src/tls/tls_server.cpp | 2 +- 4 files changed, 69 insertions(+), 68 deletions(-) (limited to 'src') diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index 4a37f1c7a..62fc47c45 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -40,28 +40,28 @@ Channel::~Channel() Handshake_State& Channel::create_handshake_state() { - if(m_state) + if(m_pending_state) throw Internal_Error("create_handshake_state called during handshake"); - m_state.reset(new_handshake_state()); + m_pending_state.reset(new_handshake_state()); - return *m_state.get(); + return *m_pending_state.get(); } void Channel::renegotiate(bool force_full_renegotiation) { - if(m_state) // currently in handshake? + if(m_pending_state) // currently in handshake? return; - m_state.reset(new_handshake_state()); + m_pending_state.reset(new_handshake_state()); - initiate_handshake(*m_state.get(), force_full_renegotiation); + initiate_handshake(*m_pending_state.get(), force_full_renegotiation); } void Channel::set_protocol_version(Protocol_Version version) { m_current_version = version; - m_state->set_version(version); + m_pending_state->set_version(version); } void Channel::set_maximum_fragment_size(size_t max_fragment) @@ -74,7 +74,10 @@ void Channel::set_maximum_fragment_size(size_t max_fragment) void Channel::change_cipher_spec_reader(Connection_Side side) { - if(m_state->server_hello()->compression_method()!= NO_COMPRESSION) + BOTAN_ASSERT(m_pending_state && m_pending_state->server_hello(), + "Have received server hello"); + + if(m_pending_state->server_hello()->compression_method()!= NO_COMPRESSION) throw Internal_Error("Negotiated unknown compression algorithm"); m_read_seq_no = 0; @@ -85,14 +88,17 @@ void Channel::change_cipher_spec_reader(Connection_Side side) m_read_cipherstate.reset( new Connection_Cipher_State(current_protocol_version(), side, - m_state->ciphersuite(), - m_state->session_keys()) + m_pending_state->ciphersuite(), + m_pending_state->session_keys()) ); } void Channel::change_cipher_spec_writer(Connection_Side side) { - if(m_state->server_hello()->compression_method()!= NO_COMPRESSION) + BOTAN_ASSERT(m_pending_state && m_pending_state->server_hello(), + "Have received server hello"); + + if(m_pending_state->server_hello()->compression_method()!= NO_COMPRESSION) throw Internal_Error("Negotiated unknown compression algorithm"); /* @@ -111,25 +117,32 @@ void Channel::change_cipher_spec_writer(Connection_Side side) m_write_cipherstate.reset( new Connection_Cipher_State(current_protocol_version(), side, - m_state->ciphersuite(), - m_state->session_keys()) + m_pending_state->ciphersuite(), + m_pending_state->session_keys()) ); } -void Channel::activate_session(const std::vector& session_id) +void Channel::activate_session() { - m_secure_renegotiation.update(m_state->client_finished(), - m_state->server_finished()); + m_secure_renegotiation.update(m_pending_state->client_finished(), + m_pending_state->server_finished()); - m_state.reset(); - m_handshake_completed = true; - m_active_session = session_id; + std::swap(m_active_state, m_pending_state); + m_pending_state.reset(); } -void Channel::heartbeat_support(bool peer_supports, bool sending_allowed) +bool Channel::peer_supports_heartbeats() const { - m_peer_supports_heartbeats = peer_supports; - m_heartbeat_sending_allowed = sending_allowed; + if(m_active_state && m_active_state->server_hello()) + return m_active_state->server_hello()->supports_heartbeats(); + return false; + } + +bool Channel::heartbeat_sending_allowed() const + { + if(m_active_state && m_active_state->server_hello()) + return m_active_state->server_hello()->peer_can_send_heartbeats(); + return false; } size_t Channel::received_data(const byte buf[], size_t buf_size) @@ -177,25 +190,23 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) if(rec_type == HANDSHAKE || rec_type == CHANGE_CIPHER_SPEC) { - if(!m_state) - m_state.reset(new_handshake_state()); + if(!m_pending_state) + m_pending_state.reset(new_handshake_state()); - m_state->handshake_io().add_input(rec_type, - &record[0], - record.size(), - record_number); + m_pending_state->handshake_io().add_input( + rec_type, &record[0], record.size(), record_number); - while(m_state) + while(m_pending_state) { - auto msg = m_state->get_next_handshake_msg(); + auto msg = m_pending_state->get_next_handshake_msg(); if(msg.first == HANDSHAKE_NONE) // no full handshake yet break; - process_handshake_msg(*m_state.get(), msg.first, msg.second); + process_handshake_msg(*m_pending_state.get(), msg.first, msg.second); } } - else if(rec_type == HEARTBEAT && m_peer_supports_heartbeats) + else if(rec_type == HEARTBEAT && peer_supports_heartbeats()) { Heartbeat_Message heartbeat(record); @@ -203,7 +214,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) if(heartbeat.is_request()) { - if(!m_state) // no heartbeats during handshake + if(!m_pending_state) // no heartbeats during handshake { Heartbeat_Message response(Heartbeat_Message::RESPONSE, &payload[0], payload.size()); @@ -219,7 +230,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) } else if(rec_type == APPLICATION_DATA) { - if(m_handshake_completed) + if(m_active_state) { /* * OpenSSL among others sends empty records in versions @@ -239,10 +250,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) Alert alert_msg(record); if(alert_msg.type() == Alert::NO_RENEGOTIATION) - { - if(m_handshake_completed && m_state) - m_state.reset(); - } + m_pending_state.reset(); m_proc_fn(nullptr, 0, alert_msg); @@ -257,15 +265,13 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) { // delete state immediately - if(!m_active_session.empty()) - { - m_session_manager.remove_entry(m_active_session); - m_active_session.clear(); - } + if(m_active_state && m_active_state->server_hello()) + m_session_manager.remove_entry(m_active_state->server_hello()->session_id()); m_connection_closed = true; - m_state.reset(); + m_active_state.reset(); + m_pending_state.reset(); m_write_cipherstate.reset(); m_read_cipherstate.reset(); @@ -302,10 +308,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) void Channel::heartbeat(const byte payload[], size_t payload_size) { - if(!is_active()) - throw std::runtime_error("Heartbeat cannot be sent on inactive TLS connection"); - - if(m_heartbeat_sending_allowed) + if(heartbeat_sending_allowed()) { Heartbeat_Message heartbeat(Heartbeat_Message::REQUEST, payload, payload_size); @@ -362,10 +365,10 @@ void Channel::write_record(byte record_type, const byte input[], size_t length) Protocol_Version record_version = current_protocol_version(); if(!record_version.valid()) { - BOTAN_ASSERT(m_state && !m_state->server_hello(), + BOTAN_ASSERT(m_pending_state && !m_pending_state->server_hello(), "In first record of client connection"); - record_version = m_state->handshake_io().initial_record_version(); + record_version = m_pending_state->handshake_io().initial_record_version(); } TLS::write_record(m_writebuf, @@ -406,19 +409,17 @@ void Channel::send_alert(const Alert& alert) } if(alert.type() == Alert::NO_RENEGOTIATION) - m_state.reset(); + m_pending_state.reset(); - if(alert.is_fatal() && !m_active_session.empty()) - { - m_session_manager.remove_entry(m_active_session); - m_active_session.clear(); - } + if(alert.is_fatal() && m_active_state && m_active_state->server_hello()) + m_session_manager.remove_entry(m_active_state->server_hello()->session_id()); - if(!m_connection_closed && (alert.type() == Alert::CLOSE_NOTIFY || alert.is_fatal())) + if(alert.type() == Alert::CLOSE_NOTIFY || alert.is_fatal()) { m_connection_closed = true; - m_state.reset(); + m_active_state.reset(); + m_pending_state.reset(); m_write_cipherstate.reset(); } } diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h index 29567da4c..109fbc771 100644 --- a/src/tls/tls_channel.h +++ b/src/tls/tls_channel.h @@ -52,7 +52,7 @@ class BOTAN_DLL Channel /** * @return true iff the connection is active for sending application data */ - bool is_active() const { return m_handshake_completed && !is_closed(); } + bool is_active() const { return m_active_state && !is_closed(); } /** * @return true iff the connection has been definitely closed @@ -114,7 +114,7 @@ class BOTAN_DLL Channel */ void send_alert(const Alert& alert); - void activate_session(const std::vector& session_id); + void activate_session(); void heartbeat_support(bool peer_supports, bool allowed_to_send); @@ -174,6 +174,10 @@ class BOTAN_DLL Channel void write_record(byte type, const byte input[], size_t length); + bool peer_supports_heartbeats() const; + + bool heartbeat_sending_allowed() const; + /* callbacks */ std::function m_proc_fn; std::function m_output_fn; @@ -190,17 +194,13 @@ class BOTAN_DLL Channel u64bit m_read_seq_no = 0; /* connection parameters */ - std::unique_ptr m_state; + std::unique_ptr m_active_state; + std::unique_ptr m_pending_state; Protocol_Version m_current_version; size_t m_max_fragment = MAX_PLAINTEXT_SIZE; - bool m_peer_supports_heartbeats = false; - bool m_heartbeat_sending_allowed = false; - bool m_connection_closed = false; - bool m_handshake_completed = false; - std::vector m_active_session; }; } diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index c88d0319a..0fdc24f59 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -516,7 +516,7 @@ void Client::process_handshake_msg(Handshake_State& state, m_session_manager.remove_entry(session_info.session_id()); } - activate_session(session_info.session_id()); + activate_session(); } else throw Unexpected_Message("Unknown handshake message received"); diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp index 5cd00c524..313640008 100644 --- a/src/tls/tls_server.cpp +++ b/src/tls/tls_server.cpp @@ -703,7 +703,7 @@ void Server::process_handshake_msg(Handshake_State& state, ); } - activate_session(state.server_hello()->session_id()); + activate_session(); } else throw Unexpected_Message("Unknown handshake message received"); -- cgit v1.2.3