diff options
author | lloyd <[email protected]> | 2014-10-06 01:29:13 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2014-10-06 01:29:13 +0000 |
commit | 2d6a5e530c8db496aad61b5a9ab3107dd1ed646b (patch) | |
tree | 29d92fc311f65ca88b812dadf3462c3ad1fdb0f9 /src/lib | |
parent | 97010abaf527fdbe6e308cb3570f9167c1dc9ec1 (diff) |
Add support for DTLS handshake timeouts and retransmissions.
Diffstat (limited to 'src/lib')
-rw-r--r-- | src/lib/tls/tls_channel.cpp | 93 | ||||
-rw-r--r-- | src/lib/tls/tls_channel.h | 15 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_io.cpp | 101 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_io.h | 43 | ||||
-rw-r--r-- | src/lib/tls/tls_reader.h | 4 | ||||
-rw-r--r-- | src/lib/tls/tls_record.cpp | 9 | ||||
-rw-r--r-- | src/lib/tls/tls_record.h | 2 | ||||
-rw-r--r-- | src/lib/tls/tls_seq_numbers.h | 22 |
8 files changed, 224 insertions, 65 deletions
diff --git a/src/lib/tls/tls_channel.cpp b/src/lib/tls/tls_channel.cpp index 30f30d623..0617f992c 100644 --- a/src/lib/tls/tls_channel.cpp +++ b/src/lib/tls/tls_channel.cpp @@ -1,6 +1,6 @@ /* * TLS Channels -* (C) 2011-2012 Jack Lloyd +* (C) 2011,2012,2014 Jack Lloyd * * Released under the terms of the Botan license */ @@ -113,19 +113,22 @@ Handshake_State& Channel::create_handshake_state(Protocol_Version version) m_sequence_numbers.reset(new Stream_Sequence_Numbers); } + using namespace std::placeholders; + std::unique_ptr<Handshake_IO> io; if(version.is_datagram_protocol()) + { + // default MTU is IPv6 min MTU minus UDP/IP headers (TODO: make configurable) + const u16bit mtu = 1280 - 40 - 8; + io.reset(new Datagram_Handshake_IO( sequence_numbers(), - std::bind(&Channel::send_record_under_epoch, this, - std::placeholders::_1, - std::placeholders::_2, - std::placeholders::_3))); + std::bind(&Channel::send_record_under_epoch, this, _1, _2, _3), + mtu)); + } else io.reset(new Stream_Handshake_IO( - std::bind(&Channel::send_record, this, - std::placeholders::_1, - std::placeholders::_2))); + std::bind(&Channel::send_record, this, _1, _2))); m_pending_state.reset(new_handshake_state(io.release())); @@ -135,6 +138,13 @@ Handshake_State& Channel::create_handshake_state(Protocol_Version version) return *m_pending_state.get(); } +bool Channel::timeout_check() + { + if(m_pending_state) + return m_pending_state->handshake_io().timeout_check(); + return false; + } + void Channel::renegotiate(bool force_full_renegotiation) { if(pending_state()) // currently in handshake? @@ -280,9 +290,6 @@ size_t Channel::received_data(const std::vector<byte>& buf) size_t Channel::received_data(const byte input[], size_t input_size) { - const auto get_cipherstate = [this](u16bit epoch) - { return this->read_cipher_state_epoch(epoch).get(); }; - const size_t max_fragment_size = maximum_fragment_size(); try @@ -306,7 +313,10 @@ size_t Channel::received_data(const byte input[], size_t input_size) &record_version, &record_type, m_sequence_numbers.get(), - get_cipherstate); + std::bind(&TLS::Channel::read_cipher_state_epoch, this, + std::placeholders::_1)); + + BOTAN_ASSERT(consumed > 0, "Got to eat something"); BOTAN_ASSERT(consumed <= input_size, "Record reader consumed sane amount"); @@ -328,24 +338,50 @@ size_t Channel::received_data(const byte input[], size_t input_size) { if(!m_pending_state) { - create_handshake_state(record_version); if(record_version.is_datagram_protocol()) + { sequence_numbers().read_accept(record_sequence); - } - m_pending_state->handshake_io().add_record(unlock(record), - record_type, - record_sequence); + /* + * Might be a peer retransmit under epoch - 1 in which + * case we must retransmit last flight + */ + + const u16bit epoch = record_sequence >> 48; + + if(epoch == sequence_numbers().current_read_epoch()) + { + create_handshake_state(record_version); + } + else if(epoch == sequence_numbers().current_read_epoch() - 1) + { + m_active_state->handshake_io().add_record(unlock(record), + record_type, + record_sequence); + } + } + else + { + create_handshake_state(record_version); + } + } - while(auto pending = m_pending_state.get()) + if(m_pending_state) { - auto msg = pending->get_next_handshake_msg(); + m_pending_state->handshake_io().add_record(unlock(record), + record_type, + record_sequence); + + while(auto pending = m_pending_state.get()) + { + auto msg = pending->get_next_handshake_msg(); - if(msg.first == HANDSHAKE_NONE) // no full handshake yet - break; + if(msg.first == HANDSHAKE_NONE) // no full handshake yet + break; - process_handshake_msg(active_state(), *pending, - msg.first, msg.second); + process_handshake_msg(active_state(), *pending, + msg.first, msg.second); + } } } else if(record_type == HEARTBEAT && peer_supports_heartbeats()) @@ -450,11 +486,10 @@ void Channel::heartbeat(const byte payload[], size_t payload_size) } } -void Channel::write_record(Connection_Cipher_State* cipher_state, +void Channel::write_record(Connection_Cipher_State* cipher_state, u16bit epoch, byte record_type, const byte input[], size_t length) { - BOTAN_ASSERT(m_pending_state || m_active_state, - "Some connection state exists"); + BOTAN_ASSERT(m_pending_state || m_active_state, "Some connection state exists"); Protocol_Version record_version = (m_pending_state) ? (m_pending_state->version()) : (m_active_state->version()); @@ -464,7 +499,7 @@ void Channel::write_record(Connection_Cipher_State* cipher_state, input, length, record_version, - sequence_numbers().next_write_sequence(), + sequence_numbers().next_write_sequence(epoch), cipher_state, m_rng); @@ -492,7 +527,7 @@ void Channel::send_record_array(u16bit epoch, byte type, const byte input[], siz if(type == APPLICATION_DATA && cipher_state->cbc_without_explicit_iv()) { - write_record(cipher_state.get(), type, &input[0], 1); + write_record(cipher_state.get(), epoch, type, &input[0], 1); input += 1; length -= 1; } @@ -502,7 +537,7 @@ void Channel::send_record_array(u16bit epoch, byte type, const byte input[], siz while(length) { const size_t sending = std::min(length, max_fragment_size); - write_record(cipher_state.get(), type, &input[0], sending); + write_record(cipher_state.get(), epoch, type, &input[0], sending); input += sending; length -= sending; diff --git a/src/lib/tls/tls_channel.h b/src/lib/tls/tls_channel.h index 6c159689a..3cdfe3d5e 100644 --- a/src/lib/tls/tls_channel.h +++ b/src/lib/tls/tls_channel.h @@ -1,6 +1,6 @@ /* * TLS Channel -* (C) 2011,2012 Jack Lloyd +* (C) 2011,2012,2014 Jack Lloyd * * Released under the terms of the Botan license */ @@ -46,17 +46,28 @@ class BOTAN_DLL Channel size_t received_data(const std::vector<byte>& buf); /** + * Perform a handshake timeout check. This does nothing unless + * this is a DTLS channel with a pending handshake state, in + * which case we check for timeout and potentially retransmit + * handshake packets. + */ + bool timeout_check(); + + /** * Inject plaintext intended for counterparty + * Throws an exception if is_active() is false */ void send(const byte buf[], size_t buf_size); /** * Inject plaintext intended for counterparty + * Throws an exception if is_active() is false */ void send(const std::string& val); /** * Inject plaintext intended for counterparty + * Throws an exception if is_active() is false */ template<typename Alloc> void send(const std::vector<unsigned char, Alloc>& val) @@ -209,7 +220,7 @@ class BOTAN_DLL Channel const byte input[], size_t length); void write_record(Connection_Cipher_State* cipher_state, - byte type, const byte input[], size_t length); + u16bit epoch, byte type, const byte input[], size_t length); Connection_Sequence_Numbers& sequence_numbers() const; diff --git a/src/lib/tls/tls_handshake_io.cpp b/src/lib/tls/tls_handshake_io.cpp index 287918841..da27cc4ce 100644 --- a/src/lib/tls/tls_handshake_io.cpp +++ b/src/lib/tls/tls_handshake_io.cpp @@ -1,6 +1,6 @@ /* * TLS Handshake IO -* (C) 2012 Jack Lloyd +* (C) 2012,2014 Jack Lloyd * * Released under the terms of the Botan license */ @@ -10,6 +10,7 @@ #include <botan/internal/tls_record.h> #include <botan/internal/tls_seq_numbers.h> #include <botan/exceptn.h> +#include <chrono> namespace Botan { @@ -56,7 +57,7 @@ void Stream_Handshake_IO::add_record(const std::vector<byte>& record, m_queue.insert(m_queue.end(), ccs_hs, ccs_hs + sizeof(ccs_hs)); } else - throw Decoding_Error("Unknown message type in handshake processing"); + throw Decoding_Error("Unknown message type " + std::to_string(record_type) + " in handshake processing"); } std::pair<Handshake_Type, std::vector<byte>> @@ -119,6 +120,65 @@ Protocol_Version Datagram_Handshake_IO::initial_record_version() const return Protocol_Version::DTLS_V10; } +namespace { + +// 1 second initial timeout, 60 second max - see RFC 6347 sec 4.2.4.1 +const u64bit INITIAL_TIMEOUT = 1*1000; +const u64bit MAXIMUM_TIMEOUT = 60*1000; + +u64bit steady_clock_ms() + { + return std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()).count(); + } + +} + +bool Datagram_Handshake_IO::timeout_check() + { + if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty())) + { + /* + If we haven't written anything yet obviously no timeout. + Also no timeout possible if we are mid-flight, + */ + return false; + } + + const u64bit ms_since_write = steady_clock_ms() - m_last_write; + + if(ms_since_write < m_next_timeout) + return false; + + std::vector<u16bit> flight; + if(m_flights.size() == 1) + flight = m_flights.at(0); // lost initial client hello + else + flight = m_flights.at(m_flights.size() - 2); + + BOTAN_ASSERT(flight.size() > 0, "Nonempty flight to retransmit"); + + u16bit epoch = m_flight_data[flight[0]].epoch; + + for(auto msg_seq : flight) + { + auto& msg = m_flight_data[msg_seq]; + + if(msg.epoch != epoch) + { + // Epoch gap: insert the CCS + std::vector<byte> ccs(1, 1); + m_send_hs(epoch, CHANGE_CIPHER_SPEC, ccs); + } + + send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits); + epoch = msg.epoch; + } + + m_next_timeout = std::min(2 * m_next_timeout, MAXIMUM_TIMEOUT); + return true; + } + void Datagram_Handshake_IO::add_record(const std::vector<byte>& record, Record_Type record_type, u64bit record_sequence) @@ -127,6 +187,7 @@ void Datagram_Handshake_IO::add_record(const std::vector<byte>& record, if(record_type == CHANGE_CIPHER_SPEC) { + // TODO: check this is otherwise empty m_ccs_epochs.insert(epoch); return; } @@ -161,6 +222,10 @@ void Datagram_Handshake_IO::add_record(const std::vector<byte>& record, msg_type, msg_len); } + else + { + // TODO: detect retransmitted flight + } record_bits += total_size; record_size -= total_size; @@ -170,6 +235,7 @@ void Datagram_Handshake_IO::add_record(const std::vector<byte>& record, std::pair<Handshake_Type, std::vector<byte>> Datagram_Handshake_IO::get_next_record(bool expecting_ccs) { + // Expecting a message means the last flight is concluded if(!m_flights.rbegin()->empty()) m_flights.push_back(std::vector<u16bit>()); @@ -215,7 +281,7 @@ void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment( } if(msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch) - throw Decoding_Error("Inconsistent values in DTLS handshake header"); + throw Decoding_Error("Inconsistent values in fragmented DTLS handshake header"); if(fragment_offset > m_msg_length) throw Decoding_Error("Fragment offset past end of message"); @@ -327,16 +393,30 @@ Datagram_Handshake_IO::send(const Handshake_Message& msg) const u16bit epoch = m_seqs.current_write_epoch(); const Handshake_Type msg_type = msg.type(); - std::tuple<u16bit, byte, std::vector<byte>> msg_info(epoch, msg_type, msg_bits); - if(msg_type == HANDSHAKE_CCS) { m_send_hs(epoch, CHANGE_CIPHER_SPEC, msg_bits); return std::vector<byte>(); // not included in handshake hashes } + // Note: not saving CCS, instead we know it was there due to change in epoch + m_flights.rbegin()->push_back(m_out_message_seq); + m_flight_data[m_out_message_seq] = Message_Info(epoch, msg_type, msg_bits); + + m_out_message_seq += 1; + m_last_write = steady_clock_ms(); + m_next_timeout = INITIAL_TIMEOUT; + + return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits); + } + +std::vector<byte> Datagram_Handshake_IO::send_message(u16bit msg_seq, + u16bit epoch, + Handshake_Type msg_type, + const std::vector<byte>& msg_bits) + { const std::vector<byte> no_fragment = - format_w_seq(msg_bits, msg_type, m_out_message_seq); + format_w_seq(msg_bits, msg_type, msg_seq); if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu) m_send_hs(epoch, HANDSHAKE, no_fragment); @@ -361,21 +441,14 @@ Datagram_Handshake_IO::send(const Handshake_Message& msg) frag_offset, msg_bits.size(), msg_type, - m_out_message_seq)); + msg_seq)); frag_offset += frag_len; } } - // Note: not saving CCS, instead we know it was there due to change in epoch - m_flights.rbegin()->push_back(m_out_message_seq); - m_flight_data[m_out_message_seq] = msg_info; - - m_out_message_seq += 1; - return no_fragment; } } - } diff --git a/src/lib/tls/tls_handshake_io.h b/src/lib/tls/tls_handshake_io.h index 36c605c30..b13a81700 100644 --- a/src/lib/tls/tls_handshake_io.h +++ b/src/lib/tls/tls_handshake_io.h @@ -1,6 +1,6 @@ /* * TLS Handshake Serialization -* (C) 2012 Jack Lloyd +* (C) 2012,2014 Jack Lloyd * * Released under the terms of the Botan license */ @@ -17,7 +17,6 @@ #include <map> #include <set> #include <utility> -#include <tuple> namespace Botan { @@ -35,6 +34,8 @@ class Handshake_IO virtual std::vector<byte> send(const Handshake_Message& msg) = 0; + virtual bool timeout_check() = 0; + virtual std::vector<byte> format( const std::vector<byte>& handshake_msg, Handshake_Type handshake_type) const = 0; @@ -69,6 +70,8 @@ class Stream_Handshake_IO : public Handshake_IO Protocol_Version initial_record_version() const override; + bool timeout_check() override { return false; } + std::vector<byte> send(const Handshake_Message& msg) override; std::vector<byte> format( @@ -93,11 +96,14 @@ class Datagram_Handshake_IO : public Handshake_IO { public: Datagram_Handshake_IO(class Connection_Sequence_Numbers& seq, - std::function<void (u16bit, byte, const std::vector<byte>&)> writer) : - m_seqs(seq), m_flights(1), m_send_hs(writer) {} + std::function<void (u16bit, byte, const std::vector<byte>&)> writer, + u16bit mtu) : + m_seqs(seq), m_flights(1), m_send_hs(writer), m_mtu(mtu) {} Protocol_Version initial_record_version() const override; + bool timeout_check() override; + std::vector<byte> send(const Handshake_Message& msg) override; std::vector<byte> format( @@ -124,6 +130,10 @@ class Datagram_Handshake_IO : public Handshake_IO Handshake_Type handshake_type, u16bit msg_sequence) const; + std::vector<byte> send_message(u16bit msg_seq, u16bit epoch, + Handshake_Type msg_type, + const std::vector<byte>& msg); + class Handshake_Reassembly { public: @@ -144,21 +154,40 @@ class Datagram_Handshake_IO : public Handshake_IO size_t m_msg_length = 0; u16bit m_epoch = 0; + // vector<bool> m_seen; + // vector<byte> m_fragments std::map<size_t, byte> m_fragments; std::vector<byte> m_message; }; + struct Message_Info + { + Message_Info(u16bit e, Handshake_Type mt, const std::vector<byte>& msg) : + epoch(e), msg_type(mt), msg_bits(msg) {} + + Message_Info(const Message_Info& other) = default; + + Message_Info() : epoch(0xFFFF), msg_type(HANDSHAKE_NONE) {} + + u16bit epoch; + Handshake_Type msg_type; + std::vector<byte> msg_bits; + }; + class Connection_Sequence_Numbers& m_seqs; std::map<u16bit, Handshake_Reassembly> m_messages; std::set<u16bit> m_ccs_epochs; std::vector<std::vector<u16bit>> m_flights; - std::map<u16bit, std::tuple<u16bit, byte, std::vector<byte>>> m_flight_data; + std::map<u16bit, Message_Info> m_flight_data; + + u64bit m_last_write = 0; + u64bit m_next_timeout = 0; - // default MTU is IPv6 min MTU minus UDP/IP headers - u16bit m_mtu = 1280 - 40 - 8; u16bit m_in_message_seq = 0; u16bit m_out_message_seq = 0; + std::function<void (u16bit, byte, const std::vector<byte>&)> m_send_hs; + u16bit m_mtu; }; } diff --git a/src/lib/tls/tls_reader.h b/src/lib/tls/tls_reader.h index 1c4f0f456..028893cc1 100644 --- a/src/lib/tls/tls_reader.h +++ b/src/lib/tls/tls_reader.h @@ -92,8 +92,8 @@ class TLS_Data_Reader template<typename T> std::vector<T> get_range(size_t len_bytes, - size_t min_elems, - size_t max_elems) + size_t min_elems, + size_t max_elems) { const size_t num_elems = get_num_elems(len_bytes, sizeof(T), min_elems, max_elems); diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp index fc4908dc5..be0777573 100644 --- a/src/lib/tls/tls_record.cpp +++ b/src/lib/tls/tls_record.cpp @@ -1,6 +1,6 @@ /* * TLS Record Handling -* (C) 2012,2013 Jack Lloyd +* (C) 2012,2013,2014 Jack Lloyd * * Released under the terms of the Botan license */ @@ -477,7 +477,7 @@ size_t read_record(secure_vector<byte>& readbuf, Protocol_Version* record_version, Record_Type* record_type, Connection_Sequence_Numbers* sequence_numbers, - std::function<Connection_Cipher_State* (u16bit)> get_cipherstate) + std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate) { consumed = 0; @@ -584,7 +584,10 @@ size_t read_record(secure_vector<byte>& readbuf, } if(sequence_numbers && sequence_numbers->already_seen(*record_sequence)) + { + readbuf.clear(); return 0; + } byte* record_contents = &readbuf[header_size]; @@ -596,7 +599,7 @@ size_t read_record(secure_vector<byte>& readbuf, } // Otherwise, decrypt, check MAC, return plaintext - Connection_Cipher_State* cipherstate = get_cipherstate(epoch); + auto cipherstate = get_cipherstate(epoch); // FIXME: DTLS reordering might cause us not to have the cipher state diff --git a/src/lib/tls/tls_record.h b/src/lib/tls/tls_record.h index c9f164407..fb727753a 100644 --- a/src/lib/tls/tls_record.h +++ b/src/lib/tls/tls_record.h @@ -125,7 +125,7 @@ size_t read_record(secure_vector<byte>& read_buffer, Protocol_Version* record_version, Record_Type* record_type, Connection_Sequence_Numbers* sequence_numbers, - std::function<Connection_Cipher_State* (u16bit)> get_cipherstate); + std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate); } diff --git a/src/lib/tls/tls_seq_numbers.h b/src/lib/tls/tls_seq_numbers.h index 87edf3130..d7b8c919c 100644 --- a/src/lib/tls/tls_seq_numbers.h +++ b/src/lib/tls/tls_seq_numbers.h @@ -24,7 +24,7 @@ class Connection_Sequence_Numbers virtual u16bit current_read_epoch() const = 0; virtual u16bit current_write_epoch() const = 0; - virtual u64bit next_write_sequence() = 0; + virtual u64bit next_write_sequence(u16bit) = 0; virtual u64bit next_read_sequence() = 0; virtual bool already_seen(u64bit seq) const = 0; @@ -40,7 +40,7 @@ class Stream_Sequence_Numbers : public Connection_Sequence_Numbers u16bit current_read_epoch() const override { return m_read_epoch; } u16bit current_write_epoch() const override { return m_write_epoch; } - u64bit next_write_sequence() override { return m_write_seq_no++; } + u64bit next_write_sequence(u16bit) override { return m_write_seq_no++; } u64bit next_read_sequence() override { return m_read_seq_no; } bool already_seen(u64bit) const override { return false; } @@ -55,18 +55,25 @@ class Stream_Sequence_Numbers : public Connection_Sequence_Numbers class Datagram_Sequence_Numbers : public Connection_Sequence_Numbers { public: + Datagram_Sequence_Numbers() { m_write_seqs[0] = 0; } + void new_read_cipher_state() override { m_read_epoch += 1; } void new_write_cipher_state() override { - // increment epoch - m_write_seq_no = ((m_write_seq_no >> 48) + 1) << 48; + m_write_epoch += 1; + m_write_seqs[m_write_epoch] = 0; } u16bit current_read_epoch() const override { return m_read_epoch; } - u16bit current_write_epoch() const override { return (m_write_seq_no >> 48); } + u16bit current_write_epoch() const override { return m_write_epoch; } - u64bit next_write_sequence() override { return m_write_seq_no++; } + u64bit next_write_sequence(u16bit epoch) override + { + auto i = m_write_seqs.find(epoch); + BOTAN_ASSERT(i != m_write_seqs.end(), "Found epoch"); + return (static_cast<u64bit>(epoch) << 48) | i->second++; + } u64bit next_read_sequence() override { @@ -112,7 +119,8 @@ class Datagram_Sequence_Numbers : public Connection_Sequence_Numbers } private: - u64bit m_write_seq_no = 0; + std::map<u16bit, u64bit> m_write_seqs; + u16bit m_write_epoch = 0; u16bit m_read_epoch = 0; u64bit m_window_highest = 0; u64bit m_window_bits = 0; |