From 9edb16ec3bb572bad5e51582dab5efea73c4fe14 Mon Sep 17 00:00:00 2001 From: lloyd Date: Tue, 6 Nov 2012 22:29:59 +0000 Subject: Pass read_record a callback mapping epoch to cipher state so it can read out of order messages in DTLS. --- src/tls/tls_channel.cpp | 57 +++++++++++++------------------------------------ src/tls/tls_channel.h | 6 ------ src/tls/tls_record.cpp | 26 ++++++++++++++++++---- src/tls/tls_record.h | 2 +- 4 files changed, 38 insertions(+), 53 deletions(-) diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index 46c5f4c74..88ca474f6 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -64,16 +64,6 @@ std::shared_ptr Channel::write_cipher_state_epoch(u16bi return i->second; } -std::shared_ptr Channel::read_cipher_state_current() const - { - return read_cipher_state_epoch(sequence_numbers().current_read_epoch()); - } - -std::shared_ptr Channel::write_cipher_state_current() const - { - return write_cipher_state_epoch(sequence_numbers().current_write_epoch()); - } - std::vector Channel::peer_cert_chain() const { if(auto active = active_state()) @@ -214,39 +204,20 @@ void Channel::activate_session() std::swap(m_active_state, m_pending_state); m_pending_state.reset(); - const u16bit last_valid_epoch = get_last_valid_epoch(); - - const auto obsolete_epoch = - [last_valid_epoch](u16bit epoch) { return (epoch < last_valid_epoch); }; - - map_remove_if(obsolete_epoch, m_write_cipher_states); - map_remove_if(obsolete_epoch, m_read_cipher_states); - } - -u16bit Channel::get_last_valid_epoch() const - { if(m_active_state->version().is_datagram_protocol()) { - // DTLS: find first epoch less than TCP MSL - - // FIXME: what about lost/retransmitted flights? - const std::chrono::seconds tcp_msl(120); - - for(auto i : m_read_cipher_states) - { - if(i.second->age() <= tcp_msl) - return i.first; - - if(i.first == sequence_numbers().current_read_epoch()) - return i.first; - } - - throw std::logic_error("Could not find current DTLS epoch"); + // FIXME, remove old states when we are sure not needed anymore } else { - // TLS is easy case - return sequence_numbers().current_write_epoch(); + // TLS is easy just remove all but the current state + auto current_epoch = sequence_numbers().current_write_epoch(); + + const auto not_current_epoch = + [current_epoch](u16bit epoch) { return (epoch != current_epoch); }; + + map_remove_if(not_current_epoch, m_write_cipher_states); + map_remove_if(not_current_epoch, m_read_cipher_states); } } @@ -266,6 +237,9 @@ bool Channel::heartbeat_sending_allowed() const size_t Channel::received_data(const byte buf[], size_t buf_size) { + const auto get_cipherstate = [this](u16bit epoch) + { return this->read_cipher_state_epoch(epoch).get(); }; + try { while(!is_closed() && buf_size) @@ -277,8 +251,6 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) size_t consumed = 0; - auto cipher_state = read_cipher_state_current(); - const size_t needed = read_record(m_readbuf, buf, @@ -289,7 +261,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) record_version, record_sequence, m_sequence_numbers.get(), - cipher_state.get()); + get_cipherstate); BOTAN_ASSERT(consumed <= buf_size, "Record reader consumed sane amount"); @@ -457,7 +429,8 @@ void Channel::send_record_array(byte type, const byte input[], size_t length) * See http://www.openssl.org/~bodo/tls-cbc.txt for background. */ - auto cipher_state = write_cipher_state_current(); + auto cipher_state = + write_cipher_state_epoch(sequence_numbers().current_write_epoch()); if(type == APPLICATION_DATA && cipher_state->cbc_without_explicit_iv()) { diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h index 3e400ed5d..b7b3d35de 100644 --- a/src/tls/tls_channel.h +++ b/src/tls/tls_channel.h @@ -186,12 +186,6 @@ class BOTAN_DLL Channel std::shared_ptr write_cipher_state_epoch(u16bit epoch) const; - std::shared_ptr read_cipher_state_current() const; - - std::shared_ptr write_cipher_state_current() const; - - u16bit get_last_valid_epoch() const; - const Handshake_State* active_state() const { return m_active_state.get(); } const Handshake_State* pending_state() const { return m_pending_state.get(); } diff --git a/src/tls/tls_record.cpp b/src/tls/tls_record.cpp index b2addf116..e11ba31b1 100644 --- a/src/tls/tls_record.cpp +++ b/src/tls/tls_record.cpp @@ -275,7 +275,7 @@ size_t read_record(std::vector& readbuf, Protocol_Version& record_version, u64bit& record_sequence, Connection_Sequence_Numbers* sequence_numbers, - Connection_Cipher_State* cipherstate) + std::function get_cipherstate) { consumed = 0; @@ -291,7 +291,7 @@ size_t read_record(std::vector& readbuf, } // Possible SSLv2 format client hello - if((!cipherstate) && (readbuf[0] & 0x80) && (readbuf[2] == 1)) + if(!sequence_numbers && (readbuf[0] & 0x80) && (readbuf[2] == 1)) { if(readbuf[3] == 0 && readbuf[4] == 2) throw TLS_Exception(Alert::PROTOCOL_VERSION, @@ -358,19 +358,31 @@ size_t read_record(std::vector& readbuf, readbuf.size(), "Have the full record"); + u16bit epoch = 0; + if(record_version.is_datagram_protocol()) + { record_sequence = load_be(&readbuf[3], 0); + epoch = (record_sequence >> 48); + } else if(sequence_numbers) + { record_sequence = sequence_numbers->next_read_sequence(); + epoch = sequence_numbers->current_read_epoch(); + } else - record_sequence = 0; // server initial handshake case + { + // server initial handshake case + record_sequence = 0; + epoch = 0; + } if(sequence_numbers && sequence_numbers->already_seen(record_sequence)) return 0; byte* record_contents = &readbuf[header_size]; - if(!cipherstate) // Unencrypted initial handshake + if(epoch == 0) // Unencrypted initial handshake { msg_type = readbuf[0]; msg.assign(&record_contents[0], &record_contents[record_len]); @@ -380,6 +392,12 @@ size_t read_record(std::vector& readbuf, } // Otherwise, decrypt, check MAC, return plaintext + Connection_Cipher_State* cipherstate = get_cipherstate(epoch); + + // FIXME: DTLS reordering might cause us not to have the cipher state + + BOTAN_ASSERT(cipherstate, "Have cipherstate for this epoch"); + const size_t block_size = cipherstate->block_size(); const size_t iv_size = cipherstate->iv_size(); const size_t mac_size = cipherstate->mac_size(); diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h index dbe77cbd2..a73efccb1 100644 --- a/src/tls/tls_record.h +++ b/src/tls/tls_record.h @@ -110,7 +110,7 @@ size_t read_record(std::vector& read_buffer, Protocol_Version& record_version, u64bit& record_sequence, Connection_Sequence_Numbers* sequence_numbers, - Connection_Cipher_State* cipherstate); + std::function get_cipherstate); } -- cgit v1.2.3