diff options
author | Jack Lloyd <[email protected]> | 2019-07-14 19:30:47 -0400 |
---|---|---|
committer | Jack Lloyd <[email protected]> | 2019-07-14 19:30:47 -0400 |
commit | 67ce92b89318c25b35dc92f917166e5f3a22bf76 (patch) | |
tree | 8da5832e8dc9e9c420153c54461313ff7b409926 /src | |
parent | e01ef99340af26feccb0ba769e6ac12bf4b8d3cf (diff) | |
parent | 0557ef2299e1528037c53ff70c8e7fcfec816438 (diff) |
Merge GH #2029 Support a DTLS client reconnecting from same source port
Diffstat (limited to 'src')
-rw-r--r-- | src/build-data/version.txt | 2 | ||||
-rw-r--r-- | src/lib/tls/tls_channel.cpp | 83 | ||||
-rw-r--r-- | src/lib/tls/tls_channel.h | 19 | ||||
-rw-r--r-- | src/lib/tls/tls_client.cpp | 13 | ||||
-rw-r--r-- | src/lib/tls/tls_client.h | 3 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_io.cpp | 15 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_io.h | 6 | ||||
-rw-r--r-- | src/lib/tls/tls_policy.cpp | 1 | ||||
-rw-r--r-- | src/lib/tls/tls_policy.h | 6 | ||||
-rw-r--r-- | src/lib/tls/tls_record.cpp | 16 | ||||
-rw-r--r-- | src/lib/tls/tls_record.h | 8 | ||||
-rw-r--r-- | src/lib/tls/tls_seq_numbers.h | 33 | ||||
-rw-r--r-- | src/lib/tls/tls_server.cpp | 34 | ||||
-rw-r--r-- | src/lib/tls/tls_server.h | 6 | ||||
-rw-r--r-- | src/tests/unit_tls.cpp | 235 |
15 files changed, 423 insertions, 57 deletions
diff --git a/src/build-data/version.txt b/src/build-data/version.txt index 9170528e3..7998b8b47 100644 --- a/src/build-data/version.txt +++ b/src/build-data/version.txt @@ -2,7 +2,7 @@ release_major = 2 release_minor = 12 release_patch = 0 -release_so_abi_rev = 11 +release_so_abi_rev = 12 # These are set by the distribution script release_vc_rev = None diff --git a/src/lib/tls/tls_channel.cpp b/src/lib/tls/tls_channel.cpp index 2fefc40b7..366bef9c3 100644 --- a/src/lib/tls/tls_channel.cpp +++ b/src/lib/tls/tls_channel.cpp @@ -27,8 +27,10 @@ Channel::Channel(Callbacks& callbacks, Session_Manager& session_manager, RandomNumberGenerator& rng, const Policy& policy, + bool is_server, bool is_datagram, size_t reserved_io_buffer_size) : + m_is_server(is_server), m_is_datagram(is_datagram), m_callbacks(callbacks), m_session_manager(session_manager), @@ -46,8 +48,10 @@ Channel::Channel(output_fn out, Session_Manager& session_manager, RandomNumberGenerator& rng, const Policy& policy, + bool is_server, bool is_datagram, size_t io_buf_sz) : + m_is_server(is_server), m_is_datagram(is_datagram), m_compat_callbacks(new Compat_Callbacks( /* @@ -83,6 +87,21 @@ void Channel::reset_state() m_read_cipher_states.clear(); } +void Channel::reset_active_association_state() + { + // This operation only makes sense for DTLS + BOTAN_ASSERT_NOMSG(m_is_datagram); + m_active_state.reset(); + m_read_cipher_states.clear(); + m_write_cipher_states.clear(); + + m_write_cipher_states[0] = nullptr; + m_read_cipher_states[0] = nullptr; + + if(m_sequence_numbers) + m_sequence_numbers->reset(); + } + Channel::~Channel() { // So unique_ptr destructors run correctly @@ -271,7 +290,14 @@ bool Channel::is_closed() const * received a connection. This case is detectable by also lacking * m_sequence_numbers */ - return (m_sequence_numbers != nullptr); + if(m_is_server) + { + return (m_sequence_numbers != nullptr); + } + else + { + return true; + } } void Channel::activate_session() @@ -301,12 +327,16 @@ size_t Channel::received_data(const std::vector<uint8_t>& buf) size_t Channel::received_data(const uint8_t input[], size_t input_size) { + const bool allow_epoch0_restart = m_is_datagram && m_is_server && policy().allow_dtls_epoch0_restart(); + try { while(!is_closed() && input_size) { size_t consumed = 0; + auto get_epoch = [this](uint16_t epoch) { return read_cipher_state_epoch(epoch); }; + const Record_Header record = read_record(m_is_datagram, m_readbuf, @@ -315,7 +345,8 @@ size_t Channel::received_data(const uint8_t input[], size_t input_size) consumed, m_record_buf, m_sequence_numbers.get(), - [this](uint16_t epoch) { return read_cipher_state_epoch(epoch); }); + get_epoch, + allow_epoch0_restart); const size_t needed = record.needed(); @@ -335,45 +366,56 @@ size_t Channel::received_data(const uint8_t input[], size_t input_size) // Ignore invalid records in DTLS if(m_is_datagram && record.type() == NO_RECORD) + { return 0; + } if(m_record_buf.size() > MAX_PLAINTEXT_SIZE) throw TLS_Exception(Alert::RECORD_OVERFLOW, "TLS plaintext record is larger than allowed maximum"); + + const bool epoch0_restart = m_is_datagram && record.epoch() == 0 && active_state(); + BOTAN_ASSERT_IMPLICATION(epoch0_restart, allow_epoch0_restart, "Allowed state"); + + const bool initial_record = epoch0_restart || (!pending_state() && !active_state()); + if(record.type() != ALERT) { - if(auto pending = pending_state()) + if(initial_record) { - if(pending->server_hello() != nullptr && record.version() != pending->version()) + // For initial records just check for basic sanity + if(record.version().major_version() != 3 && + record.version().major_version() != 0xFE) { throw TLS_Exception(Alert::PROTOCOL_VERSION, - "Received unexpected record version"); + "Received unexpected record version in initial record"); } } - else if(auto active = active_state()) + else if(auto pending = pending_state()) { - if(record.version() != active->version()) + if(pending->server_hello() != nullptr && record.version() != pending->version()) { - throw TLS_Exception(Alert::PROTOCOL_VERSION, - "Received unexpected record version"); + if(record.version() != pending->version()) + { + throw TLS_Exception(Alert::PROTOCOL_VERSION, + "Received unexpected record version"); + } } } - else + else if(auto active = active_state()) { - // For initial records just check for basic sanity - if(record.version().major_version() != 3 && - record.version().major_version() != 0xFE) + if(record.version() != active->version()) { throw TLS_Exception(Alert::PROTOCOL_VERSION, - "Received unexpected record version in initial record"); + "Received unexpected record version"); } } } if(record.type() == HANDSHAKE || record.type() == CHANGE_CIPHER_SPEC) { - process_handshake_ccs(m_record_buf, record.sequence(), record.type(), record.version()); + process_handshake_ccs(m_record_buf, record.sequence(), record.type(), record.version(), epoch0_restart); } else if(record.type() == APPLICATION_DATA) { @@ -418,12 +460,13 @@ size_t Channel::received_data(const uint8_t input[], size_t input_size) void Channel::process_handshake_ccs(const secure_vector<uint8_t>& record, uint64_t record_sequence, Record_Type record_type, - Protocol_Version record_version) + Protocol_Version record_version, + bool epoch0_restart) { if(!m_pending_state) { // No pending handshake, possibly new: - if(record_version.is_datagram_protocol()) + if(record_version.is_datagram_protocol() && !epoch0_restart) { if(m_sequence_numbers) { @@ -475,7 +518,10 @@ void Channel::process_handshake_ccs(const secure_vector<uint8_t>& record, break; process_handshake_msg(active_state(), *pending, - msg.first, msg.second); + msg.first, msg.second, epoch0_restart); + + if(!m_pending_state) + break; } } } @@ -512,7 +558,6 @@ void Channel::process_alert(const secure_vector<uint8_t>& record) } } - void Channel::write_record(Connection_Cipher_State* cipher_state, uint16_t epoch, uint8_t record_type, const uint8_t input[], size_t length) { diff --git a/src/lib/tls/tls_channel.h b/src/lib/tls/tls_channel.h index 2a2b74332..8f977932b 100644 --- a/src/lib/tls/tls_channel.h +++ b/src/lib/tls/tls_channel.h @@ -66,6 +66,7 @@ class BOTAN_PUBLIC_API(2,0) Channel Session_Manager& session_manager, RandomNumberGenerator& rng, const Policy& policy, + bool is_server, bool is_datagram, size_t io_buf_sz = IO_BUF_DEFAULT_SIZE); @@ -83,6 +84,7 @@ class BOTAN_PUBLIC_API(2,0) Channel Session_Manager& session_manager, RandomNumberGenerator& rng, const Policy& policy, + bool is_server, bool is_datagram, size_t io_buf_sz = IO_BUF_DEFAULT_SIZE); @@ -160,7 +162,6 @@ class BOTAN_PUBLIC_API(2,0) Channel */ bool is_closed() const; - /** * @return certificate chain of the peer (may be empty) */ @@ -205,7 +206,8 @@ class BOTAN_PUBLIC_API(2,0) Channel virtual void process_handshake_msg(const Handshake_State* active_state, Handshake_State& pending_state, Handshake_Type type, - const std::vector<uint8_t>& contents) = 0; + const std::vector<uint8_t>& contents, + bool epoch0_restart) = 0; virtual void initiate_handshake(Handshake_State& state, bool force_full_renegotiation) = 0; @@ -242,6 +244,9 @@ class BOTAN_PUBLIC_API(2,0) Channel bool save_session(const Session& session); Callbacks& callbacks() const { return m_callbacks; } + + void reset_active_association_state(); + private: void init(size_t io_buf_sze); @@ -256,14 +261,14 @@ class BOTAN_PUBLIC_API(2,0) Channel void write_record(Connection_Cipher_State* cipher_state, uint16_t epoch, uint8_t type, const uint8_t input[], size_t length); + void reset_state(); + Connection_Sequence_Numbers& sequence_numbers() const; std::shared_ptr<Connection_Cipher_State> read_cipher_state_epoch(uint16_t epoch) const; std::shared_ptr<Connection_Cipher_State> write_cipher_state_epoch(uint16_t epoch) const; - void reset_state(); - const Handshake_State* active_state() const { return m_active_state.get(); } const Handshake_State* pending_state() const { return m_pending_state.get(); } @@ -272,13 +277,15 @@ class BOTAN_PUBLIC_API(2,0) Channel void process_handshake_ccs(const secure_vector<uint8_t>& record, uint64_t record_sequence, Record_Type record_type, - Protocol_Version record_version); + Protocol_Version record_version, + bool epoch0_restart); void process_application_data(uint64_t req_no, const secure_vector<uint8_t>& record); void process_alert(const secure_vector<uint8_t>& record); - bool m_is_datagram; + const bool m_is_server; + const bool m_is_datagram; /* callbacks */ std::unique_ptr<Compat_Callbacks> m_compat_callbacks; diff --git a/src/lib/tls/tls_client.cpp b/src/lib/tls/tls_client.cpp index 10bd34226..440dfb6c2 100644 --- a/src/lib/tls/tls_client.cpp +++ b/src/lib/tls/tls_client.cpp @@ -70,8 +70,8 @@ Client::Client(Callbacks& callbacks, const Protocol_Version& offer_version, const std::vector<std::string>& next_protos, size_t io_buf_sz) : - Channel(callbacks, session_manager, rng, policy, offer_version.is_datagram_protocol(), - io_buf_sz), + Channel(callbacks, session_manager, rng, policy, + false, offer_version.is_datagram_protocol(), io_buf_sz), m_creds(creds), m_info(info) { @@ -91,7 +91,7 @@ Client::Client(output_fn data_output_fn, const std::vector<std::string>& next_protos, size_t io_buf_sz) : Channel(data_output_fn, proc_cb, recv_alert_cb, hs_cb, Channel::handshake_msg_cb(), - session_manager, rng, policy, offer_version.is_datagram_protocol(), io_buf_sz), + session_manager, rng, policy, false, offer_version.is_datagram_protocol(), io_buf_sz), m_creds(creds), m_info(info) { @@ -111,7 +111,7 @@ Client::Client(output_fn data_output_fn, const Protocol_Version& offer_version, const std::vector<std::string>& next_protos) : Channel(data_output_fn, proc_cb, recv_alert_cb, hs_cb, hs_msg_cb, - session_manager, rng, policy, offer_version.is_datagram_protocol()), + session_manager, rng, policy, false, offer_version.is_datagram_protocol()), m_creds(creds), m_info(info) { @@ -227,8 +227,11 @@ void Client::send_client_hello(Handshake_State& state_base, void Client::process_handshake_msg(const Handshake_State* active_state, Handshake_State& state_base, Handshake_Type type, - const std::vector<uint8_t>& contents) + const std::vector<uint8_t>& contents, + bool epoch0_restart) { + BOTAN_ASSERT_NOMSG(epoch0_restart == false); // only happens on server side + Client_Handshake_State& state = dynamic_cast<Client_Handshake_State&>(state_base); if(type == HELLO_REQUEST && active_state) diff --git a/src/lib/tls/tls_client.h b/src/lib/tls/tls_client.h index 005370e78..0e08b4595 100644 --- a/src/lib/tls/tls_client.h +++ b/src/lib/tls/tls_client.h @@ -152,7 +152,8 @@ class BOTAN_PUBLIC_API(2,0) Client final : public Channel void process_handshake_msg(const Handshake_State* active_state, Handshake_State& pending_state, Handshake_Type type, - const std::vector<uint8_t>& contents) override; + const std::vector<uint8_t>& contents, + bool epoch0_restart) override; Handshake_State* new_handshake_state(Handshake_IO* io) override; diff --git a/src/lib/tls/tls_handshake_io.cpp b/src/lib/tls/tls_handshake_io.cpp index 62f3bebb8..3f3e672de 100644 --- a/src/lib/tls/tls_handshake_io.cpp +++ b/src/lib/tls/tls_handshake_io.cpp @@ -113,6 +113,11 @@ Stream_Handshake_IO::format(const std::vector<uint8_t>& msg, return send_buf; } +std::vector<uint8_t> Stream_Handshake_IO::send_under_epoch(const Handshake_Message& /*msg*/, uint16_t /*epoch*/) + { + throw Invalid_State("Not possible to send under arbitrary epoch with stream based TLS"); + } + std::vector<uint8_t> Stream_Handshake_IO::send(const Handshake_Message& msg) { const std::vector<uint8_t> msg_bits = msg.serialize(); @@ -261,7 +266,9 @@ Datagram_Handshake_IO::get_next_record(bool expecting_ccs) auto i = m_messages.find(m_in_message_seq); if(i == m_messages.end() || !i->second.complete()) + { return std::make_pair(HANDSHAKE_NONE, std::vector<uint8_t>()); + } m_in_message_seq += 1; @@ -379,11 +386,15 @@ Datagram_Handshake_IO::format(const std::vector<uint8_t>& msg, return format_w_seq(msg, type, m_in_message_seq - 1); } +std::vector<uint8_t> Datagram_Handshake_IO::send(const Handshake_Message& msg) + { + return this->send_under_epoch(msg, m_seqs.current_write_epoch()); + } + std::vector<uint8_t> -Datagram_Handshake_IO::send(const Handshake_Message& msg) +Datagram_Handshake_IO::send_under_epoch(const Handshake_Message& msg, uint16_t epoch) { const std::vector<uint8_t> msg_bits = msg.serialize(); - const uint16_t epoch = m_seqs.current_write_epoch(); const Handshake_Type msg_type = msg.type(); if(msg_type == HANDSHAKE_CCS) diff --git a/src/lib/tls/tls_handshake_io.h b/src/lib/tls/tls_handshake_io.h index 66579459d..1c128726d 100644 --- a/src/lib/tls/tls_handshake_io.h +++ b/src/lib/tls/tls_handshake_io.h @@ -33,6 +33,8 @@ class Handshake_IO virtual std::vector<uint8_t> send(const Handshake_Message& msg) = 0; + virtual std::vector<uint8_t> send_under_epoch(const Handshake_Message& msg, uint16_t epoch) = 0; + virtual bool timeout_check() = 0; virtual std::vector<uint8_t> format( @@ -75,6 +77,8 @@ class Stream_Handshake_IO final : public Handshake_IO std::vector<uint8_t> send(const Handshake_Message& msg) override; + std::vector<uint8_t> send_under_epoch(const Handshake_Message& msg, uint16_t epoch) override; + std::vector<uint8_t> format( const std::vector<uint8_t>& handshake_msg, Handshake_Type handshake_type) const override; @@ -116,6 +120,8 @@ class Datagram_Handshake_IO final : public Handshake_IO std::vector<uint8_t> send(const Handshake_Message& msg) override; + std::vector<uint8_t> send_under_epoch(const Handshake_Message& msg, uint16_t epoch) override; + std::vector<uint8_t> format( const std::vector<uint8_t>& handshake_msg, Handshake_Type handshake_type) const override; diff --git a/src/lib/tls/tls_policy.cpp b/src/lib/tls/tls_policy.cpp index 58ba73ade..0e627fdea 100644 --- a/src/lib/tls/tls_policy.cpp +++ b/src/lib/tls/tls_policy.cpp @@ -336,6 +336,7 @@ bool Policy::only_resume_with_exact_version() const { return true; } bool Policy::require_client_certificate_authentication() const { return false; } bool Policy::request_client_certificate_authentication() const { return require_client_certificate_authentication(); } bool Policy::abort_connection_on_undesired_renegotiation() const { return false; } +bool Policy::allow_dtls_epoch0_restart() const { return false; } size_t Policy::maximum_certificate_chain_size() const { return 0; } diff --git a/src/lib/tls/tls_policy.h b/src/lib/tls/tls_policy.h index 3a5be83d9..b076d5f9d 100644 --- a/src/lib/tls/tls_policy.h +++ b/src/lib/tls/tls_policy.h @@ -292,6 +292,12 @@ class BOTAN_PUBLIC_API(2,0) Policy virtual bool request_client_certificate_authentication() const; /** + * If true, then allow a DTLS client to restart a connection to the + * same server association as described in section 4.2.8 of the DTLS RFC + */ + virtual bool allow_dtls_epoch0_restart() const; + + /** * Return allowed ciphersuites, in order of preference */ virtual std::vector<uint16_t> ciphersuite_list(Protocol_Version version, diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp index 3304b70eb..71f942bc4 100644 --- a/src/lib/tls/tls_record.cpp +++ b/src/lib/tls/tls_record.cpp @@ -399,7 +399,8 @@ Record_Header read_dtls_record(secure_vector<uint8_t>& readbuf, size_t& consumed, secure_vector<uint8_t>& recbuf, Connection_Sequence_Numbers* sequence_numbers, - get_cipherstate_fn get_cipherstate) + get_cipherstate_fn get_cipherstate, + bool allow_epoch0_restart) { if(readbuf.size() < DTLS_HEADER_SIZE) // header incomplete? { @@ -442,12 +443,12 @@ Record_Header read_dtls_record(secure_vector<uint8_t>& readbuf, const Record_Type type = static_cast<Record_Type>(readbuf[0]); - uint16_t epoch = 0; - const uint64_t sequence = load_be<uint64_t>(&readbuf[3], 0); - epoch = (sequence >> 48); + const uint16_t epoch = (sequence >> 48); + + const bool already_seen = sequence_numbers && sequence_numbers->already_seen(sequence); - if(sequence_numbers && sequence_numbers->already_seen(sequence)) + if(already_seen && !(epoch == 0 && allow_epoch0_restart)) { readbuf.clear(); return Record_Header(0); @@ -499,11 +500,12 @@ Record_Header read_record(bool is_datagram, size_t& consumed, secure_vector<uint8_t>& recbuf, Connection_Sequence_Numbers* sequence_numbers, - get_cipherstate_fn get_cipherstate) + get_cipherstate_fn get_cipherstate, + bool allow_epoch0_restart) { if(is_datagram) return read_dtls_record(readbuf, input, input_len, consumed, - recbuf, sequence_numbers, get_cipherstate); + recbuf, sequence_numbers, get_cipherstate, allow_epoch0_restart); else return read_tls_record(readbuf, input, input_len, consumed, recbuf, sequence_numbers, get_cipherstate); diff --git a/src/lib/tls/tls_record.h b/src/lib/tls/tls_record.h index 3e3475c03..779954439 100644 --- a/src/lib/tls/tls_record.h +++ b/src/lib/tls/tls_record.h @@ -110,6 +110,11 @@ class Record_Header final return m_sequence; } + uint16_t epoch() const + { + return static_cast<uint16_t>(sequence() >> 48); + } + Record_Type type() const { BOTAN_ASSERT_NOMSG(m_needed == 0); @@ -157,7 +162,8 @@ Record_Header read_record(bool is_datagram, size_t& consumed, secure_vector<uint8_t>& record_buf, Connection_Sequence_Numbers* sequence_numbers, - get_cipherstate_fn get_cipherstate); + get_cipherstate_fn get_cipherstate, + bool allow_epoch0_restart); } diff --git a/src/lib/tls/tls_seq_numbers.h b/src/lib/tls/tls_seq_numbers.h index 34f949baa..64e2d0589 100644 --- a/src/lib/tls/tls_seq_numbers.h +++ b/src/lib/tls/tls_seq_numbers.h @@ -31,11 +31,23 @@ class Connection_Sequence_Numbers virtual bool already_seen(uint64_t seq) const = 0; virtual void read_accept(uint64_t seq) = 0; + + virtual void reset() = 0; }; class Stream_Sequence_Numbers final : public Connection_Sequence_Numbers { public: + Stream_Sequence_Numbers() { Stream_Sequence_Numbers::reset(); } + + void reset() override + { + m_write_seq_no = 0; + m_read_seq_no = 0; + m_read_epoch = 0; + m_write_epoch = 0; + } + void new_read_cipher_state() override { m_read_seq_no = 0; m_read_epoch++; } void new_write_cipher_state() override { m_write_seq_no = 0; m_write_epoch++; } @@ -47,17 +59,28 @@ class Stream_Sequence_Numbers final : public Connection_Sequence_Numbers bool already_seen(uint64_t) const override { return false; } void read_accept(uint64_t) override { m_read_seq_no++; } + private: - uint64_t m_write_seq_no = 0; - uint64_t m_read_seq_no = 0; - uint16_t m_read_epoch = 0; - uint16_t m_write_epoch = 0; + uint64_t m_write_seq_no; + uint64_t m_read_seq_no; + uint16_t m_read_epoch; + uint16_t m_write_epoch; }; class Datagram_Sequence_Numbers final : public Connection_Sequence_Numbers { public: - Datagram_Sequence_Numbers() { m_write_seqs[0] = 0; } + Datagram_Sequence_Numbers() { Datagram_Sequence_Numbers::reset(); } + + void reset() override + { + m_write_seqs.clear(); + m_write_seqs[0] = 0; + m_write_epoch = 0; + m_read_epoch = 0; + m_window_highest = 0; + m_window_bits = 0; + } void new_read_cipher_state() override { m_read_epoch++; } diff --git a/src/lib/tls/tls_server.cpp b/src/lib/tls/tls_server.cpp index 76941fd11..33d45b852 100644 --- a/src/lib/tls/tls_server.cpp +++ b/src/lib/tls/tls_server.cpp @@ -303,7 +303,7 @@ Server::Server(Callbacks& callbacks, bool is_datagram, size_t io_buf_sz) : Channel(callbacks, session_manager, rng, policy, - is_datagram, io_buf_sz), + true, is_datagram, io_buf_sz), m_creds(creds) { } @@ -321,7 +321,7 @@ Server::Server(output_fn output, size_t io_buf_sz) : Channel(output, got_data_cb, recv_alert_cb, hs_cb, Channel::handshake_msg_cb(), session_manager, - rng, policy, is_datagram, io_buf_sz), + rng, policy, true, is_datagram, io_buf_sz), m_creds(creds), m_choose_next_protocol(next_proto) { @@ -339,7 +339,7 @@ Server::Server(output_fn output, next_protocol_fn next_proto, bool is_datagram) : Channel(output, got_data_cb, recv_alert_cb, hs_cb, hs_msg_cb, - session_manager, rng, policy, is_datagram), + session_manager, rng, policy, true, is_datagram), m_creds(creds), m_choose_next_protocol(next_proto) { @@ -471,9 +471,12 @@ Protocol_Version select_version(const Botan::TLS::Policy& policy, */ void Server::process_client_hello_msg(const Handshake_State* active_state, Server_Handshake_State& pending_state, - const std::vector<uint8_t>& contents) + const std::vector<uint8_t>& contents, + bool epoch0_restart) { - const bool initial_handshake = !active_state; + BOTAN_ASSERT_IMPLICATION(epoch0_restart, active_state != nullptr, "Can't restart with a dead connection"); + + const bool initial_handshake = epoch0_restart || !active_state; if(initial_handshake == false && policy().allow_client_initiated_renegotiation() == false) { @@ -547,12 +550,26 @@ void Server::process_client_hello_msg(const Handshake_State* active_state, if(pending_state.client_hello()->cookie() != verify.cookie()) { - pending_state.handshake_io().send(verify); + if(epoch0_restart) + pending_state.handshake_io().send_under_epoch(verify, 0); + else + pending_state.handshake_io().send(verify); + pending_state.client_hello(nullptr); pending_state.set_expected_next(CLIENT_HELLO); return; } } + else if(epoch0_restart) + { + throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Reuse of DTLS association requires DTLS cookie secret be set"); + } + } + + if(epoch0_restart) + { + // If we reached here then we were able to verify the cookie + reset_active_association_state(); } secure_renegotiation_check(pending_state.client_hello()); @@ -749,7 +766,8 @@ void Server::process_finished_msg(Server_Handshake_State& pending_state, void Server::process_handshake_msg(const Handshake_State* active_state, Handshake_State& state_base, Handshake_Type type, - const std::vector<uint8_t>& contents) + const std::vector<uint8_t>& contents, + bool epoch0_restart) { Server_Handshake_State& state = dynamic_cast<Server_Handshake_State&>(state_base); state.confirm_transition_to(type); @@ -769,7 +787,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, switch(type) { case CLIENT_HELLO: - return this->process_client_hello_msg(active_state, state, contents); + return this->process_client_hello_msg(active_state, state, contents, epoch0_restart); case CERTIFICATE: return this->process_certificate_msg(state, contents); diff --git a/src/lib/tls/tls_server.h b/src/lib/tls/tls_server.h index e6536934a..c601e8c6e 100644 --- a/src/lib/tls/tls_server.h +++ b/src/lib/tls/tls_server.h @@ -122,11 +122,13 @@ class BOTAN_PUBLIC_API(2,0) Server final : public Channel void process_handshake_msg(const Handshake_State* active_state, Handshake_State& pending_state, Handshake_Type type, - const std::vector<uint8_t>& contents) override; + const std::vector<uint8_t>& contents, + bool epoch0_restart) override; void process_client_hello_msg(const Handshake_State* active_state, Server_Handshake_State& pending_state, - const std::vector<uint8_t>& contents); + const std::vector<uint8_t>& contents, + bool epoch0_restart); void process_certificate_msg(Server_Handshake_State& pending_state, const std::vector<uint8_t>& contents); diff --git a/src/tests/unit_tls.cpp b/src/tests/unit_tls.cpp index c3355b118..c6114b010 100644 --- a/src/tests/unit_tls.cpp +++ b/src/tests/unit_tls.cpp @@ -1032,6 +1032,241 @@ class TLS_Unit_Tests final : public Test BOTAN_REGISTER_TEST("tls", TLS_Unit_Tests); +class DTLS_Reconnection_Test : public Test + { + public: + std::vector<Test::Result> run() override + { + class Test_Callbacks : public Botan::TLS::Callbacks + { + public: + Test_Callbacks(Test::Result& results, + std::vector<uint8_t>& outbound, + std::vector<uint8_t>& recv_buf) : + m_results(results), + m_outbound(outbound), + m_recv(recv_buf) + {} + + void tls_emit_data(const uint8_t bits[], size_t len) override + { + m_outbound.insert(m_outbound.end(), bits, bits + len); + } + + void tls_record_received(uint64_t /*seq*/, const uint8_t bits[], size_t len) override + { + m_recv.insert(m_recv.end(), bits, bits + len); + } + + void tls_alert(Botan::TLS::Alert /*alert*/) override + { + // ignore + } + + bool tls_session_established(const Botan::TLS::Session& /*session*/) override + { + m_results.test_success("Established a session"); + return true; + } + + private: + Test::Result& m_results; + std::vector<uint8_t>& m_outbound; + std::vector<uint8_t>& m_recv; + }; + + class Credentials_PSK : public Botan::Credentials_Manager + { + public: + Botan::SymmetricKey psk(const std::string& type, + const std::string& context, + const std::string&) override + { + if(type == "tls-server" && context == "session-ticket") + { + return Botan::SymmetricKey("AABBCCDDEEFF012345678012345678"); + } + + if(type == "tls-server" && context == "dtls-cookie-secret") + { + return Botan::SymmetricKey("4AEA5EAD279CADEB537A594DA0E9DE3A"); + } + + if(context == "localhost" && type == "tls-client") + { + return Botan::SymmetricKey("20B602D1475F2DF888FCB60D2AE03AFD"); + } + + if(context == "localhost" && type == "tls-server") + { + return Botan::SymmetricKey("20B602D1475F2DF888FCB60D2AE03AFD"); + } + + throw Test_Error("No PSK set for " + type + "/" + context); + } + }; + + class Datagram_PSK_Policy : public Botan::TLS::Policy + { + public: + std::vector<std::string> allowed_macs() const override + { return std::vector<std::string>({"AEAD"}); } + + std::vector<std::string> allowed_key_exchange_methods() const override + { return {"PSK"}; } + + bool allow_tls10() const override { return false; } + bool allow_tls11() const override { return false; } + bool allow_tls12() const override { return false; } + bool allow_dtls10() const override { return false; } + bool allow_dtls12() const override { return true; } + + bool allow_dtls_epoch0_restart() const override { return true; } + }; + + Test::Result result("DTLS reconnection"); + + Datagram_PSK_Policy server_policy; + Datagram_PSK_Policy client_policy; + Credentials_PSK creds; + Botan::TLS::Session_Manager_In_Memory server_sessions(rng()); + //Botan::TLS::Session_Manager_In_Memory client_sessions(rng()); + Botan::TLS::Session_Manager_Noop client_sessions; + + std::vector<uint8_t> s2c, server_recv; + Test_Callbacks server_callbacks(result, s2c, server_recv); + Botan::TLS::Server server(server_callbacks, server_sessions, creds, server_policy, rng(), true); + + std::vector<uint8_t> c1_c2s, client1_recv; + Test_Callbacks client1_callbacks(result, c1_c2s, client1_recv); + Botan::TLS::Client client1(client1_callbacks, client_sessions, creds, client_policy, rng(), + Botan::TLS::Server_Information("localhost"), + Botan::TLS::Protocol_Version::latest_dtls_version()); + + bool c1_to_server_sent = false; + bool server_to_c1_sent = false; + + const std::vector<uint8_t> c1_to_server_magic(16, 0xC1); + const std::vector<uint8_t> server_to_c1_magic(16, 0x42); + + size_t c1_rounds = 0; + for(;;) + { + c1_rounds++; + + if(c1_rounds > 64) + { + result.test_failure("Still spinning in client1 loop after 64 rounds"); + return {result}; + } + + if(c1_c2s.size() > 0) + { + std::vector<uint8_t> input; + std::swap(c1_c2s, input); + server.received_data(input.data(), input.size()); + continue; + } + + if(s2c.size() > 0) + { + std::vector<uint8_t> input; + std::swap(s2c, input); + client1.received_data(input.data(), input.size()); + continue; + } + + if(!c1_to_server_sent && client1.is_active()) + { + client1.send(c1_to_server_magic); + c1_to_server_sent = true; + } + + if(!server_to_c1_sent && server.is_active()) + { + server.send(server_to_c1_magic); + } + + if(server_recv.size() > 0 && client1_recv.size() > 0) + { + result.test_eq("Expected message from client1", server_recv, c1_to_server_magic); + result.test_eq("Expected message to client1", client1_recv, server_to_c1_magic); + break; + } + } + + // Now client1 "goes away" (goes silent) and new client + // connects to same server context (ie due to reuse of client source port) + // See RFC 6347 section 4.2.8 + + server_recv.clear(); + s2c.clear(); + + std::vector<uint8_t> c2_c2s, client2_recv; + Test_Callbacks client2_callbacks(result, c2_c2s, client2_recv); + Botan::TLS::Client client2(client2_callbacks, client_sessions, creds, client_policy, rng(), + Botan::TLS::Server_Information("localhost"), + Botan::TLS::Protocol_Version::latest_dtls_version()); + + bool c2_to_server_sent = false; + bool server_to_c2_sent = false; + + const std::vector<uint8_t> c2_to_server_magic(16, 0xC2); + const std::vector<uint8_t> server_to_c2_magic(16, 0x66); + + size_t c2_rounds = 0; + + for(;;) + { + c2_rounds++; + + if(c2_rounds > 64) + { + result.test_failure("Still spinning in client2 loop after 64 rounds"); + return {result}; + } + + if(c2_c2s.size() > 0) + { + std::vector<uint8_t> input; + std::swap(c2_c2s, input); + server.received_data(input.data(), input.size()); + continue; + } + + if(s2c.size() > 0) + { + std::vector<uint8_t> input; + std::swap(s2c, input); + client2.received_data(input.data(), input.size()); + continue; + } + + if(!c2_to_server_sent && client2.is_active()) + { + client2.send(c2_to_server_magic); + c2_to_server_sent = true; + } + + if(!server_to_c2_sent && server.is_active()) + { + server.send(server_to_c2_magic); + } + + if(server_recv.size() > 0 && client2_recv.size() > 0) + { + result.test_eq("Expected message from client2", server_recv, c2_to_server_magic); + result.test_eq("Expected message to client2", client2_recv, server_to_c2_magic); + break; + } + } + + return {result}; + } + }; + +BOTAN_REGISTER_TEST("tls_dtls_reconnect", DTLS_Reconnection_Test); + #endif } |