diff options
Diffstat (limited to 'src/tls')
-rw-r--r-- | src/tls/tls_channel.cpp | 109 | ||||
-rw-r--r-- | src/tls/tls_channel.h | 19 | ||||
-rw-r--r-- | src/tls/tls_handshake_state.cpp | 20 | ||||
-rw-r--r-- | src/tls/tls_handshake_state.h | 17 | ||||
-rw-r--r-- | src/tls/tls_record.cpp | 15 | ||||
-rw-r--r-- | src/tls/tls_record.h | 9 | ||||
-rw-r--r-- | src/tls/tls_seq_numbers.h | 22 |
7 files changed, 145 insertions, 66 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index be3f1c784..7065064cc 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -44,6 +44,36 @@ Connection_Sequence_Numbers& Channel::sequence_numbers() const return *m_sequence_numbers; } +std::shared_ptr<Connection_Cipher_State> Channel::read_cipher_state_epoch(u16bit epoch) const + { + auto i = m_read_cipher_states.find(epoch); + + BOTAN_ASSERT(i != m_read_cipher_states.end(), + "Have a cipher state for the specified epoch"); + + return i->second; + } + +std::shared_ptr<Connection_Cipher_State> Channel::write_cipher_state_epoch(u16bit epoch) const + { + auto i = m_write_cipher_states.find(epoch); + + BOTAN_ASSERT(i != m_write_cipher_states.end(), + "Have a cipher state for the specified epoch"); + + return i->second; + } + +std::shared_ptr<Connection_Cipher_State> Channel::read_cipher_state_current() const + { + return read_cipher_state_epoch(sequence_numbers().current_read_epoch()); + } + +std::shared_ptr<Connection_Cipher_State> Channel::write_cipher_state_current() const + { + return write_cipher_state_epoch(sequence_numbers().current_write_epoch()); + } + std::vector<X509_Certificate> Channel::peer_cert_chain() const { if(!m_active_state) @@ -91,10 +121,7 @@ Handshake_State& Channel::create_handshake_state(Protocol_Version version) m_pending_state.reset(new_handshake_state(io.release())); if(m_active_state) - { m_pending_state->set_version(m_active_state->version()); - m_pending_state->copy_cipher_states(*m_active_state); - } return *m_pending_state.get(); } @@ -129,8 +156,19 @@ void Channel::change_cipher_spec_reader(Connection_Side side) sequence_numbers().new_read_cipher_state(); + const u16bit epoch = sequence_numbers().current_read_epoch(); + + BOTAN_ASSERT(m_read_cipher_states.count(epoch) == 0, + "No read cipher state currently set for next epoch"); + // flip side as we are reading - m_pending_state->new_read_cipher_state((side == CLIENT) ? SERVER : CLIENT); + std::shared_ptr<Connection_Cipher_State> read_state( + new Connection_Cipher_State(m_pending_state->version(), + (side == CLIENT) ? SERVER : CLIENT, + m_pending_state->ciphersuite(), + m_pending_state->session_keys())); + + m_read_cipher_states[epoch] = read_state; } void Channel::change_cipher_spec_writer(Connection_Side side) @@ -143,7 +181,18 @@ void Channel::change_cipher_spec_writer(Connection_Side side) sequence_numbers().new_write_cipher_state(); - m_pending_state->new_write_cipher_state(side); + const u16bit epoch = sequence_numbers().current_write_epoch(); + + BOTAN_ASSERT(m_write_cipher_states.count(epoch) == 0, + "No write cipher state currently set for next epoch"); + + std::shared_ptr<Connection_Cipher_State> write_state( + new Connection_Cipher_State(m_pending_state->version(), + side, + m_pending_state->ciphersuite(), + m_pending_state->session_keys())); + + m_write_cipher_states[epoch] = write_state; } bool Channel::is_active() const @@ -160,6 +209,41 @@ void Channel::activate_session() { std::swap(m_active_state, m_pending_state); m_pending_state.reset(); + + const u16bit last_valid_epoch = get_last_valid_epoch(); + + const auto obsolete_epoch = + [last_valid_epoch](u16bit epoch) { return (epoch < last_valid_epoch); }; + + map_remove_if(obsolete_epoch, m_write_cipher_states); + map_remove_if(obsolete_epoch, m_read_cipher_states); + } + +u16bit Channel::get_last_valid_epoch() const + { + if(m_active_state->version().is_datagram_protocol()) + { + // DTLS: find first epoch less than TCP MSL + + // FIXME: what about lost/retransmitted flights? + const std::chrono::seconds tcp_msl(120); + + for(auto i : m_read_cipher_states) + { + if(i.second->age() <= tcp_msl) + return i.first; + + if(i.first == sequence_numbers().current_read_epoch()) + return i.first; + } + + throw std::logic_error("Could not find current DTLS epoch"); + } + else + { + // TLS is easy case + return sequence_numbers().current_write_epoch(); + } } bool Channel::peer_supports_heartbeats() const @@ -189,11 +273,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) size_t consumed = 0; - std::shared_ptr<Connection_Cipher_State> cipher_state; - if(m_pending_state) - cipher_state = m_pending_state->read_cipher_state(); - else if(m_active_state) - cipher_state = m_active_state->read_cipher_state(); + auto cipher_state = read_cipher_state_current(); const size_t needed = read_record(m_readbuf, @@ -231,7 +311,8 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) if(!m_pending_state) { create_handshake_state(record_version); - sequence_numbers().read_accept(record_sequence); + if(record_version.is_datagram_protocol()) + sequence_numbers().read_accept(record_sequence); } m_pending_state->handshake_io().add_input( @@ -372,12 +453,8 @@ void Channel::send_record_array(byte type, const byte input[], size_t length) * * See http://www.openssl.org/~bodo/tls-cbc.txt for background. */ - std::shared_ptr<Connection_Cipher_State> cipher_state; - if(m_pending_state) - cipher_state = m_pending_state->write_cipher_state(); - else if(m_active_state) - cipher_state = m_active_state->write_cipher_state(); + auto cipher_state = write_cipher_state_current(); if(type == APPLICATION_DATA && cipher_state->cbc_without_explicit_iv()) { diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h index 1a30c5604..77e7c81f1 100644 --- a/src/tls/tls_channel.h +++ b/src/tls/tls_channel.h @@ -16,6 +16,7 @@ #include <vector> #include <string> #include <memory> +#include <map> namespace Botan { @@ -181,6 +182,16 @@ class BOTAN_DLL Channel Connection_Sequence_Numbers& sequence_numbers() const; + std::shared_ptr<Connection_Cipher_State> read_cipher_state_epoch(u16bit epoch) const; + + std::shared_ptr<Connection_Cipher_State> write_cipher_state_epoch(u16bit epoch) const; + + std::shared_ptr<Connection_Cipher_State> read_cipher_state_current() const; + + std::shared_ptr<Connection_Cipher_State> write_cipher_state_current() const; + + u16bit get_last_valid_epoch() const; + /* callbacks */ std::function<bool (const Session&)> m_handshake_fn; std::function<void (const byte[], size_t, Alert)> m_proc_fn; @@ -197,7 +208,13 @@ class BOTAN_DLL Channel std::vector<byte> m_writebuf; std::vector<byte> m_readbuf; - /* connection parameters */ + /* cipher states for each epoch - epoch 0 is plaintext, thus null cipher state */ + std::map<u16bit, std::shared_ptr<Connection_Cipher_State>> m_write_cipher_states = + { { 0, nullptr } }; + std::map<u16bit, std::shared_ptr<Connection_Cipher_State>> m_read_cipher_states = + { { 0, nullptr } }; + + /* pending and active connection states */ std::unique_ptr<Handshake_State> m_active_state; std::unique_ptr<Handshake_State> m_pending_state; diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp index 044e97366..8ff0fb585 100644 --- a/src/tls/tls_handshake_state.cpp +++ b/src/tls/tls_handshake_state.cpp @@ -200,26 +200,6 @@ void Handshake_State::compute_session_keys(const secure_vector<byte>& resume_mas m_session_keys = Session_Keys(this, resume_master_secret, true); } -void Handshake_State::copy_cipher_states(const Handshake_State& prev_state) - { - m_write_cipher_state = prev_state.m_write_cipher_state; - m_read_cipher_state = prev_state.m_read_cipher_state; - } - -void Handshake_State::new_read_cipher_state(Connection_Side side) - { - m_read_cipher_state.reset( - new Connection_Cipher_State(version(), side, ciphersuite(), session_keys()) - ); - } - -void Handshake_State::new_write_cipher_state(Connection_Side side) - { - m_write_cipher_state.reset( - new Connection_Cipher_State(version(), side, ciphersuite(), session_keys()) - ); - } - void Handshake_State::confirm_transition_to(Handshake_Type handshake_msg) { const u32bit mask = bitmask_for_handshake_type(handshake_msg); diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h index fee9bd5d1..9afcd0374 100644 --- a/src/tls/tls_handshake_state.h +++ b/src/tls/tls_handshake_state.h @@ -41,8 +41,6 @@ class Next_Protocol; class New_Session_Ticket; class Finished; -class Connection_Cipher_State; - /** * SSL/TLS Handshake State */ @@ -120,10 +118,6 @@ class Handshake_State void server_finished(Finished* server_finished); void client_finished(Finished* client_finished); - void new_read_cipher_state(Connection_Side side); - - void new_write_cipher_state(Connection_Side side); - const Client_Hello* client_hello() const { return m_client_hello.get(); } @@ -181,23 +175,12 @@ class Handshake_State m_msg_callback(msg); } - std::shared_ptr<Connection_Cipher_State> read_cipher_state() - { return m_read_cipher_state; } - - std::shared_ptr<Connection_Cipher_State> write_cipher_state() - { return m_write_cipher_state; } - - void copy_cipher_states(const Handshake_State& prev_state); - private: std::function<void (const Handshake_Message&)> m_msg_callback; std::unique_ptr<Handshake_IO> m_handshake_io; - std::shared_ptr<Connection_Cipher_State> m_write_cipher_state; - std::shared_ptr<Connection_Cipher_State> m_read_cipher_state; - u32bit m_hand_expecting_mask = 0; u32bit m_hand_received_mask = 0; Protocol_Version m_version; diff --git a/src/tls/tls_record.cpp b/src/tls/tls_record.cpp index d0bc8bc69..b2addf116 100644 --- a/src/tls/tls_record.cpp +++ b/src/tls/tls_record.cpp @@ -24,6 +24,7 @@ Connection_Cipher_State::Connection_Cipher_State(Protocol_Version version, Connection_Side side, const Ciphersuite& suite, const Session_Keys& keys) : + m_start_time(std::chrono::system_clock::now()), m_is_ssl3(version == Protocol_Version::SSL_V3) { SymmetricKey mac_key, cipher_key; @@ -341,13 +342,6 @@ size_t read_record(std::vector<byte>& readbuf, const size_t header_size = (record_version.is_datagram_protocol()) ? DTLS_HEADER_SIZE : TLS_HEADER_SIZE; - if(record_version.is_datagram_protocol()) - record_sequence = load_be<u64bit>(&readbuf[3], 0); - else if(sequence_numbers) - record_sequence = sequence_numbers->next_read_sequence(); - else - record_sequence = 0; // server initial handshake case - const size_t record_len = make_u16bit(readbuf[header_size-2], readbuf[header_size-1]); @@ -364,6 +358,13 @@ size_t read_record(std::vector<byte>& readbuf, readbuf.size(), "Have the full record"); + if(record_version.is_datagram_protocol()) + record_sequence = load_be<u64bit>(&readbuf[3], 0); + else if(sequence_numbers) + record_sequence = sequence_numbers->next_read_sequence(); + else + record_sequence = 0; // server initial handshake case + if(sequence_numbers && sequence_numbers->already_seen(record_sequence)) return 0; diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h index 5c5d64d0d..dbe77cbd2 100644 --- a/src/tls/tls_record.h +++ b/src/tls/tls_record.h @@ -15,6 +15,7 @@ #include <botan/mac.h> #include <vector> #include <memory> +#include <chrono> namespace Botan { @@ -59,7 +60,15 @@ class Connection_Cipher_State bool cbc_without_explicit_iv() const { return (m_block_size > 0) && (m_iv_size == 0); } + + std::chrono::seconds age() const + { + return std::chrono::duration_cast<std::chrono::seconds>( + std::chrono::system_clock::now() - m_start_time); + } + private: + std::chrono::system_clock::time_point m_start_time; std::unique_ptr<BlockCipher> m_block_cipher; secure_vector<byte> m_block_cipher_cbc_state; std::unique_ptr<StreamCipher> m_stream_cipher; diff --git a/src/tls/tls_seq_numbers.h b/src/tls/tls_seq_numbers.h index c9a334e4b..4a8a0fab8 100644 --- a/src/tls/tls_seq_numbers.h +++ b/src/tls/tls_seq_numbers.h @@ -20,9 +20,12 @@ class Connection_Sequence_Numbers virtual void new_read_cipher_state() = 0; virtual void new_write_cipher_state() = 0; - virtual u64bit next_write_sequence() = 0; + virtual u16bit current_read_epoch() const = 0; + virtual u16bit current_write_epoch() const = 0; + virtual u64bit next_write_sequence() = 0; virtual u64bit next_read_sequence() = 0; + virtual bool already_seen(u64bit seq) const = 0; virtual void read_accept(u64bit seq) = 0; }; @@ -30,23 +33,28 @@ class Connection_Sequence_Numbers class Stream_Sequence_Numbers : public Connection_Sequence_Numbers { public: - void new_read_cipher_state() override { m_read_seq_no = 0; } - void new_write_cipher_state() override { m_write_seq_no = 0; } + void new_read_cipher_state() override { m_read_seq_no = 0; m_read_epoch += 1; } + void new_write_cipher_state() override { m_write_seq_no = 0; m_write_epoch += 1; } - u64bit next_write_sequence() override { return m_write_seq_no++; } + u16bit current_read_epoch() const override { return m_read_epoch; } + u16bit current_write_epoch() const override { return m_write_epoch; } + u64bit next_write_sequence() override { return m_write_seq_no++; } u64bit next_read_sequence() override { return m_read_seq_no; } + bool already_seen(u64bit) const override { return false; } void read_accept(u64bit) override { m_read_seq_no++; } private: u64bit m_write_seq_no = 0; u64bit m_read_seq_no = 0; + u16bit m_read_epoch = 0; + u16bit m_write_epoch = 0; }; class Datagram_Sequence_Numbers : public Connection_Sequence_Numbers { public: - void new_read_cipher_state() override {} + void new_read_cipher_state() override { m_read_epoch += 1; } void new_write_cipher_state() override { @@ -54,6 +62,9 @@ class Datagram_Sequence_Numbers : public Connection_Sequence_Numbers m_write_seq_no = ((m_write_seq_no >> 48) + 1) << 48; } + u16bit current_read_epoch() const override { return m_read_epoch; } + u16bit current_write_epoch() const override { return (m_write_seq_no >> 48); } + u64bit next_write_sequence() override { return m_write_seq_no++; } u64bit next_read_sequence() override @@ -101,6 +112,7 @@ class Datagram_Sequence_Numbers : public Connection_Sequence_Numbers private: u64bit m_write_seq_no = 0; + u16bit m_read_epoch = 0; u64bit m_window_highest = 0; u64bit m_window_bits = 0; }; |