diff options
author | lloyd <[email protected]> | 2012-11-06 21:29:02 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2012-11-06 21:29:02 +0000 |
commit | 38af5416771e39a1b0c2f628cf5ff01790922bee (patch) | |
tree | ad4407d1cc1e6f43d3a6621a0cbb8acfa74ca3d9 /src/tls | |
parent | dfbe4b328fa29b80f05bf89dd6c20be304312b17 (diff) |
Add Channel::pending_state and Channel::active_state, use where possible
Diffstat (limited to 'src/tls')
-rw-r--r-- | src/tls/tls_channel.cpp | 153 | ||||
-rw-r--r-- | src/tls/tls_channel.h | 4 |
2 files changed, 84 insertions, 73 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index 7065064cc..46c5f4c74 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -76,21 +76,21 @@ std::shared_ptr<Connection_Cipher_State> Channel::write_cipher_state_current() c std::vector<X509_Certificate> Channel::peer_cert_chain() const { - if(!m_active_state) - return std::vector<X509_Certificate>(); - return get_peer_cert_chain(*m_active_state); + if(auto active = active_state()) + return get_peer_cert_chain(*active); + return std::vector<X509_Certificate>(); } Handshake_State& Channel::create_handshake_state(Protocol_Version version) { const size_t dtls_mtu = 1400; // fixme should be settable - if(m_pending_state) + if(pending_state()) throw Internal_Error("create_handshake_state called during handshake"); - if(m_active_state) + if(auto active = active_state()) { - Protocol_Version active_version = m_active_state->version(); + Protocol_Version active_version = active->version(); if(active_version.is_datagram_protocol() != version.is_datagram_protocol()) throw std::runtime_error("Active state using version " + @@ -120,22 +120,22 @@ Handshake_State& Channel::create_handshake_state(Protocol_Version version) m_pending_state.reset(new_handshake_state(io.release())); - if(m_active_state) - m_pending_state->set_version(m_active_state->version()); + if(auto active = active_state()) + m_pending_state->set_version(active->version()); return *m_pending_state.get(); } void Channel::renegotiate(bool force_full_renegotiation) { - if(m_pending_state) // currently in handshake? + if(pending_state()) // currently in handshake? return; - if(!m_active_state) + if(auto active = active_state()) + initiate_handshake(create_handshake_state(active->version()), + force_full_renegotiation); + else throw std::runtime_error("Cannot renegotiate on inactive connection"); - - initiate_handshake(create_handshake_state(m_active_state->version()), - force_full_renegotiation); } void Channel::set_maximum_fragment_size(size_t max_fragment) @@ -148,10 +148,12 @@ void Channel::set_maximum_fragment_size(size_t max_fragment) void Channel::change_cipher_spec_reader(Connection_Side side) { - BOTAN_ASSERT(m_pending_state && m_pending_state->server_hello(), + auto pending = pending_state(); + + BOTAN_ASSERT(pending && pending->server_hello(), "Have received server hello"); - if(m_pending_state->server_hello()->compression_method() != NO_COMPRESSION) + if(pending->server_hello()->compression_method() != NO_COMPRESSION) throw Internal_Error("Negotiated unknown compression algorithm"); sequence_numbers().new_read_cipher_state(); @@ -163,20 +165,22 @@ void Channel::change_cipher_spec_reader(Connection_Side side) // flip side as we are reading std::shared_ptr<Connection_Cipher_State> read_state( - new Connection_Cipher_State(m_pending_state->version(), + new Connection_Cipher_State(pending->version(), (side == CLIENT) ? SERVER : CLIENT, - m_pending_state->ciphersuite(), - m_pending_state->session_keys())); + pending->ciphersuite(), + pending->session_keys())); m_read_cipher_states[epoch] = read_state; } void Channel::change_cipher_spec_writer(Connection_Side side) { - BOTAN_ASSERT(m_pending_state && m_pending_state->server_hello(), + auto pending = pending_state(); + + BOTAN_ASSERT(pending && pending->server_hello(), "Have received server hello"); - if(m_pending_state->server_hello()->compression_method() != NO_COMPRESSION) + if(pending->server_hello()->compression_method() != NO_COMPRESSION) throw Internal_Error("Negotiated unknown compression algorithm"); sequence_numbers().new_write_cipher_state(); @@ -187,17 +191,17 @@ void Channel::change_cipher_spec_writer(Connection_Side side) "No write cipher state currently set for next epoch"); std::shared_ptr<Connection_Cipher_State> write_state( - new Connection_Cipher_State(m_pending_state->version(), + new Connection_Cipher_State(pending->version(), side, - m_pending_state->ciphersuite(), - m_pending_state->session_keys())); + pending->ciphersuite(), + pending->session_keys())); m_write_cipher_states[epoch] = write_state; } bool Channel::is_active() const { - return m_active_state.get(); + return (active_state() != nullptr); } bool Channel::is_closed() const @@ -248,15 +252,15 @@ u16bit Channel::get_last_valid_epoch() const bool Channel::peer_supports_heartbeats() const { - if(m_active_state && m_active_state->server_hello()) - return m_active_state->server_hello()->supports_heartbeats(); + if(auto active = active_state()) + return active->server_hello()->supports_heartbeats(); return false; } bool Channel::heartbeat_sending_allowed() const { - if(m_active_state && m_active_state->server_hello()) - return m_active_state->server_hello()->peer_can_send_heartbeats(); + if(auto active = active_state()) + return active->server_hello()->peer_can_send_heartbeats(); return false; } @@ -318,21 +322,20 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) m_pending_state->handshake_io().add_input( rec_type, &record[0], record.size(), record_sequence); - while(m_pending_state) + while(auto pending = m_pending_state.get()) { - auto msg = m_pending_state->get_next_handshake_msg(); + auto msg = pending->get_next_handshake_msg(); if(msg.first == HANDSHAKE_NONE) // no full handshake yet break; - process_handshake_msg(m_active_state.get(), - *m_pending_state.get(), + process_handshake_msg(active_state(), *pending, msg.first, msg.second); } } else if(rec_type == HEARTBEAT && peer_supports_heartbeats()) { - if(!m_active_state) + if(!active_state()) throw Unexpected_Message("Heartbeat sent before handshake done"); Heartbeat_Message heartbeat(record); @@ -341,7 +344,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) if(heartbeat.is_request()) { - if(!m_pending_state) // no heartbeats during handshake + if(!pending_state()) { Heartbeat_Message response(Heartbeat_Message::RESPONSE, &payload[0], payload.size()); @@ -357,7 +360,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) } else if(rec_type == APPLICATION_DATA) { - if(!m_active_state) + if(!active_state()) throw Unexpected_Message("Application data before handshake done"); /* @@ -379,8 +382,8 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) if(alert_msg.is_fatal()) { - if(m_active_state && m_active_state->server_hello()) - m_session_manager.remove_entry(m_active_state->server_hello()->session_id()); + if(auto active = active_state()) + m_session_manager.remove_entry(active->server_hello()->session_id()); } if(alert_msg.type() == Alert::CLOSE_NOTIFY) @@ -529,8 +532,9 @@ void Channel::send_alert(const Alert& alert) if(alert.type() == Alert::NO_RENEGOTIATION) m_pending_state.reset(); - if(alert.is_fatal() && m_active_state && m_active_state->server_hello()) - m_session_manager.remove_entry(m_active_state->server_hello()->session_id()); + if(alert.is_fatal()) + if(auto active = active_state()) + m_session_manager.remove_entry(active->server_hello()->session_id()); if(alert.type() == Alert::CLOSE_NOTIFY || alert.is_fatal()) { @@ -545,9 +549,9 @@ void Channel::secure_renegotiation_check(const Client_Hello* client_hello) { const bool secure_renegotiation = client_hello->secure_renegotiation(); - if(m_active_state) + if(auto active = active_state()) { - const bool active_sr = m_active_state->client_hello()->secure_renegotiation(); + const bool active_sr = active->client_hello()->secure_renegotiation(); if(active_sr != secure_renegotiation) throw TLS_Exception(Alert::HANDSHAKE_FAILURE, @@ -568,9 +572,9 @@ void Channel::secure_renegotiation_check(const Server_Hello* server_hello) { const bool secure_renegotiation = server_hello->secure_renegotiation(); - if(m_active_state) + if(auto active = active_state()) { - const bool active_sr = m_active_state->client_hello()->secure_renegotiation(); + const bool active_sr = active->client_hello()->secure_renegotiation(); if(active_sr != secure_renegotiation) throw TLS_Exception(Alert::HANDSHAKE_FAILURE, @@ -589,17 +593,17 @@ void Channel::secure_renegotiation_check(const Server_Hello* server_hello) std::vector<byte> Channel::secure_renegotiation_data_for_client_hello() const { - if(m_active_state) - return m_active_state->client_finished()->verify_data(); + if(auto active = active_state()) + return active->client_finished()->verify_data(); return std::vector<byte>(); } std::vector<byte> Channel::secure_renegotiation_data_for_server_hello() const { - if(m_active_state) + if(auto active = active_state()) { - std::vector<byte> buf = m_active_state->client_finished()->verify_data(); - buf += m_active_state->server_finished()->verify_data(); + std::vector<byte> buf = active->client_finished()->verify_data(); + buf += active->server_finished()->verify_data(); return buf; } @@ -608,10 +612,13 @@ std::vector<byte> Channel::secure_renegotiation_data_for_server_hello() const bool Channel::secure_renegotiation_supported() const { - if(m_active_state) - return m_active_state->server_hello()->secure_renegotiation(); - if(m_pending_state && m_pending_state->server_hello()) - return m_pending_state->server_hello()->secure_renegotiation(); + if(auto active = active_state()) + return active->server_hello()->secure_renegotiation(); + + if(auto pending = pending_state()) + if(auto hello = pending->server_hello()) + return hello->secure_renegotiation(); + return false; } @@ -619,32 +626,32 @@ SymmetricKey Channel::key_material_export(const std::string& label, const std::string& context, size_t length) const { - if(!m_active_state) - throw std::runtime_error("Channel::key_material_export connection not active"); - - Handshake_State& state = *m_active_state; + if(auto active = active_state()) + { + std::unique_ptr<KDF> prf(active->protocol_specific_prf()); - std::unique_ptr<KDF> prf(state.protocol_specific_prf()); + const secure_vector<byte>& master_secret = + active->session_keys().master_secret(); - const secure_vector<byte>& master_secret = - state.session_keys().master_secret(); + std::vector<byte> salt; + salt += to_byte_vector(label); + salt += active->client_hello()->random(); + salt += active->server_hello()->random(); - std::vector<byte> salt; - salt += to_byte_vector(label); - salt += state.client_hello()->random(); - salt += state.server_hello()->random(); + if(context != "") + { + size_t context_size = context.length(); + if(context_size > 0xFFFF) + throw std::runtime_error("key_material_export context is too long"); + salt.push_back(get_byte<u16bit>(0, context_size)); + salt.push_back(get_byte<u16bit>(1, context_size)); + salt += to_byte_vector(context); + } - if(context != "") - { - size_t context_size = context.length(); - if(context_size > 0xFFFF) - throw std::runtime_error("key_material_export context is too long"); - salt.push_back(get_byte<u16bit>(0, context_size)); - salt.push_back(get_byte<u16bit>(1, context_size)); - salt += to_byte_vector(context); + return prf->derive_key(length, master_secret, salt); } - - return prf->derive_key(length, master_secret, salt); + else + throw std::runtime_error("Channel::key_material_export connection not active"); } } diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h index 77e7c81f1..3e400ed5d 100644 --- a/src/tls/tls_channel.h +++ b/src/tls/tls_channel.h @@ -192,6 +192,10 @@ class BOTAN_DLL Channel u16bit get_last_valid_epoch() const; + const Handshake_State* active_state() const { return m_active_state.get(); } + + const Handshake_State* pending_state() const { return m_pending_state.get(); } + /* callbacks */ std::function<bool (const Session&)> m_handshake_fn; std::function<void (const byte[], size_t, Alert)> m_proc_fn; |