diff options
author | lloyd <[email protected]> | 2012-11-12 22:05:09 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2012-11-12 22:05:09 +0000 |
commit | 58461a900aea49e5230b7b748fc481114d31904a (patch) | |
tree | 1a8d54f5368d5109845f6d6fee32b32b0c1d8d12 | |
parent | 579158e826daed42963db0c8b987d51ba7831fb6 (diff) |
Changes so DTLS handshake can send messages under different epochs, eg
for retransmitting a flight.
-rw-r--r-- | src/tls/tls_channel.cpp | 71 | ||||
-rw-r--r-- | src/tls/tls_channel.h | 6 | ||||
-rw-r--r-- | src/tls/tls_client.cpp | 4 | ||||
-rw-r--r-- | src/tls/tls_handshake_io.cpp | 24 | ||||
-rw-r--r-- | src/tls/tls_handshake_io.h | 19 | ||||
-rw-r--r-- | src/tls/tls_policy.cpp | 5 | ||||
-rw-r--r-- | src/tls/tls_record.cpp | 11 |
7 files changed, 85 insertions, 55 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index 1be336fc5..5858f5d90 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -73,8 +73,6 @@ std::vector<X509_Certificate> Channel::peer_cert_chain() const Handshake_State& Channel::create_handshake_state(Protocol_Version version) { - const size_t dtls_mtu = 1400; // fixme should be settable - if(pending_state()) throw Internal_Error("create_handshake_state called during handshake"); @@ -98,15 +96,19 @@ Handshake_State& Channel::create_handshake_state(Protocol_Version version) m_sequence_numbers.reset(new Stream_Sequence_Numbers); } - auto send_rec = std::bind(&Channel::send_record, this, - std::placeholders::_1, - std::placeholders::_2); - std::unique_ptr<Handshake_IO> io; if(version.is_datagram_protocol()) - io.reset(new Datagram_Handshake_IO(send_rec, dtls_mtu)); + io.reset(new Datagram_Handshake_IO( + sequence_numbers(), + std::bind(&Channel::send_record_under_epoch, this, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3))); else - io.reset(new Stream_Handshake_IO(send_rec)); + io.reset(new Stream_Handshake_IO( + std::bind(&Channel::send_record, this, + std::placeholders::_1, + std::placeholders::_2))); m_pending_state.reset(new_handshake_state(io.release())); @@ -429,7 +431,28 @@ void Channel::heartbeat(const byte payload[], size_t payload_size) } } -void Channel::send_record_array(byte type, const byte input[], size_t length) +void Channel::write_record(Connection_Cipher_State* cipher_state, + byte record_type, const byte input[], size_t length) + { + BOTAN_ASSERT(m_pending_state || m_active_state, + "Some connection state exists"); + + Protocol_Version record_version = + (m_pending_state) ? (m_pending_state->version()) : (m_active_state->version()); + + TLS::write_record(m_writebuf, + record_type, + input, + length, + record_version, + sequence_numbers(), + cipher_state, + m_rng); + + m_output_fn(&m_writebuf[0], m_writebuf.size()); + } + +void Channel::send_record_array(u16bit epoch, byte type, const byte input[], size_t length) { if(length == 0) return; @@ -446,8 +469,7 @@ void Channel::send_record_array(byte type, const byte input[], size_t length) * See http://www.openssl.org/~bodo/tls-cbc.txt for background. */ - auto cipher_state = - write_cipher_state_epoch(sequence_numbers().current_write_epoch()); + auto cipher_state = write_cipher_state_epoch(epoch); if(type == APPLICATION_DATA && cipher_state->cbc_without_explicit_iv()) { @@ -470,28 +492,14 @@ void Channel::send_record_array(byte type, const byte input[], size_t length) void Channel::send_record(byte record_type, const std::vector<byte>& record) { - send_record_array(record_type, &record[0], record.size()); + send_record_array(sequence_numbers().current_write_epoch(), + record_type, &record[0], record.size()); } -void Channel::write_record(Connection_Cipher_State* cipher_state, - byte record_type, const byte input[], size_t length) +void Channel::send_record_under_epoch(u16bit epoch, byte record_type, + const std::vector<byte>& record) { - BOTAN_ASSERT(m_pending_state || m_active_state, - "Some connection state exists"); - - Protocol_Version record_version = - (m_pending_state) ? (m_pending_state->version()) : (m_active_state->version()); - - TLS::write_record(m_writebuf, - record_type, - input, - length, - record_version, - sequence_numbers(), - cipher_state, - m_rng); - - m_output_fn(&m_writebuf[0], m_writebuf.size()); + send_record_array(epoch, record_type, &record[0], record.size()); } void Channel::send(const byte buf[], size_t buf_size) @@ -499,7 +507,8 @@ void Channel::send(const byte buf[], size_t buf_size) if(!is_active()) throw std::runtime_error("Data cannot be sent on inactive TLS connection"); - send_record_array(APPLICATION_DATA, buf, buf_size); + send_record_array(sequence_numbers().current_write_epoch(), + APPLICATION_DATA, buf, buf_size); } void Channel::send(const std::string& string) diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h index 10ecd296f..d27f8f2f5 100644 --- a/src/tls/tls_channel.h +++ b/src/tls/tls_channel.h @@ -175,7 +175,11 @@ class BOTAN_DLL Channel void send_record(byte record_type, const std::vector<byte>& record); - void send_record_array(byte type, const byte input[], size_t length); + void send_record_under_epoch(u16bit epoch, byte record_type, + const std::vector<byte>& record); + + void send_record_array(u16bit epoch, byte record_type, + const byte input[], size_t length); void write_record(Connection_Cipher_State* cipher_state, byte type, const byte input[], size_t length); diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index aae3a65c5..b0724b03c 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -254,6 +254,10 @@ void Client::process_handshake_msg(const Handshake_State* active_state, { // new session + BOTAN_ASSERT_EQUAL(state.client_hello()->version().is_datagram_protocol(), + state.server_hello()->version().is_datagram_protocol(), + "Server replied with same protocol type client offered"); + if(state.version() > state.client_hello()->version()) { throw TLS_Exception(Alert::HANDSHAKE_FAILURE, diff --git a/src/tls/tls_handshake_io.cpp b/src/tls/tls_handshake_io.cpp index b83d9e044..1fae7b5b7 100644 --- a/src/tls/tls_handshake_io.cpp +++ b/src/tls/tls_handshake_io.cpp @@ -7,6 +7,7 @@ #include <botan/internal/tls_handshake_io.h> #include <botan/internal/tls_messages.h> +#include <botan/internal/tls_seq_numbers.h> #include <botan/exceptn.h> namespace Botan { @@ -231,6 +232,10 @@ void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment( /* * FIXME. This is a pretty lame way to do defragmentation, huge * overhead with a tree node per byte. + * + * Also should confirm that all overlaps have no changes, + * otherwise we expose ourselves to the classic fingerprinting + * and IDS evasion attacks on IP fragmentation. */ for(size_t i = 0; i != fragment_length; ++i) m_fragments[fragment_offset+i] = fragment[i]; @@ -318,18 +323,22 @@ std::vector<byte> Datagram_Handshake_IO::send(const Handshake_Message& msg) { const std::vector<byte> msg_bits = msg.serialize(); + const u16bit epoch = m_seqs.current_write_epoch(); + const Handshake_Type msg_type = msg.type(); - if(msg.type() == HANDSHAKE_CCS) + std::tuple<u16bit, byte, std::vector<byte>> msg_info(epoch, msg_type, msg_bits); + + if(msg_type == HANDSHAKE_CCS) { - m_send_hs(CHANGE_CIPHER_SPEC, msg_bits); + m_send_hs(epoch, CHANGE_CIPHER_SPEC, msg_bits); return std::vector<byte>(); // not included in handshake hashes } const std::vector<byte> no_fragment = - format_w_seq(msg_bits, msg.type(), m_out_message_seq); + format_w_seq(msg_bits, msg_type, m_out_message_seq); if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu) - m_send_hs(HANDSHAKE, no_fragment); + m_send_hs(epoch, HANDSHAKE, no_fragment); else { const size_t parts = split_for_mtu(m_mtu, msg_bits.size()); @@ -344,19 +353,22 @@ Datagram_Handshake_IO::send(const Handshake_Message& msg) std::min<size_t>(msg_bits.size() - frag_offset, parts_size); - m_send_hs(HANDSHAKE, + m_send_hs(epoch, + HANDSHAKE, format_fragment(&msg_bits[frag_offset], frag_len, frag_offset, msg_bits.size(), - msg.type(), + msg_type, m_out_message_seq)); frag_offset += frag_len; } } + // Note: not saving CCS, instead we know it was there due to change in epoch m_flights.rbegin()->push_back(m_out_message_seq); + m_flight_data[m_out_message_seq] = msg_info; m_out_message_seq += 1; diff --git a/src/tls/tls_handshake_io.h b/src/tls/tls_handshake_io.h index b026d4160..18fde1a83 100644 --- a/src/tls/tls_handshake_io.h +++ b/src/tls/tls_handshake_io.h @@ -17,6 +17,7 @@ #include <map> #include <set> #include <utility> +#include <tuple> namespace Botan { @@ -24,8 +25,6 @@ namespace TLS { class Handshake_Message; -typedef std::function<void (byte, const std::vector<byte>&)> handshake_write_fn; - /** * Handshake IO Interface */ @@ -66,7 +65,7 @@ class Handshake_IO class Stream_Handshake_IO : public Handshake_IO { public: - Stream_Handshake_IO(handshake_write_fn writer) : + Stream_Handshake_IO(std::function<void (byte, const std::vector<byte>&)> writer) : m_send_hs(writer) {} Protocol_Version initial_record_version() const override; @@ -86,7 +85,7 @@ class Stream_Handshake_IO : public Handshake_IO get_next_record(bool expecting_ccs) override; private: std::deque<byte> m_queue; - handshake_write_fn m_send_hs; + std::function<void (byte, const std::vector<byte>&)> m_send_hs; }; /** @@ -95,8 +94,9 @@ class Stream_Handshake_IO : public Handshake_IO class Datagram_Handshake_IO : public Handshake_IO { public: - Datagram_Handshake_IO(handshake_write_fn writer, u16bit mtu) : - m_flights(1), m_mtu(mtu), m_send_hs(writer) {} + Datagram_Handshake_IO(class Connection_Sequence_Numbers& seq, + std::function<void (u16bit, byte, const std::vector<byte>&)> writer) : + m_seqs(seq), m_flights(1), m_send_hs(writer) {} Protocol_Version initial_record_version() const override; @@ -151,14 +151,17 @@ class Datagram_Handshake_IO : public Handshake_IO std::vector<byte> m_message; }; + class Connection_Sequence_Numbers& m_seqs; std::map<u16bit, Handshake_Reassembly> m_messages; std::set<u16bit> m_ccs_epochs; std::vector<std::vector<u16bit>> m_flights; + std::map<u16bit, std::tuple<u16bit, byte, std::vector<byte>>> m_flight_data; - u16bit m_mtu = 0; + // default MTU is IPv6 min MTU minus UDP/IP headers + u16bit m_mtu = 1280 - 40 - 8; u16bit m_in_message_seq = 0; u16bit m_out_message_seq = 0; - handshake_write_fn m_send_hs; + std::function<void (u16bit, byte, const std::vector<byte>&)> m_send_hs; }; } diff --git a/src/tls/tls_policy.cpp b/src/tls/tls_policy.cpp index c76fe30a5..e98fe66b2 100644 --- a/src/tls/tls_policy.cpp +++ b/src/tls/tls_policy.cpp @@ -130,10 +130,7 @@ u32bit Policy::session_ticket_lifetime() const bool Policy::acceptable_protocol_version(Protocol_Version version) const { - return (version == Protocol_Version::SSL_V3 || - version == Protocol_Version::TLS_V10 || - version == Protocol_Version::TLS_V11 || - version == Protocol_Version::TLS_V12); + return version.known_version(); // accept any version we know about } namespace { diff --git a/src/tls/tls_record.cpp b/src/tls/tls_record.cpp index e11ba31b1..fab966e72 100644 --- a/src/tls/tls_record.cpp +++ b/src/tls/tls_record.cpp @@ -328,7 +328,9 @@ size_t read_record(std::vector<byte>& readbuf, record_version = Protocol_Version(readbuf[1], readbuf[2]); - if(record_version.is_datagram_protocol() && readbuf.size() < DTLS_HEADER_SIZE) + const bool is_dtls = record_version.is_datagram_protocol(); + + if(is_dtls && readbuf.size() < DTLS_HEADER_SIZE) { if(size_t needed = fill_buffer_to(readbuf, input, input_sz, consumed, @@ -339,8 +341,7 @@ size_t read_record(std::vector<byte>& readbuf, "Have an entire header"); } - const size_t header_size = - (record_version.is_datagram_protocol()) ? DTLS_HEADER_SIZE : TLS_HEADER_SIZE; + const size_t header_size = (is_dtls) ? DTLS_HEADER_SIZE : TLS_HEADER_SIZE; const size_t record_len = make_u16bit(readbuf[header_size-2], readbuf[header_size-1]); @@ -352,7 +353,7 @@ size_t read_record(std::vector<byte>& readbuf, if(size_t needed = fill_buffer_to(readbuf, input, input_sz, consumed, header_size + record_len)) - return needed; + return needed; // wrong for DTLS? BOTAN_ASSERT_EQUAL(static_cast<size_t>(header_size) + record_len, readbuf.size(), @@ -360,7 +361,7 @@ size_t read_record(std::vector<byte>& readbuf, u16bit epoch = 0; - if(record_version.is_datagram_protocol()) + if(is_dtls) { record_sequence = load_be<u64bit>(&readbuf[3], 0); epoch = (record_sequence >> 48); |