aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorlloyd <[email protected]>2012-09-10 19:42:44 +0000
committerlloyd <[email protected]>2012-09-10 19:42:44 +0000
commit8c0160098e9bffa1a124a8951ba1a9c074f5509c (patch)
tree438bcd9b17bd0f3f122cb2450cc6851ded4faeac /src
parentad949688f2903d6b59e3178fc2d6a0022bdfa79f (diff)
New logic for DTLS replay detection. Abstracts the sequence handling
out a bit. Handling of initial server record is pretty nasty.
Diffstat (limited to 'src')
-rw-r--r--src/tls/info.txt1
-rw-r--r--src/tls/tls_channel.cpp47
-rw-r--r--src/tls/tls_channel.h9
-rw-r--r--src/tls/tls_record.cpp23
-rw-r--r--src/tls/tls_record.h6
-rw-r--r--src/tls/tls_seq_numbers.h112
6 files changed, 168 insertions, 30 deletions
diff --git a/src/tls/info.txt b/src/tls/info.txt
index 7ffc26438..e61b2c0da 100644
--- a/src/tls/info.txt
+++ b/src/tls/info.txt
@@ -31,6 +31,7 @@ tls_heartbeats.h
tls_messages.h
tls_reader.h
tls_record.h
+tls_seq_numbers.h
tls_session_key.h
</header:internal>
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index 71dae73b3..382671afe 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -10,6 +10,7 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_heartbeats.h>
#include <botan/internal/tls_record.h>
+#include <botan/internal/tls_seq_numbers.h>
#include <botan/internal/assert.h>
#include <botan/internal/rounding.h>
#include <botan/internal/stl_util.h>
@@ -39,6 +40,12 @@ Channel::~Channel()
// So unique_ptr destructors run correctly
}
+Connection_Sequence_Numbers& Channel::sequence_numbers() const
+ {
+ BOTAN_ASSERT(m_sequence_numbers, "Have a sequence numbers object");
+ return *m_sequence_numbers;
+ }
+
std::vector<X509_Certificate> Channel::peer_cert_chain() const
{
if(!m_active_state)
@@ -65,6 +72,14 @@ Handshake_State& Channel::create_handshake_state(Protocol_Version version)
" in pending");
}
+ if(!m_sequence_numbers)
+ {
+ if(version.is_datagram_protocol())
+ m_sequence_numbers.reset(new Datagram_Sequence_Numbers);
+ else
+ m_sequence_numbers.reset(new Stream_Sequence_Numbers);
+ }
+
auto send_rec = std::bind(&Channel::send_record, this,
std::placeholders::_1,
std::placeholders::_2);
@@ -111,7 +126,7 @@ void Channel::change_cipher_spec_reader(Connection_Side side)
if(m_pending_state->server_hello()->compression_method()!= NO_COMPRESSION)
throw Internal_Error("Negotiated unknown compression algorithm");
- m_read_seq_no = 0;
+ sequence_numbers().new_read_cipher_state();
// flip side as we are reading
side = (side == CLIENT) ? SERVER : CLIENT;
@@ -132,18 +147,7 @@ void Channel::change_cipher_spec_writer(Connection_Side side)
if(m_pending_state->server_hello()->compression_method()!= NO_COMPRESSION)
throw Internal_Error("Negotiated unknown compression algorithm");
- /*
- RFC 4346:
- A sequence number is incremented after each record: specifically,
- the first record transmitted under a particular connection state
- MUST use sequence number 0
-
- For DTLS, increment the epoch
- */
- if(m_pending_state->version().is_datagram_protocol())
- m_write_seq_no = ((m_write_seq_no >> 48) + 1) << 48;
- else
- m_write_seq_no = 0;
+ sequence_numbers().new_write_cipher_state();
m_write_cipherstate.reset(
new Connection_Cipher_State(m_pending_state->version(),
@@ -194,8 +198,8 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
consumed,
rec_type,
record,
- m_read_seq_no,
record_version,
+ m_sequence_numbers.get(),
m_read_cipherstate.get());
BOTAN_ASSERT(consumed <= buf_size,
@@ -214,13 +218,17 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
throw TLS_Exception(Alert::RECORD_OVERFLOW,
"Plaintext record is too large");
- record_number = m_read_seq_no;
- m_read_seq_no += 1;
-
- if(rec_type == HANDSHAKE || rec_type == CHANGE_CIPHER_SPEC)
+ if(rec_type == CONNECTION_CLOSED)
+ {
+ continue;
+ }
+ else if(rec_type == HANDSHAKE || rec_type == CHANGE_CIPHER_SPEC)
{
if(!m_pending_state)
+ {
create_handshake_state(record_version);
+ sequence_numbers().read_accept(0);
+ }
m_pending_state->handshake_io().add_input(
rec_type, &record[0], record.size(), record_number);
@@ -403,12 +411,11 @@ void Channel::write_record(byte record_type, const byte input[], size_t length)
record_type,
input,
length,
- m_write_seq_no,
record_version,
+ sequence_numbers(),
m_write_cipherstate.get(),
m_rng);
- m_write_seq_no += 1;
m_output_fn(&m_writebuf[0], m_writebuf.size());
}
diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h
index 819279b69..367c3560d 100644
--- a/src/tls/tls_channel.h
+++ b/src/tls/tls_channel.h
@@ -21,7 +21,6 @@ namespace Botan {
namespace TLS {
-class Handshake_IO;
class Handshake_State;
/**
@@ -121,7 +120,7 @@ class BOTAN_DLL Channel
virtual std::vector<X509_Certificate>
get_peer_cert_chain(const Handshake_State& state) const = 0;
- virtual Handshake_State* new_handshake_state(Handshake_IO* io) = 0;
+ virtual Handshake_State* new_handshake_state(class Handshake_IO* io) = 0;
Handshake_State& create_handshake_state(Protocol_Version version);
@@ -167,6 +166,8 @@ class BOTAN_DLL Channel
bool heartbeat_sending_allowed() const;
+ class Connection_Sequence_Numbers& sequence_numbers() const;
+
/* callbacks */
std::function<bool (const Session&)> m_handshake_fn;
std::function<void (const byte[], size_t, Alert)> m_proc_fn;
@@ -176,16 +177,16 @@ class BOTAN_DLL Channel
RandomNumberGenerator& m_rng;
Session_Manager& m_session_manager;
+ std::unique_ptr<class Connection_Sequence_Numbers> m_sequence_numbers;
+
/* writing cipher state */
std::vector<byte> m_writebuf;
std::unique_ptr<class Connection_Cipher_State> m_write_cipherstate;
- u64bit m_write_seq_no = 0;
/* reading cipher state */
std::vector<byte> m_readbuf;
size_t m_readbuf_pos = 0;
std::unique_ptr<class Connection_Cipher_State> m_read_cipherstate;
- u64bit m_read_seq_no = 0;
/* connection parameters */
std::unique_ptr<Handshake_State> m_active_state;
diff --git a/src/tls/tls_record.cpp b/src/tls/tls_record.cpp
index 33903f1df..5aa69747f 100644
--- a/src/tls/tls_record.cpp
+++ b/src/tls/tls_record.cpp
@@ -10,6 +10,7 @@
#include <botan/tls_exceptn.h>
#include <botan/libstate.h>
#include <botan/loadstor.h>
+#include <botan/internal/tls_seq_numbers.h>
#include <botan/internal/tls_session_key.h>
#include <botan/internal/rounding.h>
#include <botan/internal/assert.h>
@@ -78,8 +79,8 @@ Connection_Cipher_State::Connection_Cipher_State(Protocol_Version version,
void write_record(std::vector<byte>& output,
byte msg_type, const byte msg[], size_t msg_length,
- u64bit msg_sequence_number,
Protocol_Version version,
+ Connection_Sequence_Numbers& sequence_numbers,
Connection_Cipher_State* cipherstate,
RandomNumberGenerator& rng)
{
@@ -89,10 +90,12 @@ void write_record(std::vector<byte>& output,
output.push_back(version.major_version());
output.push_back(version.minor_version());
+ const u64bit msg_sequence = sequence_numbers.next_write_sequence();
+
if(version.is_datagram_protocol())
{
for(size_t i = 0; i != 8; ++i)
- output.push_back(get_byte(i, msg_sequence_number));
+ output.push_back(get_byte(i, msg_sequence));
}
if(!cipherstate) // initial unencrypted handshake records
@@ -105,7 +108,7 @@ void write_record(std::vector<byte>& output,
return;
}
- cipherstate->mac()->update_be(msg_sequence_number);
+ cipherstate->mac()->update_be(msg_sequence);
cipherstate->mac()->update(msg_type);
if(cipherstate->mac_includes_record_version())
@@ -273,8 +276,8 @@ size_t read_record(std::vector<byte>& readbuf,
size_t& consumed,
byte& msg_type,
std::vector<byte>& msg,
- u64bit msg_sequence,
Protocol_Version& record_version,
+ Connection_Sequence_Numbers* sequence_numbers,
Connection_Cipher_State* cipherstate)
{
consumed = 0;
@@ -352,8 +355,14 @@ size_t read_record(std::vector<byte>& readbuf,
const size_t header_size =
(record_version.is_datagram_protocol()) ? DTLS_HEADER_SIZE : TLS_HEADER_SIZE;
+ u64bit msg_sequence = 0;
+
if(record_version.is_datagram_protocol())
msg_sequence = load_be<u64bit>(&readbuf[3], 0);
+ else if(sequence_numbers)
+ msg_sequence = sequence_numbers->next_read_sequence();
+ else
+ msg_sequence = 0; // server initial handshake case
const size_t record_len = make_u16bit(readbuf[header_size-2],
readbuf[header_size-1]);
@@ -371,6 +380,9 @@ size_t read_record(std::vector<byte>& readbuf,
readbuf_pos,
"Have the full record");
+ if(sequence_numbers && sequence_numbers->already_seen(msg_sequence))
+ return 0;
+
byte* record_contents = &readbuf[header_size];
if(!cipherstate) // Only handshake messages allowed during initial handshake
@@ -469,6 +481,9 @@ size_t read_record(std::vector<byte>& readbuf,
if(mac_bad || padding_bad)
throw TLS_Exception(Alert::BAD_RECORD_MAC, "Message authentication failure");
+ if(sequence_numbers)
+ sequence_numbers->read_accept(msg_sequence);
+
msg_type = readbuf[0];
msg.assign(&record_contents[iv_size],
&record_contents[iv_size + plain_length]);
diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h
index 2b415edb8..841244733 100644
--- a/src/tls/tls_record.h
+++ b/src/tls/tls_record.h
@@ -23,6 +23,8 @@ namespace TLS {
class Ciphersuite;
class Session_Keys;
+class Connection_Sequence_Numbers;
+
/**
* TLS Cipher State
*/
@@ -78,8 +80,8 @@ class Connection_Cipher_State
*/
void write_record(std::vector<byte>& write_buffer,
byte msg_type, const byte msg[], size_t msg_length,
- u64bit msg_sequence_number,
Protocol_Version version,
+ Connection_Sequence_Numbers& sequence_numbers,
Connection_Cipher_State* cipherstate,
RandomNumberGenerator& rng);
@@ -94,8 +96,8 @@ size_t read_record(std::vector<byte>& read_buffer,
size_t& input_consumed,
byte& msg_type,
std::vector<byte>& msg,
- u64bit msg_sequence,
Protocol_Version& record_version,
+ Connection_Sequence_Numbers* sequence_numbers,
Connection_Cipher_State* cipherstate);
}
diff --git a/src/tls/tls_seq_numbers.h b/src/tls/tls_seq_numbers.h
new file mode 100644
index 000000000..c9a334e4b
--- /dev/null
+++ b/src/tls/tls_seq_numbers.h
@@ -0,0 +1,112 @@
+/*
+* TLS Sequence Number Handling
+* (C) 2012 Jack Lloyd
+*
+* Released under the terms of the Botan license
+*/
+
+#ifndef BOTAN_TLS_SEQ_NUMBERS_H__
+#define BOTAN_TLS_SEQ_NUMBERS_H__
+
+#include <stdexcept>
+
+namespace Botan {
+
+namespace TLS {
+
+class Connection_Sequence_Numbers
+ {
+ public:
+ virtual void new_read_cipher_state() = 0;
+ virtual void new_write_cipher_state() = 0;
+
+ virtual u64bit next_write_sequence() = 0;
+
+ virtual u64bit next_read_sequence() = 0;
+ virtual bool already_seen(u64bit seq) const = 0;
+ virtual void read_accept(u64bit seq) = 0;
+ };
+
+class Stream_Sequence_Numbers : public Connection_Sequence_Numbers
+ {
+ public:
+ void new_read_cipher_state() override { m_read_seq_no = 0; }
+ void new_write_cipher_state() override { m_write_seq_no = 0; }
+
+ u64bit next_write_sequence() override { return m_write_seq_no++; }
+
+ u64bit next_read_sequence() override { return m_read_seq_no; }
+ bool already_seen(u64bit) const override { return false; }
+ void read_accept(u64bit) override { m_read_seq_no++; }
+ private:
+ u64bit m_write_seq_no = 0;
+ u64bit m_read_seq_no = 0;
+ };
+
+class Datagram_Sequence_Numbers : public Connection_Sequence_Numbers
+ {
+ public:
+ void new_read_cipher_state() override {}
+
+ void new_write_cipher_state() override
+ {
+ // increment epoch
+ m_write_seq_no = ((m_write_seq_no >> 48) + 1) << 48;
+ }
+
+ u64bit next_write_sequence() override { return m_write_seq_no++; }
+
+ u64bit next_read_sequence() override
+ {
+ throw std::runtime_error("DTLS uses explicit sequence numbers");
+ }
+
+ bool already_seen(u64bit sequence) const override
+ {
+ const size_t window_size = sizeof(m_window_bits) * 8;
+
+ if(sequence > m_window_highest)
+ return false;
+
+ const u64bit offset = m_window_highest - sequence;
+
+ if(offset >= window_size)
+ return true; // really old?
+
+ return (((m_window_bits >> offset) & 1) == 1);
+ }
+
+ void read_accept(u64bit sequence) override
+ {
+ const size_t window_size = sizeof(m_window_bits) * 8;
+
+ if(sequence > m_window_highest)
+ {
+ const size_t offset = sequence - m_window_highest;
+ m_window_highest += offset;
+
+ if(offset >= window_size)
+ m_window_bits = 0;
+ else
+ m_window_bits <<= offset;
+
+ m_window_bits |= 0x01;
+ }
+ else
+ {
+ const u64bit offset = m_window_highest - sequence;
+ m_window_bits |= (static_cast<u64bit>(1) << offset);
+ }
+ }
+
+ private:
+ u64bit m_write_seq_no = 0;
+ u64bit m_window_highest = 0;
+ u64bit m_window_bits = 0;
+ };
+
+}
+
+}
+
+#endif