aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorlloyd <[email protected]>2012-11-12 22:05:09 +0000
committerlloyd <[email protected]>2012-11-12 22:05:09 +0000
commit58461a900aea49e5230b7b748fc481114d31904a (patch)
tree1a8d54f5368d5109845f6d6fee32b32b0c1d8d12
parent579158e826daed42963db0c8b987d51ba7831fb6 (diff)
Changes so DTLS handshake can send messages under different epochs, eg
for retransmitting a flight.
-rw-r--r--src/tls/tls_channel.cpp71
-rw-r--r--src/tls/tls_channel.h6
-rw-r--r--src/tls/tls_client.cpp4
-rw-r--r--src/tls/tls_handshake_io.cpp24
-rw-r--r--src/tls/tls_handshake_io.h19
-rw-r--r--src/tls/tls_policy.cpp5
-rw-r--r--src/tls/tls_record.cpp11
7 files changed, 85 insertions, 55 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index 1be336fc5..5858f5d90 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -73,8 +73,6 @@ std::vector<X509_Certificate> Channel::peer_cert_chain() const
Handshake_State& Channel::create_handshake_state(Protocol_Version version)
{
- const size_t dtls_mtu = 1400; // fixme should be settable
-
if(pending_state())
throw Internal_Error("create_handshake_state called during handshake");
@@ -98,15 +96,19 @@ Handshake_State& Channel::create_handshake_state(Protocol_Version version)
m_sequence_numbers.reset(new Stream_Sequence_Numbers);
}
- auto send_rec = std::bind(&Channel::send_record, this,
- std::placeholders::_1,
- std::placeholders::_2);
-
std::unique_ptr<Handshake_IO> io;
if(version.is_datagram_protocol())
- io.reset(new Datagram_Handshake_IO(send_rec, dtls_mtu));
+ io.reset(new Datagram_Handshake_IO(
+ sequence_numbers(),
+ std::bind(&Channel::send_record_under_epoch, this,
+ std::placeholders::_1,
+ std::placeholders::_2,
+ std::placeholders::_3)));
else
- io.reset(new Stream_Handshake_IO(send_rec));
+ io.reset(new Stream_Handshake_IO(
+ std::bind(&Channel::send_record, this,
+ std::placeholders::_1,
+ std::placeholders::_2)));
m_pending_state.reset(new_handshake_state(io.release()));
@@ -429,7 +431,28 @@ void Channel::heartbeat(const byte payload[], size_t payload_size)
}
}
-void Channel::send_record_array(byte type, const byte input[], size_t length)
+void Channel::write_record(Connection_Cipher_State* cipher_state,
+ byte record_type, const byte input[], size_t length)
+ {
+ BOTAN_ASSERT(m_pending_state || m_active_state,
+ "Some connection state exists");
+
+ Protocol_Version record_version =
+ (m_pending_state) ? (m_pending_state->version()) : (m_active_state->version());
+
+ TLS::write_record(m_writebuf,
+ record_type,
+ input,
+ length,
+ record_version,
+ sequence_numbers(),
+ cipher_state,
+ m_rng);
+
+ m_output_fn(&m_writebuf[0], m_writebuf.size());
+ }
+
+void Channel::send_record_array(u16bit epoch, byte type, const byte input[], size_t length)
{
if(length == 0)
return;
@@ -446,8 +469,7 @@ void Channel::send_record_array(byte type, const byte input[], size_t length)
* See http://www.openssl.org/~bodo/tls-cbc.txt for background.
*/
- auto cipher_state =
- write_cipher_state_epoch(sequence_numbers().current_write_epoch());
+ auto cipher_state = write_cipher_state_epoch(epoch);
if(type == APPLICATION_DATA && cipher_state->cbc_without_explicit_iv())
{
@@ -470,28 +492,14 @@ void Channel::send_record_array(byte type, const byte input[], size_t length)
void Channel::send_record(byte record_type, const std::vector<byte>& record)
{
- send_record_array(record_type, &record[0], record.size());
+ send_record_array(sequence_numbers().current_write_epoch(),
+ record_type, &record[0], record.size());
}
-void Channel::write_record(Connection_Cipher_State* cipher_state,
- byte record_type, const byte input[], size_t length)
+void Channel::send_record_under_epoch(u16bit epoch, byte record_type,
+ const std::vector<byte>& record)
{
- BOTAN_ASSERT(m_pending_state || m_active_state,
- "Some connection state exists");
-
- Protocol_Version record_version =
- (m_pending_state) ? (m_pending_state->version()) : (m_active_state->version());
-
- TLS::write_record(m_writebuf,
- record_type,
- input,
- length,
- record_version,
- sequence_numbers(),
- cipher_state,
- m_rng);
-
- m_output_fn(&m_writebuf[0], m_writebuf.size());
+ send_record_array(epoch, record_type, &record[0], record.size());
}
void Channel::send(const byte buf[], size_t buf_size)
@@ -499,7 +507,8 @@ void Channel::send(const byte buf[], size_t buf_size)
if(!is_active())
throw std::runtime_error("Data cannot be sent on inactive TLS connection");
- send_record_array(APPLICATION_DATA, buf, buf_size);
+ send_record_array(sequence_numbers().current_write_epoch(),
+ APPLICATION_DATA, buf, buf_size);
}
void Channel::send(const std::string& string)
diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h
index 10ecd296f..d27f8f2f5 100644
--- a/src/tls/tls_channel.h
+++ b/src/tls/tls_channel.h
@@ -175,7 +175,11 @@ class BOTAN_DLL Channel
void send_record(byte record_type, const std::vector<byte>& record);
- void send_record_array(byte type, const byte input[], size_t length);
+ void send_record_under_epoch(u16bit epoch, byte record_type,
+ const std::vector<byte>& record);
+
+ void send_record_array(u16bit epoch, byte record_type,
+ const byte input[], size_t length);
void write_record(Connection_Cipher_State* cipher_state,
byte type, const byte input[], size_t length);
diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp
index aae3a65c5..b0724b03c 100644
--- a/src/tls/tls_client.cpp
+++ b/src/tls/tls_client.cpp
@@ -254,6 +254,10 @@ void Client::process_handshake_msg(const Handshake_State* active_state,
{
// new session
+ BOTAN_ASSERT_EQUAL(state.client_hello()->version().is_datagram_protocol(),
+ state.server_hello()->version().is_datagram_protocol(),
+ "Server replied with same protocol type client offered");
+
if(state.version() > state.client_hello()->version())
{
throw TLS_Exception(Alert::HANDSHAKE_FAILURE,
diff --git a/src/tls/tls_handshake_io.cpp b/src/tls/tls_handshake_io.cpp
index b83d9e044..1fae7b5b7 100644
--- a/src/tls/tls_handshake_io.cpp
+++ b/src/tls/tls_handshake_io.cpp
@@ -7,6 +7,7 @@
#include <botan/internal/tls_handshake_io.h>
#include <botan/internal/tls_messages.h>
+#include <botan/internal/tls_seq_numbers.h>
#include <botan/exceptn.h>
namespace Botan {
@@ -231,6 +232,10 @@ void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment(
/*
* FIXME. This is a pretty lame way to do defragmentation, huge
* overhead with a tree node per byte.
+ *
+ * Also should confirm that all overlaps have no changes,
+ * otherwise we expose ourselves to the classic fingerprinting
+ * and IDS evasion attacks on IP fragmentation.
*/
for(size_t i = 0; i != fragment_length; ++i)
m_fragments[fragment_offset+i] = fragment[i];
@@ -318,18 +323,22 @@ std::vector<byte>
Datagram_Handshake_IO::send(const Handshake_Message& msg)
{
const std::vector<byte> msg_bits = msg.serialize();
+ const u16bit epoch = m_seqs.current_write_epoch();
+ const Handshake_Type msg_type = msg.type();
- if(msg.type() == HANDSHAKE_CCS)
+ std::tuple<u16bit, byte, std::vector<byte>> msg_info(epoch, msg_type, msg_bits);
+
+ if(msg_type == HANDSHAKE_CCS)
{
- m_send_hs(CHANGE_CIPHER_SPEC, msg_bits);
+ m_send_hs(epoch, CHANGE_CIPHER_SPEC, msg_bits);
return std::vector<byte>(); // not included in handshake hashes
}
const std::vector<byte> no_fragment =
- format_w_seq(msg_bits, msg.type(), m_out_message_seq);
+ format_w_seq(msg_bits, msg_type, m_out_message_seq);
if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu)
- m_send_hs(HANDSHAKE, no_fragment);
+ m_send_hs(epoch, HANDSHAKE, no_fragment);
else
{
const size_t parts = split_for_mtu(m_mtu, msg_bits.size());
@@ -344,19 +353,22 @@ Datagram_Handshake_IO::send(const Handshake_Message& msg)
std::min<size_t>(msg_bits.size() - frag_offset,
parts_size);
- m_send_hs(HANDSHAKE,
+ m_send_hs(epoch,
+ HANDSHAKE,
format_fragment(&msg_bits[frag_offset],
frag_len,
frag_offset,
msg_bits.size(),
- msg.type(),
+ msg_type,
m_out_message_seq));
frag_offset += frag_len;
}
}
+ // Note: not saving CCS, instead we know it was there due to change in epoch
m_flights.rbegin()->push_back(m_out_message_seq);
+ m_flight_data[m_out_message_seq] = msg_info;
m_out_message_seq += 1;
diff --git a/src/tls/tls_handshake_io.h b/src/tls/tls_handshake_io.h
index b026d4160..18fde1a83 100644
--- a/src/tls/tls_handshake_io.h
+++ b/src/tls/tls_handshake_io.h
@@ -17,6 +17,7 @@
#include <map>
#include <set>
#include <utility>
+#include <tuple>
namespace Botan {
@@ -24,8 +25,6 @@ namespace TLS {
class Handshake_Message;
-typedef std::function<void (byte, const std::vector<byte>&)> handshake_write_fn;
-
/**
* Handshake IO Interface
*/
@@ -66,7 +65,7 @@ class Handshake_IO
class Stream_Handshake_IO : public Handshake_IO
{
public:
- Stream_Handshake_IO(handshake_write_fn writer) :
+ Stream_Handshake_IO(std::function<void (byte, const std::vector<byte>&)> writer) :
m_send_hs(writer) {}
Protocol_Version initial_record_version() const override;
@@ -86,7 +85,7 @@ class Stream_Handshake_IO : public Handshake_IO
get_next_record(bool expecting_ccs) override;
private:
std::deque<byte> m_queue;
- handshake_write_fn m_send_hs;
+ std::function<void (byte, const std::vector<byte>&)> m_send_hs;
};
/**
@@ -95,8 +94,9 @@ class Stream_Handshake_IO : public Handshake_IO
class Datagram_Handshake_IO : public Handshake_IO
{
public:
- Datagram_Handshake_IO(handshake_write_fn writer, u16bit mtu) :
- m_flights(1), m_mtu(mtu), m_send_hs(writer) {}
+ Datagram_Handshake_IO(class Connection_Sequence_Numbers& seq,
+ std::function<void (u16bit, byte, const std::vector<byte>&)> writer) :
+ m_seqs(seq), m_flights(1), m_send_hs(writer) {}
Protocol_Version initial_record_version() const override;
@@ -151,14 +151,17 @@ class Datagram_Handshake_IO : public Handshake_IO
std::vector<byte> m_message;
};
+ class Connection_Sequence_Numbers& m_seqs;
std::map<u16bit, Handshake_Reassembly> m_messages;
std::set<u16bit> m_ccs_epochs;
std::vector<std::vector<u16bit>> m_flights;
+ std::map<u16bit, std::tuple<u16bit, byte, std::vector<byte>>> m_flight_data;
- u16bit m_mtu = 0;
+ // default MTU is IPv6 min MTU minus UDP/IP headers
+ u16bit m_mtu = 1280 - 40 - 8;
u16bit m_in_message_seq = 0;
u16bit m_out_message_seq = 0;
- handshake_write_fn m_send_hs;
+ std::function<void (u16bit, byte, const std::vector<byte>&)> m_send_hs;
};
}
diff --git a/src/tls/tls_policy.cpp b/src/tls/tls_policy.cpp
index c76fe30a5..e98fe66b2 100644
--- a/src/tls/tls_policy.cpp
+++ b/src/tls/tls_policy.cpp
@@ -130,10 +130,7 @@ u32bit Policy::session_ticket_lifetime() const
bool Policy::acceptable_protocol_version(Protocol_Version version) const
{
- return (version == Protocol_Version::SSL_V3 ||
- version == Protocol_Version::TLS_V10 ||
- version == Protocol_Version::TLS_V11 ||
- version == Protocol_Version::TLS_V12);
+ return version.known_version(); // accept any version we know about
}
namespace {
diff --git a/src/tls/tls_record.cpp b/src/tls/tls_record.cpp
index e11ba31b1..fab966e72 100644
--- a/src/tls/tls_record.cpp
+++ b/src/tls/tls_record.cpp
@@ -328,7 +328,9 @@ size_t read_record(std::vector<byte>& readbuf,
record_version = Protocol_Version(readbuf[1], readbuf[2]);
- if(record_version.is_datagram_protocol() && readbuf.size() < DTLS_HEADER_SIZE)
+ const bool is_dtls = record_version.is_datagram_protocol();
+
+ if(is_dtls && readbuf.size() < DTLS_HEADER_SIZE)
{
if(size_t needed = fill_buffer_to(readbuf,
input, input_sz, consumed,
@@ -339,8 +341,7 @@ size_t read_record(std::vector<byte>& readbuf,
"Have an entire header");
}
- const size_t header_size =
- (record_version.is_datagram_protocol()) ? DTLS_HEADER_SIZE : TLS_HEADER_SIZE;
+ const size_t header_size = (is_dtls) ? DTLS_HEADER_SIZE : TLS_HEADER_SIZE;
const size_t record_len = make_u16bit(readbuf[header_size-2],
readbuf[header_size-1]);
@@ -352,7 +353,7 @@ size_t read_record(std::vector<byte>& readbuf,
if(size_t needed = fill_buffer_to(readbuf,
input, input_sz, consumed,
header_size + record_len))
- return needed;
+ return needed; // wrong for DTLS?
BOTAN_ASSERT_EQUAL(static_cast<size_t>(header_size) + record_len,
readbuf.size(),
@@ -360,7 +361,7 @@ size_t read_record(std::vector<byte>& readbuf,
u16bit epoch = 0;
- if(record_version.is_datagram_protocol())
+ if(is_dtls)
{
record_sequence = load_be<u64bit>(&readbuf[3], 0);
epoch = (record_sequence >> 48);