aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorJack Lloyd <[email protected]>2019-07-14 19:30:47 -0400
committerJack Lloyd <[email protected]>2019-07-14 19:30:47 -0400
commit67ce92b89318c25b35dc92f917166e5f3a22bf76 (patch)
tree8da5832e8dc9e9c420153c54461313ff7b409926 /src
parente01ef99340af26feccb0ba769e6ac12bf4b8d3cf (diff)
parent0557ef2299e1528037c53ff70c8e7fcfec816438 (diff)
Merge GH #2029 Support a DTLS client reconnecting from same source port
Diffstat (limited to 'src')
-rw-r--r--src/build-data/version.txt2
-rw-r--r--src/lib/tls/tls_channel.cpp83
-rw-r--r--src/lib/tls/tls_channel.h19
-rw-r--r--src/lib/tls/tls_client.cpp13
-rw-r--r--src/lib/tls/tls_client.h3
-rw-r--r--src/lib/tls/tls_handshake_io.cpp15
-rw-r--r--src/lib/tls/tls_handshake_io.h6
-rw-r--r--src/lib/tls/tls_policy.cpp1
-rw-r--r--src/lib/tls/tls_policy.h6
-rw-r--r--src/lib/tls/tls_record.cpp16
-rw-r--r--src/lib/tls/tls_record.h8
-rw-r--r--src/lib/tls/tls_seq_numbers.h33
-rw-r--r--src/lib/tls/tls_server.cpp34
-rw-r--r--src/lib/tls/tls_server.h6
-rw-r--r--src/tests/unit_tls.cpp235
15 files changed, 423 insertions, 57 deletions
diff --git a/src/build-data/version.txt b/src/build-data/version.txt
index 9170528e3..7998b8b47 100644
--- a/src/build-data/version.txt
+++ b/src/build-data/version.txt
@@ -2,7 +2,7 @@
release_major = 2
release_minor = 12
release_patch = 0
-release_so_abi_rev = 11
+release_so_abi_rev = 12
# These are set by the distribution script
release_vc_rev = None
diff --git a/src/lib/tls/tls_channel.cpp b/src/lib/tls/tls_channel.cpp
index 2fefc40b7..366bef9c3 100644
--- a/src/lib/tls/tls_channel.cpp
+++ b/src/lib/tls/tls_channel.cpp
@@ -27,8 +27,10 @@ Channel::Channel(Callbacks& callbacks,
Session_Manager& session_manager,
RandomNumberGenerator& rng,
const Policy& policy,
+ bool is_server,
bool is_datagram,
size_t reserved_io_buffer_size) :
+ m_is_server(is_server),
m_is_datagram(is_datagram),
m_callbacks(callbacks),
m_session_manager(session_manager),
@@ -46,8 +48,10 @@ Channel::Channel(output_fn out,
Session_Manager& session_manager,
RandomNumberGenerator& rng,
const Policy& policy,
+ bool is_server,
bool is_datagram,
size_t io_buf_sz) :
+ m_is_server(is_server),
m_is_datagram(is_datagram),
m_compat_callbacks(new Compat_Callbacks(
/*
@@ -83,6 +87,21 @@ void Channel::reset_state()
m_read_cipher_states.clear();
}
+void Channel::reset_active_association_state()
+ {
+ // This operation only makes sense for DTLS
+ BOTAN_ASSERT_NOMSG(m_is_datagram);
+ m_active_state.reset();
+ m_read_cipher_states.clear();
+ m_write_cipher_states.clear();
+
+ m_write_cipher_states[0] = nullptr;
+ m_read_cipher_states[0] = nullptr;
+
+ if(m_sequence_numbers)
+ m_sequence_numbers->reset();
+ }
+
Channel::~Channel()
{
// So unique_ptr destructors run correctly
@@ -271,7 +290,14 @@ bool Channel::is_closed() const
* received a connection. This case is detectable by also lacking
* m_sequence_numbers
*/
- return (m_sequence_numbers != nullptr);
+ if(m_is_server)
+ {
+ return (m_sequence_numbers != nullptr);
+ }
+ else
+ {
+ return true;
+ }
}
void Channel::activate_session()
@@ -301,12 +327,16 @@ size_t Channel::received_data(const std::vector<uint8_t>& buf)
size_t Channel::received_data(const uint8_t input[], size_t input_size)
{
+ const bool allow_epoch0_restart = m_is_datagram && m_is_server && policy().allow_dtls_epoch0_restart();
+
try
{
while(!is_closed() && input_size)
{
size_t consumed = 0;
+ auto get_epoch = [this](uint16_t epoch) { return read_cipher_state_epoch(epoch); };
+
const Record_Header record =
read_record(m_is_datagram,
m_readbuf,
@@ -315,7 +345,8 @@ size_t Channel::received_data(const uint8_t input[], size_t input_size)
consumed,
m_record_buf,
m_sequence_numbers.get(),
- [this](uint16_t epoch) { return read_cipher_state_epoch(epoch); });
+ get_epoch,
+ allow_epoch0_restart);
const size_t needed = record.needed();
@@ -335,45 +366,56 @@ size_t Channel::received_data(const uint8_t input[], size_t input_size)
// Ignore invalid records in DTLS
if(m_is_datagram && record.type() == NO_RECORD)
+ {
return 0;
+ }
if(m_record_buf.size() > MAX_PLAINTEXT_SIZE)
throw TLS_Exception(Alert::RECORD_OVERFLOW,
"TLS plaintext record is larger than allowed maximum");
+
+ const bool epoch0_restart = m_is_datagram && record.epoch() == 0 && active_state();
+ BOTAN_ASSERT_IMPLICATION(epoch0_restart, allow_epoch0_restart, "Allowed state");
+
+ const bool initial_record = epoch0_restart || (!pending_state() && !active_state());
+
if(record.type() != ALERT)
{
- if(auto pending = pending_state())
+ if(initial_record)
{
- if(pending->server_hello() != nullptr && record.version() != pending->version())
+ // For initial records just check for basic sanity
+ if(record.version().major_version() != 3 &&
+ record.version().major_version() != 0xFE)
{
throw TLS_Exception(Alert::PROTOCOL_VERSION,
- "Received unexpected record version");
+ "Received unexpected record version in initial record");
}
}
- else if(auto active = active_state())
+ else if(auto pending = pending_state())
{
- if(record.version() != active->version())
+ if(pending->server_hello() != nullptr && record.version() != pending->version())
{
- throw TLS_Exception(Alert::PROTOCOL_VERSION,
- "Received unexpected record version");
+ if(record.version() != pending->version())
+ {
+ throw TLS_Exception(Alert::PROTOCOL_VERSION,
+ "Received unexpected record version");
+ }
}
}
- else
+ else if(auto active = active_state())
{
- // For initial records just check for basic sanity
- if(record.version().major_version() != 3 &&
- record.version().major_version() != 0xFE)
+ if(record.version() != active->version())
{
throw TLS_Exception(Alert::PROTOCOL_VERSION,
- "Received unexpected record version in initial record");
+ "Received unexpected record version");
}
}
}
if(record.type() == HANDSHAKE || record.type() == CHANGE_CIPHER_SPEC)
{
- process_handshake_ccs(m_record_buf, record.sequence(), record.type(), record.version());
+ process_handshake_ccs(m_record_buf, record.sequence(), record.type(), record.version(), epoch0_restart);
}
else if(record.type() == APPLICATION_DATA)
{
@@ -418,12 +460,13 @@ size_t Channel::received_data(const uint8_t input[], size_t input_size)
void Channel::process_handshake_ccs(const secure_vector<uint8_t>& record,
uint64_t record_sequence,
Record_Type record_type,
- Protocol_Version record_version)
+ Protocol_Version record_version,
+ bool epoch0_restart)
{
if(!m_pending_state)
{
// No pending handshake, possibly new:
- if(record_version.is_datagram_protocol())
+ if(record_version.is_datagram_protocol() && !epoch0_restart)
{
if(m_sequence_numbers)
{
@@ -475,7 +518,10 @@ void Channel::process_handshake_ccs(const secure_vector<uint8_t>& record,
break;
process_handshake_msg(active_state(), *pending,
- msg.first, msg.second);
+ msg.first, msg.second, epoch0_restart);
+
+ if(!m_pending_state)
+ break;
}
}
}
@@ -512,7 +558,6 @@ void Channel::process_alert(const secure_vector<uint8_t>& record)
}
}
-
void Channel::write_record(Connection_Cipher_State* cipher_state, uint16_t epoch,
uint8_t record_type, const uint8_t input[], size_t length)
{
diff --git a/src/lib/tls/tls_channel.h b/src/lib/tls/tls_channel.h
index 2a2b74332..8f977932b 100644
--- a/src/lib/tls/tls_channel.h
+++ b/src/lib/tls/tls_channel.h
@@ -66,6 +66,7 @@ class BOTAN_PUBLIC_API(2,0) Channel
Session_Manager& session_manager,
RandomNumberGenerator& rng,
const Policy& policy,
+ bool is_server,
bool is_datagram,
size_t io_buf_sz = IO_BUF_DEFAULT_SIZE);
@@ -83,6 +84,7 @@ class BOTAN_PUBLIC_API(2,0) Channel
Session_Manager& session_manager,
RandomNumberGenerator& rng,
const Policy& policy,
+ bool is_server,
bool is_datagram,
size_t io_buf_sz = IO_BUF_DEFAULT_SIZE);
@@ -160,7 +162,6 @@ class BOTAN_PUBLIC_API(2,0) Channel
*/
bool is_closed() const;
-
/**
* @return certificate chain of the peer (may be empty)
*/
@@ -205,7 +206,8 @@ class BOTAN_PUBLIC_API(2,0) Channel
virtual void process_handshake_msg(const Handshake_State* active_state,
Handshake_State& pending_state,
Handshake_Type type,
- const std::vector<uint8_t>& contents) = 0;
+ const std::vector<uint8_t>& contents,
+ bool epoch0_restart) = 0;
virtual void initiate_handshake(Handshake_State& state,
bool force_full_renegotiation) = 0;
@@ -242,6 +244,9 @@ class BOTAN_PUBLIC_API(2,0) Channel
bool save_session(const Session& session);
Callbacks& callbacks() const { return m_callbacks; }
+
+ void reset_active_association_state();
+
private:
void init(size_t io_buf_sze);
@@ -256,14 +261,14 @@ class BOTAN_PUBLIC_API(2,0) Channel
void write_record(Connection_Cipher_State* cipher_state,
uint16_t epoch, uint8_t type, const uint8_t input[], size_t length);
+ void reset_state();
+
Connection_Sequence_Numbers& sequence_numbers() const;
std::shared_ptr<Connection_Cipher_State> read_cipher_state_epoch(uint16_t epoch) const;
std::shared_ptr<Connection_Cipher_State> write_cipher_state_epoch(uint16_t epoch) const;
- void reset_state();
-
const Handshake_State* active_state() const { return m_active_state.get(); }
const Handshake_State* pending_state() const { return m_pending_state.get(); }
@@ -272,13 +277,15 @@ class BOTAN_PUBLIC_API(2,0) Channel
void process_handshake_ccs(const secure_vector<uint8_t>& record,
uint64_t record_sequence,
Record_Type record_type,
- Protocol_Version record_version);
+ Protocol_Version record_version,
+ bool epoch0_restart);
void process_application_data(uint64_t req_no, const secure_vector<uint8_t>& record);
void process_alert(const secure_vector<uint8_t>& record);
- bool m_is_datagram;
+ const bool m_is_server;
+ const bool m_is_datagram;
/* callbacks */
std::unique_ptr<Compat_Callbacks> m_compat_callbacks;
diff --git a/src/lib/tls/tls_client.cpp b/src/lib/tls/tls_client.cpp
index 10bd34226..440dfb6c2 100644
--- a/src/lib/tls/tls_client.cpp
+++ b/src/lib/tls/tls_client.cpp
@@ -70,8 +70,8 @@ Client::Client(Callbacks& callbacks,
const Protocol_Version& offer_version,
const std::vector<std::string>& next_protos,
size_t io_buf_sz) :
- Channel(callbacks, session_manager, rng, policy, offer_version.is_datagram_protocol(),
- io_buf_sz),
+ Channel(callbacks, session_manager, rng, policy,
+ false, offer_version.is_datagram_protocol(), io_buf_sz),
m_creds(creds),
m_info(info)
{
@@ -91,7 +91,7 @@ Client::Client(output_fn data_output_fn,
const std::vector<std::string>& next_protos,
size_t io_buf_sz) :
Channel(data_output_fn, proc_cb, recv_alert_cb, hs_cb, Channel::handshake_msg_cb(),
- session_manager, rng, policy, offer_version.is_datagram_protocol(), io_buf_sz),
+ session_manager, rng, policy, false, offer_version.is_datagram_protocol(), io_buf_sz),
m_creds(creds),
m_info(info)
{
@@ -111,7 +111,7 @@ Client::Client(output_fn data_output_fn,
const Protocol_Version& offer_version,
const std::vector<std::string>& next_protos) :
Channel(data_output_fn, proc_cb, recv_alert_cb, hs_cb, hs_msg_cb,
- session_manager, rng, policy, offer_version.is_datagram_protocol()),
+ session_manager, rng, policy, false, offer_version.is_datagram_protocol()),
m_creds(creds),
m_info(info)
{
@@ -227,8 +227,11 @@ void Client::send_client_hello(Handshake_State& state_base,
void Client::process_handshake_msg(const Handshake_State* active_state,
Handshake_State& state_base,
Handshake_Type type,
- const std::vector<uint8_t>& contents)
+ const std::vector<uint8_t>& contents,
+ bool epoch0_restart)
{
+ BOTAN_ASSERT_NOMSG(epoch0_restart == false); // only happens on server side
+
Client_Handshake_State& state = dynamic_cast<Client_Handshake_State&>(state_base);
if(type == HELLO_REQUEST && active_state)
diff --git a/src/lib/tls/tls_client.h b/src/lib/tls/tls_client.h
index 005370e78..0e08b4595 100644
--- a/src/lib/tls/tls_client.h
+++ b/src/lib/tls/tls_client.h
@@ -152,7 +152,8 @@ class BOTAN_PUBLIC_API(2,0) Client final : public Channel
void process_handshake_msg(const Handshake_State* active_state,
Handshake_State& pending_state,
Handshake_Type type,
- const std::vector<uint8_t>& contents) override;
+ const std::vector<uint8_t>& contents,
+ bool epoch0_restart) override;
Handshake_State* new_handshake_state(Handshake_IO* io) override;
diff --git a/src/lib/tls/tls_handshake_io.cpp b/src/lib/tls/tls_handshake_io.cpp
index 62f3bebb8..3f3e672de 100644
--- a/src/lib/tls/tls_handshake_io.cpp
+++ b/src/lib/tls/tls_handshake_io.cpp
@@ -113,6 +113,11 @@ Stream_Handshake_IO::format(const std::vector<uint8_t>& msg,
return send_buf;
}
+std::vector<uint8_t> Stream_Handshake_IO::send_under_epoch(const Handshake_Message& /*msg*/, uint16_t /*epoch*/)
+ {
+ throw Invalid_State("Not possible to send under arbitrary epoch with stream based TLS");
+ }
+
std::vector<uint8_t> Stream_Handshake_IO::send(const Handshake_Message& msg)
{
const std::vector<uint8_t> msg_bits = msg.serialize();
@@ -261,7 +266,9 @@ Datagram_Handshake_IO::get_next_record(bool expecting_ccs)
auto i = m_messages.find(m_in_message_seq);
if(i == m_messages.end() || !i->second.complete())
+ {
return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>());
+ }
m_in_message_seq += 1;
@@ -379,11 +386,15 @@ Datagram_Handshake_IO::format(const std::vector<uint8_t>& msg,
return format_w_seq(msg, type, m_in_message_seq - 1);
}
+std::vector<uint8_t> Datagram_Handshake_IO::send(const Handshake_Message& msg)
+ {
+ return this->send_under_epoch(msg, m_seqs.current_write_epoch());
+ }
+
std::vector<uint8_t>
-Datagram_Handshake_IO::send(const Handshake_Message& msg)
+Datagram_Handshake_IO::send_under_epoch(const Handshake_Message& msg, uint16_t epoch)
{
const std::vector<uint8_t> msg_bits = msg.serialize();
- const uint16_t epoch = m_seqs.current_write_epoch();
const Handshake_Type msg_type = msg.type();
if(msg_type == HANDSHAKE_CCS)
diff --git a/src/lib/tls/tls_handshake_io.h b/src/lib/tls/tls_handshake_io.h
index 66579459d..1c128726d 100644
--- a/src/lib/tls/tls_handshake_io.h
+++ b/src/lib/tls/tls_handshake_io.h
@@ -33,6 +33,8 @@ class Handshake_IO
virtual std::vector<uint8_t> send(const Handshake_Message& msg) = 0;
+ virtual std::vector<uint8_t> send_under_epoch(const Handshake_Message& msg, uint16_t epoch) = 0;
+
virtual bool timeout_check() = 0;
virtual std::vector<uint8_t> format(
@@ -75,6 +77,8 @@ class Stream_Handshake_IO final : public Handshake_IO
std::vector<uint8_t> send(const Handshake_Message& msg) override;
+ std::vector<uint8_t> send_under_epoch(const Handshake_Message& msg, uint16_t epoch) override;
+
std::vector<uint8_t> format(
const std::vector<uint8_t>& handshake_msg,
Handshake_Type handshake_type) const override;
@@ -116,6 +120,8 @@ class Datagram_Handshake_IO final : public Handshake_IO
std::vector<uint8_t> send(const Handshake_Message& msg) override;
+ std::vector<uint8_t> send_under_epoch(const Handshake_Message& msg, uint16_t epoch) override;
+
std::vector<uint8_t> format(
const std::vector<uint8_t>& handshake_msg,
Handshake_Type handshake_type) const override;
diff --git a/src/lib/tls/tls_policy.cpp b/src/lib/tls/tls_policy.cpp
index 58ba73ade..0e627fdea 100644
--- a/src/lib/tls/tls_policy.cpp
+++ b/src/lib/tls/tls_policy.cpp
@@ -336,6 +336,7 @@ bool Policy::only_resume_with_exact_version() const { return true; }
bool Policy::require_client_certificate_authentication() const { return false; }
bool Policy::request_client_certificate_authentication() const { return require_client_certificate_authentication(); }
bool Policy::abort_connection_on_undesired_renegotiation() const { return false; }
+bool Policy::allow_dtls_epoch0_restart() const { return false; }
size_t Policy::maximum_certificate_chain_size() const { return 0; }
diff --git a/src/lib/tls/tls_policy.h b/src/lib/tls/tls_policy.h
index 3a5be83d9..b076d5f9d 100644
--- a/src/lib/tls/tls_policy.h
+++ b/src/lib/tls/tls_policy.h
@@ -292,6 +292,12 @@ class BOTAN_PUBLIC_API(2,0) Policy
virtual bool request_client_certificate_authentication() const;
/**
+ * If true, then allow a DTLS client to restart a connection to the
+ * same server association as described in section 4.2.8 of the DTLS RFC
+ */
+ virtual bool allow_dtls_epoch0_restart() const;
+
+ /**
* Return allowed ciphersuites, in order of preference
*/
virtual std::vector<uint16_t> ciphersuite_list(Protocol_Version version,
diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp
index 3304b70eb..71f942bc4 100644
--- a/src/lib/tls/tls_record.cpp
+++ b/src/lib/tls/tls_record.cpp
@@ -399,7 +399,8 @@ Record_Header read_dtls_record(secure_vector<uint8_t>& readbuf,
size_t& consumed,
secure_vector<uint8_t>& recbuf,
Connection_Sequence_Numbers* sequence_numbers,
- get_cipherstate_fn get_cipherstate)
+ get_cipherstate_fn get_cipherstate,
+ bool allow_epoch0_restart)
{
if(readbuf.size() < DTLS_HEADER_SIZE) // header incomplete?
{
@@ -442,12 +443,12 @@ Record_Header read_dtls_record(secure_vector<uint8_t>& readbuf,
const Record_Type type = static_cast<Record_Type>(readbuf[0]);
- uint16_t epoch = 0;
-
const uint64_t sequence = load_be<uint64_t>(&readbuf[3], 0);
- epoch = (sequence >> 48);
+ const uint16_t epoch = (sequence >> 48);
+
+ const bool already_seen = sequence_numbers && sequence_numbers->already_seen(sequence);
- if(sequence_numbers && sequence_numbers->already_seen(sequence))
+ if(already_seen && !(epoch == 0 && allow_epoch0_restart))
{
readbuf.clear();
return Record_Header(0);
@@ -499,11 +500,12 @@ Record_Header read_record(bool is_datagram,
size_t& consumed,
secure_vector<uint8_t>& recbuf,
Connection_Sequence_Numbers* sequence_numbers,
- get_cipherstate_fn get_cipherstate)
+ get_cipherstate_fn get_cipherstate,
+ bool allow_epoch0_restart)
{
if(is_datagram)
return read_dtls_record(readbuf, input, input_len, consumed,
- recbuf, sequence_numbers, get_cipherstate);
+ recbuf, sequence_numbers, get_cipherstate, allow_epoch0_restart);
else
return read_tls_record(readbuf, input, input_len, consumed,
recbuf, sequence_numbers, get_cipherstate);
diff --git a/src/lib/tls/tls_record.h b/src/lib/tls/tls_record.h
index 3e3475c03..779954439 100644
--- a/src/lib/tls/tls_record.h
+++ b/src/lib/tls/tls_record.h
@@ -110,6 +110,11 @@ class Record_Header final
return m_sequence;
}
+ uint16_t epoch() const
+ {
+ return static_cast<uint16_t>(sequence() >> 48);
+ }
+
Record_Type type() const
{
BOTAN_ASSERT_NOMSG(m_needed == 0);
@@ -157,7 +162,8 @@ Record_Header read_record(bool is_datagram,
size_t& consumed,
secure_vector<uint8_t>& record_buf,
Connection_Sequence_Numbers* sequence_numbers,
- get_cipherstate_fn get_cipherstate);
+ get_cipherstate_fn get_cipherstate,
+ bool allow_epoch0_restart);
}
diff --git a/src/lib/tls/tls_seq_numbers.h b/src/lib/tls/tls_seq_numbers.h
index 34f949baa..64e2d0589 100644
--- a/src/lib/tls/tls_seq_numbers.h
+++ b/src/lib/tls/tls_seq_numbers.h
@@ -31,11 +31,23 @@ class Connection_Sequence_Numbers
virtual bool already_seen(uint64_t seq) const = 0;
virtual void read_accept(uint64_t seq) = 0;
+
+ virtual void reset() = 0;
};
class Stream_Sequence_Numbers final : public Connection_Sequence_Numbers
{
public:
+ Stream_Sequence_Numbers() { Stream_Sequence_Numbers::reset(); }
+
+ void reset() override
+ {
+ m_write_seq_no = 0;
+ m_read_seq_no = 0;
+ m_read_epoch = 0;
+ m_write_epoch = 0;
+ }
+
void new_read_cipher_state() override { m_read_seq_no = 0; m_read_epoch++; }
void new_write_cipher_state() override { m_write_seq_no = 0; m_write_epoch++; }
@@ -47,17 +59,28 @@ class Stream_Sequence_Numbers final : public Connection_Sequence_Numbers
bool already_seen(uint64_t) const override { return false; }
void read_accept(uint64_t) override { m_read_seq_no++; }
+
private:
- uint64_t m_write_seq_no = 0;
- uint64_t m_read_seq_no = 0;
- uint16_t m_read_epoch = 0;
- uint16_t m_write_epoch = 0;
+ uint64_t m_write_seq_no;
+ uint64_t m_read_seq_no;
+ uint16_t m_read_epoch;
+ uint16_t m_write_epoch;
};
class Datagram_Sequence_Numbers final : public Connection_Sequence_Numbers
{
public:
- Datagram_Sequence_Numbers() { m_write_seqs[0] = 0; }
+ Datagram_Sequence_Numbers() { Datagram_Sequence_Numbers::reset(); }
+
+ void reset() override
+ {
+ m_write_seqs.clear();
+ m_write_seqs[0] = 0;
+ m_write_epoch = 0;
+ m_read_epoch = 0;
+ m_window_highest = 0;
+ m_window_bits = 0;
+ }
void new_read_cipher_state() override { m_read_epoch++; }
diff --git a/src/lib/tls/tls_server.cpp b/src/lib/tls/tls_server.cpp
index 76941fd11..33d45b852 100644
--- a/src/lib/tls/tls_server.cpp
+++ b/src/lib/tls/tls_server.cpp
@@ -303,7 +303,7 @@ Server::Server(Callbacks& callbacks,
bool is_datagram,
size_t io_buf_sz) :
Channel(callbacks, session_manager, rng, policy,
- is_datagram, io_buf_sz),
+ true, is_datagram, io_buf_sz),
m_creds(creds)
{
}
@@ -321,7 +321,7 @@ Server::Server(output_fn output,
size_t io_buf_sz) :
Channel(output, got_data_cb, recv_alert_cb, hs_cb,
Channel::handshake_msg_cb(), session_manager,
- rng, policy, is_datagram, io_buf_sz),
+ rng, policy, true, is_datagram, io_buf_sz),
m_creds(creds),
m_choose_next_protocol(next_proto)
{
@@ -339,7 +339,7 @@ Server::Server(output_fn output,
next_protocol_fn next_proto,
bool is_datagram) :
Channel(output, got_data_cb, recv_alert_cb, hs_cb, hs_msg_cb,
- session_manager, rng, policy, is_datagram),
+ session_manager, rng, policy, true, is_datagram),
m_creds(creds),
m_choose_next_protocol(next_proto)
{
@@ -471,9 +471,12 @@ Protocol_Version select_version(const Botan::TLS::Policy& policy,
*/
void Server::process_client_hello_msg(const Handshake_State* active_state,
Server_Handshake_State& pending_state,
- const std::vector<uint8_t>& contents)
+ const std::vector<uint8_t>& contents,
+ bool epoch0_restart)
{
- const bool initial_handshake = !active_state;
+ BOTAN_ASSERT_IMPLICATION(epoch0_restart, active_state != nullptr, "Can't restart with a dead connection");
+
+ const bool initial_handshake = epoch0_restart || !active_state;
if(initial_handshake == false && policy().allow_client_initiated_renegotiation() == false)
{
@@ -547,12 +550,26 @@ void Server::process_client_hello_msg(const Handshake_State* active_state,
if(pending_state.client_hello()->cookie() != verify.cookie())
{
- pending_state.handshake_io().send(verify);
+ if(epoch0_restart)
+ pending_state.handshake_io().send_under_epoch(verify, 0);
+ else
+ pending_state.handshake_io().send(verify);
+
pending_state.client_hello(nullptr);
pending_state.set_expected_next(CLIENT_HELLO);
return;
}
}
+ else if(epoch0_restart)
+ {
+ throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Reuse of DTLS association requires DTLS cookie secret be set");
+ }
+ }
+
+ if(epoch0_restart)
+ {
+ // If we reached here then we were able to verify the cookie
+ reset_active_association_state();
}
secure_renegotiation_check(pending_state.client_hello());
@@ -749,7 +766,8 @@ void Server::process_finished_msg(Server_Handshake_State& pending_state,
void Server::process_handshake_msg(const Handshake_State* active_state,
Handshake_State& state_base,
Handshake_Type type,
- const std::vector<uint8_t>& contents)
+ const std::vector<uint8_t>& contents,
+ bool epoch0_restart)
{
Server_Handshake_State& state = dynamic_cast<Server_Handshake_State&>(state_base);
state.confirm_transition_to(type);
@@ -769,7 +787,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state,
switch(type)
{
case CLIENT_HELLO:
- return this->process_client_hello_msg(active_state, state, contents);
+ return this->process_client_hello_msg(active_state, state, contents, epoch0_restart);
case CERTIFICATE:
return this->process_certificate_msg(state, contents);
diff --git a/src/lib/tls/tls_server.h b/src/lib/tls/tls_server.h
index e6536934a..c601e8c6e 100644
--- a/src/lib/tls/tls_server.h
+++ b/src/lib/tls/tls_server.h
@@ -122,11 +122,13 @@ class BOTAN_PUBLIC_API(2,0) Server final : public Channel
void process_handshake_msg(const Handshake_State* active_state,
Handshake_State& pending_state,
Handshake_Type type,
- const std::vector<uint8_t>& contents) override;
+ const std::vector<uint8_t>& contents,
+ bool epoch0_restart) override;
void process_client_hello_msg(const Handshake_State* active_state,
Server_Handshake_State& pending_state,
- const std::vector<uint8_t>& contents);
+ const std::vector<uint8_t>& contents,
+ bool epoch0_restart);
void process_certificate_msg(Server_Handshake_State& pending_state,
const std::vector<uint8_t>& contents);
diff --git a/src/tests/unit_tls.cpp b/src/tests/unit_tls.cpp
index c3355b118..c6114b010 100644
--- a/src/tests/unit_tls.cpp
+++ b/src/tests/unit_tls.cpp
@@ -1032,6 +1032,241 @@ class TLS_Unit_Tests final : public Test
BOTAN_REGISTER_TEST("tls", TLS_Unit_Tests);
+class DTLS_Reconnection_Test : public Test
+ {
+ public:
+ std::vector<Test::Result> run() override
+ {
+ class Test_Callbacks : public Botan::TLS::Callbacks
+ {
+ public:
+ Test_Callbacks(Test::Result& results,
+ std::vector<uint8_t>& outbound,
+ std::vector<uint8_t>& recv_buf) :
+ m_results(results),
+ m_outbound(outbound),
+ m_recv(recv_buf)
+ {}
+
+ void tls_emit_data(const uint8_t bits[], size_t len) override
+ {
+ m_outbound.insert(m_outbound.end(), bits, bits + len);
+ }
+
+ void tls_record_received(uint64_t /*seq*/, const uint8_t bits[], size_t len) override
+ {
+ m_recv.insert(m_recv.end(), bits, bits + len);
+ }
+
+ void tls_alert(Botan::TLS::Alert /*alert*/) override
+ {
+ // ignore
+ }
+
+ bool tls_session_established(const Botan::TLS::Session& /*session*/) override
+ {
+ m_results.test_success("Established a session");
+ return true;
+ }
+
+ private:
+ Test::Result& m_results;
+ std::vector<uint8_t>& m_outbound;
+ std::vector<uint8_t>& m_recv;
+ };
+
+ class Credentials_PSK : public Botan::Credentials_Manager
+ {
+ public:
+ Botan::SymmetricKey psk(const std::string& type,
+ const std::string& context,
+ const std::string&) override
+ {
+ if(type == "tls-server" && context == "session-ticket")
+ {
+ return Botan::SymmetricKey("AABBCCDDEEFF012345678012345678");
+ }
+
+ if(type == "tls-server" && context == "dtls-cookie-secret")
+ {
+ return Botan::SymmetricKey("4AEA5EAD279CADEB537A594DA0E9DE3A");
+ }
+
+ if(context == "localhost" && type == "tls-client")
+ {
+ return Botan::SymmetricKey("20B602D1475F2DF888FCB60D2AE03AFD");
+ }
+
+ if(context == "localhost" && type == "tls-server")
+ {
+ return Botan::SymmetricKey("20B602D1475F2DF888FCB60D2AE03AFD");
+ }
+
+ throw Test_Error("No PSK set for " + type + "/" + context);
+ }
+ };
+
+ class Datagram_PSK_Policy : public Botan::TLS::Policy
+ {
+ public:
+ std::vector<std::string> allowed_macs() const override
+ { return std::vector<std::string>({"AEAD"}); }
+
+ std::vector<std::string> allowed_key_exchange_methods() const override
+ { return {"PSK"}; }
+
+ bool allow_tls10() const override { return false; }
+ bool allow_tls11() const override { return false; }
+ bool allow_tls12() const override { return false; }
+ bool allow_dtls10() const override { return false; }
+ bool allow_dtls12() const override { return true; }
+
+ bool allow_dtls_epoch0_restart() const override { return true; }
+ };
+
+ Test::Result result("DTLS reconnection");
+
+ Datagram_PSK_Policy server_policy;
+ Datagram_PSK_Policy client_policy;
+ Credentials_PSK creds;
+ Botan::TLS::Session_Manager_In_Memory server_sessions(rng());
+ //Botan::TLS::Session_Manager_In_Memory client_sessions(rng());
+ Botan::TLS::Session_Manager_Noop client_sessions;
+
+ std::vector<uint8_t> s2c, server_recv;
+ Test_Callbacks server_callbacks(result, s2c, server_recv);
+ Botan::TLS::Server server(server_callbacks, server_sessions, creds, server_policy, rng(), true);
+
+ std::vector<uint8_t> c1_c2s, client1_recv;
+ Test_Callbacks client1_callbacks(result, c1_c2s, client1_recv);
+ Botan::TLS::Client client1(client1_callbacks, client_sessions, creds, client_policy, rng(),
+ Botan::TLS::Server_Information("localhost"),
+ Botan::TLS::Protocol_Version::latest_dtls_version());
+
+ bool c1_to_server_sent = false;
+ bool server_to_c1_sent = false;
+
+ const std::vector<uint8_t> c1_to_server_magic(16, 0xC1);
+ const std::vector<uint8_t> server_to_c1_magic(16, 0x42);
+
+ size_t c1_rounds = 0;
+ for(;;)
+ {
+ c1_rounds++;
+
+ if(c1_rounds > 64)
+ {
+ result.test_failure("Still spinning in client1 loop after 64 rounds");
+ return {result};
+ }
+
+ if(c1_c2s.size() > 0)
+ {
+ std::vector<uint8_t> input;
+ std::swap(c1_c2s, input);
+ server.received_data(input.data(), input.size());
+ continue;
+ }
+
+ if(s2c.size() > 0)
+ {
+ std::vector<uint8_t> input;
+ std::swap(s2c, input);
+ client1.received_data(input.data(), input.size());
+ continue;
+ }
+
+ if(!c1_to_server_sent && client1.is_active())
+ {
+ client1.send(c1_to_server_magic);
+ c1_to_server_sent = true;
+ }
+
+ if(!server_to_c1_sent && server.is_active())
+ {
+ server.send(server_to_c1_magic);
+ }
+
+ if(server_recv.size() > 0 && client1_recv.size() > 0)
+ {
+ result.test_eq("Expected message from client1", server_recv, c1_to_server_magic);
+ result.test_eq("Expected message to client1", client1_recv, server_to_c1_magic);
+ break;
+ }
+ }
+
+ // Now client1 "goes away" (goes silent) and new client
+ // connects to same server context (ie due to reuse of client source port)
+ // See RFC 6347 section 4.2.8
+
+ server_recv.clear();
+ s2c.clear();
+
+ std::vector<uint8_t> c2_c2s, client2_recv;
+ Test_Callbacks client2_callbacks(result, c2_c2s, client2_recv);
+ Botan::TLS::Client client2(client2_callbacks, client_sessions, creds, client_policy, rng(),
+ Botan::TLS::Server_Information("localhost"),
+ Botan::TLS::Protocol_Version::latest_dtls_version());
+
+ bool c2_to_server_sent = false;
+ bool server_to_c2_sent = false;
+
+ const std::vector<uint8_t> c2_to_server_magic(16, 0xC2);
+ const std::vector<uint8_t> server_to_c2_magic(16, 0x66);
+
+ size_t c2_rounds = 0;
+
+ for(;;)
+ {
+ c2_rounds++;
+
+ if(c2_rounds > 64)
+ {
+ result.test_failure("Still spinning in client2 loop after 64 rounds");
+ return {result};
+ }
+
+ if(c2_c2s.size() > 0)
+ {
+ std::vector<uint8_t> input;
+ std::swap(c2_c2s, input);
+ server.received_data(input.data(), input.size());
+ continue;
+ }
+
+ if(s2c.size() > 0)
+ {
+ std::vector<uint8_t> input;
+ std::swap(s2c, input);
+ client2.received_data(input.data(), input.size());
+ continue;
+ }
+
+ if(!c2_to_server_sent && client2.is_active())
+ {
+ client2.send(c2_to_server_magic);
+ c2_to_server_sent = true;
+ }
+
+ if(!server_to_c2_sent && server.is_active())
+ {
+ server.send(server_to_c2_magic);
+ }
+
+ if(server_recv.size() > 0 && client2_recv.size() > 0)
+ {
+ result.test_eq("Expected message from client2", server_recv, c2_to_server_magic);
+ result.test_eq("Expected message to client2", client2_recv, server_to_c2_magic);
+ break;
+ }
+ }
+
+ return {result};
+ }
+ };
+
+BOTAN_REGISTER_TEST("tls_dtls_reconnect", DTLS_Reconnection_Test);
+
#endif
}