diff options
author | lloyd <[email protected]> | 2012-09-10 19:42:44 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2012-09-10 19:42:44 +0000 |
commit | 8c0160098e9bffa1a124a8951ba1a9c074f5509c (patch) | |
tree | 438bcd9b17bd0f3f122cb2450cc6851ded4faeac /src | |
parent | ad949688f2903d6b59e3178fc2d6a0022bdfa79f (diff) |
New logic for DTLS replay detection. Abstracts the sequence handling
out a bit. Handling of initial server record is pretty nasty.
Diffstat (limited to 'src')
-rw-r--r-- | src/tls/info.txt | 1 | ||||
-rw-r--r-- | src/tls/tls_channel.cpp | 47 | ||||
-rw-r--r-- | src/tls/tls_channel.h | 9 | ||||
-rw-r--r-- | src/tls/tls_record.cpp | 23 | ||||
-rw-r--r-- | src/tls/tls_record.h | 6 | ||||
-rw-r--r-- | src/tls/tls_seq_numbers.h | 112 |
6 files changed, 168 insertions, 30 deletions
diff --git a/src/tls/info.txt b/src/tls/info.txt index 7ffc26438..e61b2c0da 100644 --- a/src/tls/info.txt +++ b/src/tls/info.txt @@ -31,6 +31,7 @@ tls_heartbeats.h tls_messages.h tls_reader.h tls_record.h +tls_seq_numbers.h tls_session_key.h </header:internal> diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index 71dae73b3..382671afe 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -10,6 +10,7 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_heartbeats.h> #include <botan/internal/tls_record.h> +#include <botan/internal/tls_seq_numbers.h> #include <botan/internal/assert.h> #include <botan/internal/rounding.h> #include <botan/internal/stl_util.h> @@ -39,6 +40,12 @@ Channel::~Channel() // So unique_ptr destructors run correctly } +Connection_Sequence_Numbers& Channel::sequence_numbers() const + { + BOTAN_ASSERT(m_sequence_numbers, "Have a sequence numbers object"); + return *m_sequence_numbers; + } + std::vector<X509_Certificate> Channel::peer_cert_chain() const { if(!m_active_state) @@ -65,6 +72,14 @@ Handshake_State& Channel::create_handshake_state(Protocol_Version version) " in pending"); } + if(!m_sequence_numbers) + { + if(version.is_datagram_protocol()) + m_sequence_numbers.reset(new Datagram_Sequence_Numbers); + else + m_sequence_numbers.reset(new Stream_Sequence_Numbers); + } + auto send_rec = std::bind(&Channel::send_record, this, std::placeholders::_1, std::placeholders::_2); @@ -111,7 +126,7 @@ void Channel::change_cipher_spec_reader(Connection_Side side) if(m_pending_state->server_hello()->compression_method()!= NO_COMPRESSION) throw Internal_Error("Negotiated unknown compression algorithm"); - m_read_seq_no = 0; + sequence_numbers().new_read_cipher_state(); // flip side as we are reading side = (side == CLIENT) ? SERVER : CLIENT; @@ -132,18 +147,7 @@ void Channel::change_cipher_spec_writer(Connection_Side side) if(m_pending_state->server_hello()->compression_method()!= NO_COMPRESSION) throw Internal_Error("Negotiated unknown compression algorithm"); - /* - RFC 4346: - A sequence number is incremented after each record: specifically, - the first record transmitted under a particular connection state - MUST use sequence number 0 - - For DTLS, increment the epoch - */ - if(m_pending_state->version().is_datagram_protocol()) - m_write_seq_no = ((m_write_seq_no >> 48) + 1) << 48; - else - m_write_seq_no = 0; + sequence_numbers().new_write_cipher_state(); m_write_cipherstate.reset( new Connection_Cipher_State(m_pending_state->version(), @@ -194,8 +198,8 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) consumed, rec_type, record, - m_read_seq_no, record_version, + m_sequence_numbers.get(), m_read_cipherstate.get()); BOTAN_ASSERT(consumed <= buf_size, @@ -214,13 +218,17 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) throw TLS_Exception(Alert::RECORD_OVERFLOW, "Plaintext record is too large"); - record_number = m_read_seq_no; - m_read_seq_no += 1; - - if(rec_type == HANDSHAKE || rec_type == CHANGE_CIPHER_SPEC) + if(rec_type == CONNECTION_CLOSED) + { + continue; + } + else if(rec_type == HANDSHAKE || rec_type == CHANGE_CIPHER_SPEC) { if(!m_pending_state) + { create_handshake_state(record_version); + sequence_numbers().read_accept(0); + } m_pending_state->handshake_io().add_input( rec_type, &record[0], record.size(), record_number); @@ -403,12 +411,11 @@ void Channel::write_record(byte record_type, const byte input[], size_t length) record_type, input, length, - m_write_seq_no, record_version, + sequence_numbers(), m_write_cipherstate.get(), m_rng); - m_write_seq_no += 1; m_output_fn(&m_writebuf[0], m_writebuf.size()); } diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h index 819279b69..367c3560d 100644 --- a/src/tls/tls_channel.h +++ b/src/tls/tls_channel.h @@ -21,7 +21,6 @@ namespace Botan { namespace TLS { -class Handshake_IO; class Handshake_State; /** @@ -121,7 +120,7 @@ class BOTAN_DLL Channel virtual std::vector<X509_Certificate> get_peer_cert_chain(const Handshake_State& state) const = 0; - virtual Handshake_State* new_handshake_state(Handshake_IO* io) = 0; + virtual Handshake_State* new_handshake_state(class Handshake_IO* io) = 0; Handshake_State& create_handshake_state(Protocol_Version version); @@ -167,6 +166,8 @@ class BOTAN_DLL Channel bool heartbeat_sending_allowed() const; + class Connection_Sequence_Numbers& sequence_numbers() const; + /* callbacks */ std::function<bool (const Session&)> m_handshake_fn; std::function<void (const byte[], size_t, Alert)> m_proc_fn; @@ -176,16 +177,16 @@ class BOTAN_DLL Channel RandomNumberGenerator& m_rng; Session_Manager& m_session_manager; + std::unique_ptr<class Connection_Sequence_Numbers> m_sequence_numbers; + /* writing cipher state */ std::vector<byte> m_writebuf; std::unique_ptr<class Connection_Cipher_State> m_write_cipherstate; - u64bit m_write_seq_no = 0; /* reading cipher state */ std::vector<byte> m_readbuf; size_t m_readbuf_pos = 0; std::unique_ptr<class Connection_Cipher_State> m_read_cipherstate; - u64bit m_read_seq_no = 0; /* connection parameters */ std::unique_ptr<Handshake_State> m_active_state; diff --git a/src/tls/tls_record.cpp b/src/tls/tls_record.cpp index 33903f1df..5aa69747f 100644 --- a/src/tls/tls_record.cpp +++ b/src/tls/tls_record.cpp @@ -10,6 +10,7 @@ #include <botan/tls_exceptn.h> #include <botan/libstate.h> #include <botan/loadstor.h> +#include <botan/internal/tls_seq_numbers.h> #include <botan/internal/tls_session_key.h> #include <botan/internal/rounding.h> #include <botan/internal/assert.h> @@ -78,8 +79,8 @@ Connection_Cipher_State::Connection_Cipher_State(Protocol_Version version, void write_record(std::vector<byte>& output, byte msg_type, const byte msg[], size_t msg_length, - u64bit msg_sequence_number, Protocol_Version version, + Connection_Sequence_Numbers& sequence_numbers, Connection_Cipher_State* cipherstate, RandomNumberGenerator& rng) { @@ -89,10 +90,12 @@ void write_record(std::vector<byte>& output, output.push_back(version.major_version()); output.push_back(version.minor_version()); + const u64bit msg_sequence = sequence_numbers.next_write_sequence(); + if(version.is_datagram_protocol()) { for(size_t i = 0; i != 8; ++i) - output.push_back(get_byte(i, msg_sequence_number)); + output.push_back(get_byte(i, msg_sequence)); } if(!cipherstate) // initial unencrypted handshake records @@ -105,7 +108,7 @@ void write_record(std::vector<byte>& output, return; } - cipherstate->mac()->update_be(msg_sequence_number); + cipherstate->mac()->update_be(msg_sequence); cipherstate->mac()->update(msg_type); if(cipherstate->mac_includes_record_version()) @@ -273,8 +276,8 @@ size_t read_record(std::vector<byte>& readbuf, size_t& consumed, byte& msg_type, std::vector<byte>& msg, - u64bit msg_sequence, Protocol_Version& record_version, + Connection_Sequence_Numbers* sequence_numbers, Connection_Cipher_State* cipherstate) { consumed = 0; @@ -352,8 +355,14 @@ size_t read_record(std::vector<byte>& readbuf, const size_t header_size = (record_version.is_datagram_protocol()) ? DTLS_HEADER_SIZE : TLS_HEADER_SIZE; + u64bit msg_sequence = 0; + if(record_version.is_datagram_protocol()) msg_sequence = load_be<u64bit>(&readbuf[3], 0); + else if(sequence_numbers) + msg_sequence = sequence_numbers->next_read_sequence(); + else + msg_sequence = 0; // server initial handshake case const size_t record_len = make_u16bit(readbuf[header_size-2], readbuf[header_size-1]); @@ -371,6 +380,9 @@ size_t read_record(std::vector<byte>& readbuf, readbuf_pos, "Have the full record"); + if(sequence_numbers && sequence_numbers->already_seen(msg_sequence)) + return 0; + byte* record_contents = &readbuf[header_size]; if(!cipherstate) // Only handshake messages allowed during initial handshake @@ -469,6 +481,9 @@ size_t read_record(std::vector<byte>& readbuf, if(mac_bad || padding_bad) throw TLS_Exception(Alert::BAD_RECORD_MAC, "Message authentication failure"); + if(sequence_numbers) + sequence_numbers->read_accept(msg_sequence); + msg_type = readbuf[0]; msg.assign(&record_contents[iv_size], &record_contents[iv_size + plain_length]); diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h index 2b415edb8..841244733 100644 --- a/src/tls/tls_record.h +++ b/src/tls/tls_record.h @@ -23,6 +23,8 @@ namespace TLS { class Ciphersuite; class Session_Keys; +class Connection_Sequence_Numbers; + /** * TLS Cipher State */ @@ -78,8 +80,8 @@ class Connection_Cipher_State */ void write_record(std::vector<byte>& write_buffer, byte msg_type, const byte msg[], size_t msg_length, - u64bit msg_sequence_number, Protocol_Version version, + Connection_Sequence_Numbers& sequence_numbers, Connection_Cipher_State* cipherstate, RandomNumberGenerator& rng); @@ -94,8 +96,8 @@ size_t read_record(std::vector<byte>& read_buffer, size_t& input_consumed, byte& msg_type, std::vector<byte>& msg, - u64bit msg_sequence, Protocol_Version& record_version, + Connection_Sequence_Numbers* sequence_numbers, Connection_Cipher_State* cipherstate); } diff --git a/src/tls/tls_seq_numbers.h b/src/tls/tls_seq_numbers.h new file mode 100644 index 000000000..c9a334e4b --- /dev/null +++ b/src/tls/tls_seq_numbers.h @@ -0,0 +1,112 @@ +/* +* TLS Sequence Number Handling +* (C) 2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#ifndef BOTAN_TLS_SEQ_NUMBERS_H__ +#define BOTAN_TLS_SEQ_NUMBERS_H__ + +#include <stdexcept> + +namespace Botan { + +namespace TLS { + +class Connection_Sequence_Numbers + { + public: + virtual void new_read_cipher_state() = 0; + virtual void new_write_cipher_state() = 0; + + virtual u64bit next_write_sequence() = 0; + + virtual u64bit next_read_sequence() = 0; + virtual bool already_seen(u64bit seq) const = 0; + virtual void read_accept(u64bit seq) = 0; + }; + +class Stream_Sequence_Numbers : public Connection_Sequence_Numbers + { + public: + void new_read_cipher_state() override { m_read_seq_no = 0; } + void new_write_cipher_state() override { m_write_seq_no = 0; } + + u64bit next_write_sequence() override { return m_write_seq_no++; } + + u64bit next_read_sequence() override { return m_read_seq_no; } + bool already_seen(u64bit) const override { return false; } + void read_accept(u64bit) override { m_read_seq_no++; } + private: + u64bit m_write_seq_no = 0; + u64bit m_read_seq_no = 0; + }; + +class Datagram_Sequence_Numbers : public Connection_Sequence_Numbers + { + public: + void new_read_cipher_state() override {} + + void new_write_cipher_state() override + { + // increment epoch + m_write_seq_no = ((m_write_seq_no >> 48) + 1) << 48; + } + + u64bit next_write_sequence() override { return m_write_seq_no++; } + + u64bit next_read_sequence() override + { + throw std::runtime_error("DTLS uses explicit sequence numbers"); + } + + bool already_seen(u64bit sequence) const override + { + const size_t window_size = sizeof(m_window_bits) * 8; + + if(sequence > m_window_highest) + return false; + + const u64bit offset = m_window_highest - sequence; + + if(offset >= window_size) + return true; // really old? + + return (((m_window_bits >> offset) & 1) == 1); + } + + void read_accept(u64bit sequence) override + { + const size_t window_size = sizeof(m_window_bits) * 8; + + if(sequence > m_window_highest) + { + const size_t offset = sequence - m_window_highest; + m_window_highest += offset; + + if(offset >= window_size) + m_window_bits = 0; + else + m_window_bits <<= offset; + + m_window_bits |= 0x01; + } + else + { + const u64bit offset = m_window_highest - sequence; + m_window_bits |= (static_cast<u64bit>(1) << offset); + } + } + + private: + u64bit m_write_seq_no = 0; + u64bit m_window_highest = 0; + u64bit m_window_bits = 0; + }; + +} + +} + +#endif |