diff options
author | lloyd <[email protected]> | 2012-09-13 23:23:42 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2012-09-13 23:23:42 +0000 |
commit | bddb0f2075bd12ad88a541b1d9ba04d4a80e0767 (patch) | |
tree | b8c8e715ff62738a463dec762535e0beed27e922 /src/tls | |
parent | 4393f9d9db263510e59424a41b14f7cde7206825 (diff) |
Store the cipher states in the handshake state object as shared_ptrs.
One notable change here is that after we send a close_alert, we ignore
any data that follows. That is somewhat unfortunate actually, but
overall this change is important (for DTLS).
Diffstat (limited to 'src/tls')
-rw-r--r-- | src/tls/tls_channel.cpp | 76 | ||||
-rw-r--r-- | src/tls/tls_channel.h | 19 | ||||
-rw-r--r-- | src/tls/tls_handshake_state.cpp | 23 | ||||
-rw-r--r-- | src/tls/tls_handshake_state.h | 19 |
4 files changed, 94 insertions, 43 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index 6448ca2d4..95b3d1bbb 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -44,6 +44,24 @@ Connection_Sequence_Numbers& Channel::sequence_numbers() const return *m_sequence_numbers; } +std::shared_ptr<Connection_Cipher_State> Channel::read_cipher_state() const + { + if(m_pending_state) + return m_pending_state->read_cipher_state(); + if(m_active_state) + return m_active_state->read_cipher_state(); + return std::shared_ptr<Connection_Cipher_State>(nullptr); + } + +std::shared_ptr<Connection_Cipher_State> Channel::write_cipher_state() const + { + if(m_pending_state) + return m_pending_state->write_cipher_state(); + if(m_active_state) + return m_active_state->write_cipher_state(); + return std::shared_ptr<Connection_Cipher_State>(nullptr); + } + std::vector<X509_Certificate> Channel::peer_cert_chain() const { if(!m_active_state) @@ -91,7 +109,10 @@ Handshake_State& Channel::create_handshake_state(Protocol_Version version) m_pending_state.reset(new_handshake_state(io.release())); if(m_active_state) + { m_pending_state->set_version(m_active_state->version()); + m_pending_state->copy_cipher_states(*m_active_state); + } return *m_pending_state.get(); } @@ -127,14 +148,7 @@ void Channel::change_cipher_spec_reader(Connection_Side side) sequence_numbers().new_read_cipher_state(); // flip side as we are reading - side = (side == CLIENT) ? SERVER : CLIENT; - - m_read_cipherstate.reset( - new Connection_Cipher_State(m_pending_state->version(), - side, - m_pending_state->ciphersuite(), - m_pending_state->session_keys()) - ); + m_pending_state->new_read_cipher_state((side == CLIENT) ? SERVER : CLIENT); } void Channel::change_cipher_spec_writer(Connection_Side side) @@ -147,12 +161,7 @@ void Channel::change_cipher_spec_writer(Connection_Side side) sequence_numbers().new_write_cipher_state(); - m_write_cipherstate.reset( - new Connection_Cipher_State(m_pending_state->version(), - side, - m_pending_state->ciphersuite(), - m_pending_state->session_keys()) - ); + m_pending_state->new_write_cipher_state(side); } void Channel::activate_session() @@ -179,7 +188,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) { try { - while(buf_size) + while(!is_closed() && buf_size) { byte rec_type = NO_RECORD; std::vector<byte> record; @@ -188,6 +197,8 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) size_t consumed = 0; + std::shared_ptr<Connection_Cipher_State> cipher_state = read_cipher_state(); + const size_t needed = read_record(m_readbuf, buf, @@ -198,7 +209,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) record_version, record_sequence, m_sequence_numbers.get(), - m_read_cipherstate.get()); + cipher_state.get()); BOTAN_ASSERT(consumed <= buf_size, "Record reader consumed sane amount"); @@ -289,27 +300,22 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) m_proc_fn(nullptr, 0, alert_msg); - if(alert_msg.type() == Alert::CLOSE_NOTIFY) - { - if(!m_connection_closed) - send_alert(Alert(Alert::CLOSE_NOTIFY)); // reply in kind - m_read_cipherstate.reset(); - } - else if(alert_msg.is_fatal()) + if(alert_msg.is_fatal()) { - // delete state immediately - if(m_active_state && m_active_state->server_hello()) m_session_manager.remove_entry(m_active_state->server_hello()->session_id()); + } + if(alert_msg.type() == Alert::CLOSE_NOTIFY) + send_alert(Alert(Alert::CLOSE_NOTIFY)); // reply in kind + + if(alert_msg.type() == Alert::CLOSE_NOTIFY || alert_msg.is_fatal()) + { m_connection_closed = true; m_active_state.reset(); m_pending_state.reset(); - m_write_cipherstate.reset(); - m_read_cipherstate.reset(); - return 0; } } @@ -370,9 +376,11 @@ void Channel::send_record_array(byte type, const byte input[], size_t length) * * See http://www.openssl.org/~bodo/tls-cbc.txt for background. */ - if(type == APPLICATION_DATA && m_write_cipherstate->cbc_without_explicit_iv()) + std::shared_ptr<Connection_Cipher_State> cipher_state = write_cipher_state(); + + if(type == APPLICATION_DATA && cipher_state->cbc_without_explicit_iv()) { - write_record(type, &input[0], 1); + write_record(cipher_state.get(), type, &input[0], 1); input += 1; length -= 1; } @@ -380,7 +388,7 @@ void Channel::send_record_array(byte type, const byte input[], size_t length) while(length) { const size_t sending = std::min(length, m_max_fragment); - write_record(type, &input[0], sending); + write_record(cipher_state.get(), type, &input[0], sending); input += sending; length -= sending; @@ -392,7 +400,8 @@ void Channel::send_record(byte record_type, const std::vector<byte>& record) send_record_array(record_type, &record[0], record.size()); } -void Channel::write_record(byte record_type, const byte input[], size_t length) +void Channel::write_record(Connection_Cipher_State* cipher_state, + byte record_type, const byte input[], size_t length) { if(length > m_max_fragment) throw Internal_Error("Record is larger than allowed fragment size"); @@ -409,7 +418,7 @@ void Channel::write_record(byte record_type, const byte input[], size_t length) length, record_version, sequence_numbers(), - m_write_cipherstate.get(), + cipher_state, m_rng); m_output_fn(&m_writebuf[0], m_writebuf.size()); @@ -449,7 +458,6 @@ void Channel::send_alert(const Alert& alert) { m_active_state.reset(); m_pending_state.reset(); - m_write_cipherstate.reset(); m_connection_closed = true; } diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h index fa1fd3756..95420dcb3 100644 --- a/src/tls/tls_channel.h +++ b/src/tls/tls_channel.h @@ -21,6 +21,8 @@ namespace Botan { namespace TLS { +class Connection_Cipher_State; +class Connection_Sequence_Numbers; class Handshake_State; /** @@ -54,7 +56,7 @@ class BOTAN_DLL Channel /** * @return true iff the connection is active for sending application data */ - bool is_active() const { return m_active_state && !is_closed(); } + bool is_active() const { return m_active_state.get(); } /** * @return true iff the connection has been definitely closed @@ -160,13 +162,18 @@ class BOTAN_DLL Channel void send_record_array(byte type, const byte input[], size_t length); - void write_record(byte type, const byte input[], size_t length); + void write_record(Connection_Cipher_State* cipher_state, + byte type, const byte input[], size_t length); bool peer_supports_heartbeats() const; bool heartbeat_sending_allowed() const; - class Connection_Sequence_Numbers& sequence_numbers() const; + Connection_Sequence_Numbers& sequence_numbers() const; + + std::shared_ptr<Connection_Cipher_State> read_cipher_state() const; + + std::shared_ptr<Connection_Cipher_State> write_cipher_state() const; /* callbacks */ std::function<bool (const Session&)> m_handshake_fn; @@ -177,10 +184,8 @@ class BOTAN_DLL Channel RandomNumberGenerator& m_rng; Session_Manager& m_session_manager; - /* cipher/sequence state */ - std::unique_ptr<class Connection_Sequence_Numbers> m_sequence_numbers; - std::unique_ptr<class Connection_Cipher_State> m_write_cipherstate; - std::unique_ptr<class Connection_Cipher_State> m_read_cipherstate; + /* sequence number state */ + std::unique_ptr<Connection_Sequence_Numbers> m_sequence_numbers; /* I/O buffers */ std::vector<byte> m_writebuf; diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp index 30474aae0..044e97366 100644 --- a/src/tls/tls_handshake_state.cpp +++ b/src/tls/tls_handshake_state.cpp @@ -7,6 +7,7 @@ #include <botan/internal/tls_handshake_state.h> #include <botan/internal/tls_messages.h> +#include <botan/internal/tls_record.h> #include <botan/internal/assert.h> #include <botan/lookup.h> @@ -87,8 +88,8 @@ u32bit bitmask_for_handshake_type(Handshake_Type type) */ Handshake_State::Handshake_State(Handshake_IO* io, std::function<void (const Handshake_Message&)> msg_callback) : - m_handshake_io(io), m_msg_callback(msg_callback), + m_handshake_io(io), m_version(m_handshake_io->initial_record_version()) { } @@ -199,6 +200,26 @@ void Handshake_State::compute_session_keys(const secure_vector<byte>& resume_mas m_session_keys = Session_Keys(this, resume_master_secret, true); } +void Handshake_State::copy_cipher_states(const Handshake_State& prev_state) + { + m_write_cipher_state = prev_state.m_write_cipher_state; + m_read_cipher_state = prev_state.m_read_cipher_state; + } + +void Handshake_State::new_read_cipher_state(Connection_Side side) + { + m_read_cipher_state.reset( + new Connection_Cipher_State(version(), side, ciphersuite(), session_keys()) + ); + } + +void Handshake_State::new_write_cipher_state(Connection_Side side) + { + m_write_cipher_state.reset( + new Connection_Cipher_State(version(), side, ciphersuite(), session_keys()) + ); + } + void Handshake_State::confirm_transition_to(Handshake_Type handshake_msg) { const u32bit mask = bitmask_for_handshake_type(handshake_msg); diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h index 5145958ef..fee9bd5d1 100644 --- a/src/tls/tls_handshake_state.h +++ b/src/tls/tls_handshake_state.h @@ -41,6 +41,8 @@ class Next_Protocol; class New_Session_Ticket; class Finished; +class Connection_Cipher_State; + /** * SSL/TLS Handshake State */ @@ -118,6 +120,10 @@ class Handshake_State void server_finished(Finished* server_finished); void client_finished(Finished* client_finished); + void new_read_cipher_state(Connection_Side side); + + void new_write_cipher_state(Connection_Side side); + const Client_Hello* client_hello() const { return m_client_hello.get(); } @@ -175,11 +181,22 @@ class Handshake_State m_msg_callback(msg); } + std::shared_ptr<Connection_Cipher_State> read_cipher_state() + { return m_read_cipher_state; } + + std::shared_ptr<Connection_Cipher_State> write_cipher_state() + { return m_write_cipher_state; } + + void copy_cipher_states(const Handshake_State& prev_state); + private: + std::function<void (const Handshake_Message&)> m_msg_callback; + std::unique_ptr<Handshake_IO> m_handshake_io; - std::function<void (const Handshake_Message&)> m_msg_callback; + std::shared_ptr<Connection_Cipher_State> m_write_cipher_state; + std::shared_ptr<Connection_Cipher_State> m_read_cipher_state; u32bit m_hand_expecting_mask = 0; u32bit m_hand_received_mask = 0; |