diff options
author | Jack Lloyd <[email protected]> | 2015-10-25 22:25:40 -0400 |
---|---|---|
committer | Jack Lloyd <[email protected]> | 2015-10-25 22:25:40 -0400 |
commit | b2da74ca508745f00bb3d6b35cbe34d5031e27e7 (patch) | |
tree | 032fafd34f178af3b66877d52897f2e14359adaf /src | |
parent | 2d078053b1ac7c1e2316892d8634c386288ee159 (diff) |
TLS improvements
Use constant time operations when checking CBC padding in TLS decryption
Fix a bug in decoding ClientHellos that prevented DTLS rehandshakes
from working: on decode the session id and hello cookie would be
swapped, causing confusion between client and server.
Various changes in the service of finding the above DTLS bug that
should have been done before now anyway - better control of handshake
timeouts (via TLS::Policy), better reporting of handshake state in the
case of an error, and finally expose the facility for per-message
application callbacks.
Diffstat (limited to 'src')
-rw-r--r-- | src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp | 3 | ||||
-rw-r--r-- | src/lib/rng/rng.h | 8 | ||||
-rw-r--r-- | src/lib/tls/msg_client_hello.cpp | 6 | ||||
-rw-r--r-- | src/lib/tls/tls_channel.cpp | 38 | ||||
-rw-r--r-- | src/lib/tls/tls_channel.h | 17 | ||||
-rw-r--r-- | src/lib/tls/tls_client.cpp | 53 | ||||
-rw-r--r-- | src/lib/tls/tls_client.h | 15 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_io.cpp | 97 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_io.h | 16 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_msg.h | 2 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_state.cpp | 110 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_state.h | 6 | ||||
-rw-r--r-- | src/lib/tls/tls_magic.h | 2 | ||||
-rw-r--r-- | src/lib/tls/tls_policy.cpp | 13 | ||||
-rw-r--r-- | src/lib/tls/tls_policy.h | 15 | ||||
-rw-r--r-- | src/lib/tls/tls_record.cpp | 85 | ||||
-rw-r--r-- | src/lib/tls/tls_server.cpp | 58 | ||||
-rw-r--r-- | src/lib/tls/tls_server.h | 14 | ||||
-rw-r--r-- | src/lib/utils/ct_utils.h | 36 | ||||
-rw-r--r-- | src/tests/unit_tls.cpp | 551 |
20 files changed, 855 insertions, 290 deletions
diff --git a/src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp b/src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp index 6b3bce0aa..5ff288db2 100644 --- a/src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp +++ b/src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp @@ -28,8 +28,7 @@ secure_vector<byte> EME_PKCS1v15::pad(const byte in[], size_t inlen, out[0] = 0x02; for(size_t j = 1; j != olen - inlen - 1; ++j) - while(out[j] == 0) - out[j] = rng.next_byte(); + out[j] = rng.next_nonzero_byte(); buffer_insert(out, olen - inlen, in, inlen); return out; diff --git a/src/lib/rng/rng.h b/src/lib/rng/rng.h index 6ee67f66f..261880d5d 100644 --- a/src/lib/rng/rng.h +++ b/src/lib/rng/rng.h @@ -75,6 +75,14 @@ class BOTAN_DLL RandomNumberGenerator */ byte next_byte() { return get_random<byte>(); } + byte next_nonzero_byte() + { + byte b = next_byte(); + while(b == 0) + b = next_byte(); + return b; + } + /** * Check whether this RNG is seeded. * @return true if this RNG was already seeded, false otherwise. diff --git a/src/lib/tls/msg_client_hello.cpp b/src/lib/tls/msg_client_hello.cpp index 82ba6f4f6..77bdc5cf5 100644 --- a/src/lib/tls/msg_client_hello.cpp +++ b/src/lib/tls/msg_client_hello.cpp @@ -1,6 +1,6 @@ /* * TLS Hello Request and Client Hello Messages -* (C) 2004-2011 Jack Lloyd +* (C) 2004-2011,2015 Jack Lloyd * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -210,11 +210,11 @@ Client_Hello::Client_Hello(const std::vector<byte>& buf) m_random = reader.get_fixed<byte>(32); + m_session_id = reader.get_range<byte>(1, 0, 32); + if(m_version.is_datagram_protocol()) m_hello_cookie = reader.get_range<byte>(1, 0, 255); - m_session_id = reader.get_range<byte>(1, 0, 32); - m_suites = reader.get_range_vector<u16bit>(2, 1, 32767); m_comp_methods = reader.get_range_vector<byte>(1, 1, 255); diff --git a/src/lib/tls/tls_channel.cpp b/src/lib/tls/tls_channel.cpp index e2b1aad9d..5dfcec34e 100644 --- a/src/lib/tls/tls_channel.cpp +++ b/src/lib/tls/tls_channel.cpp @@ -23,17 +23,21 @@ Channel::Channel(output_fn output_fn, data_cb data_cb, alert_cb alert_cb, handshake_cb handshake_cb, + handshake_msg_cb handshake_msg_cb, Session_Manager& session_manager, RandomNumberGenerator& rng, + const Policy& policy, bool is_datagram, size_t reserved_io_buffer_size) : m_is_datagram(is_datagram), - m_handshake_cb(handshake_cb), m_data_cb(data_cb), m_alert_cb(alert_cb), m_output_fn(output_fn), - m_rng(rng), - m_session_manager(session_manager) + m_handshake_cb(handshake_cb), + m_handshake_msg_cb(handshake_msg_cb), + m_session_manager(session_manager), + m_policy(policy), + m_rng(rng) { /* epoch 0 is plaintext, thus null cipher state */ m_write_cipher_states[0] = nullptr; @@ -66,20 +70,16 @@ Connection_Sequence_Numbers& Channel::sequence_numbers() const std::shared_ptr<Connection_Cipher_State> Channel::read_cipher_state_epoch(u16bit epoch) const { auto i = m_read_cipher_states.find(epoch); - - BOTAN_ASSERT(i != m_read_cipher_states.end(), - "Have a cipher state for the specified epoch"); - + if(i == m_read_cipher_states.end()) + throw Internal_Error("TLS::Channel No read cipherstate for epoch " + std::to_string(epoch)); return i->second; } std::shared_ptr<Connection_Cipher_State> Channel::write_cipher_state_epoch(u16bit epoch) const { auto i = m_write_cipher_states.find(epoch); - - BOTAN_ASSERT(i != m_write_cipher_states.end(), - "Have a cipher state for the specified epoch"); - + if(i == m_write_cipher_states.end()) + throw Internal_Error("TLS::Channel No write cipherstate for epoch " + std::to_string(epoch)); return i->second; } @@ -120,17 +120,17 @@ Handshake_State& Channel::create_handshake_state(Protocol_Version version) std::unique_ptr<Handshake_IO> io; if(version.is_datagram_protocol()) { - // default MTU is IPv6 min MTU minus UDP/IP headers (TODO: make configurable) - const u16bit mtu = 1280 - 40 - 8; - io.reset(new Datagram_Handshake_IO( std::bind(&Channel::send_record_under_epoch, this, _1, _2, _3), sequence_numbers(), - mtu)); + m_policy.dtls_default_mtu(), + m_policy.dtls_initial_timeout(), + m_policy.dtls_maximum_timeout())); } else - io.reset(new Stream_Handshake_IO( - std::bind(&Channel::send_record, this, _1, _2))); + { + io.reset(new Stream_Handshake_IO(std::bind(&Channel::send_record, this, _1, _2))); + } m_pending_state.reset(new_handshake_state(io.release())); @@ -333,12 +333,13 @@ size_t Channel::received_data(const byte input[], size_t input_size) if(record.size() > max_fragment_size) throw TLS_Exception(Alert::RECORD_OVERFLOW, - "Plaintext record is too large"); + "TLS input record is larger than allowed maximum"); if(record_type == HANDSHAKE || record_type == CHANGE_CIPHER_SPEC) { if(!m_pending_state) { + // No pending handshake, possibly new: if(record_version.is_datagram_protocol()) { if(m_sequence_numbers) @@ -374,6 +375,7 @@ size_t Channel::received_data(const byte input[], size_t input_size) } } + // May have been created in above conditional if(m_pending_state) { m_pending_state->handshake_io().add_record(unlock(record), diff --git a/src/lib/tls/tls_channel.h b/src/lib/tls/tls_channel.h index 4e6874a16..9ef2d17c4 100644 --- a/src/lib/tls/tls_channel.h +++ b/src/lib/tls/tls_channel.h @@ -24,6 +24,7 @@ namespace TLS { class Connection_Cipher_State; class Connection_Sequence_Numbers; class Handshake_State; +class Handshake_Message; /** * Generic interface for TLS endpoint @@ -35,15 +36,18 @@ class BOTAN_DLL Channel typedef std::function<void (const byte[], size_t)> data_cb; typedef std::function<void (Alert, const byte[], size_t)> alert_cb; typedef std::function<bool (const Session&)> handshake_cb; + typedef std::function<void (const Handshake_Message&)> handshake_msg_cb; Channel(output_fn out, data_cb app_data_cb, alert_cb alert_cb, handshake_cb hs_cb, + handshake_msg_cb hs_msg_cb, Session_Manager& session_manager, RandomNumberGenerator& rng, + const Policy& policy, bool is_datagram, - size_t reserved_io_buffer_size); + size_t io_buf_sz = 16*1024); Channel(const Channel&) = delete; @@ -196,6 +200,8 @@ class BOTAN_DLL Channel Handshake_State& create_handshake_state(Protocol_Version version); + void inspect_handshake_message(const Handshake_Message& msg); + void activate_session(); void change_cipher_spec_reader(Connection_Side side); @@ -214,8 +220,11 @@ class BOTAN_DLL Channel Session_Manager& session_manager() { return m_session_manager; } + const Policy& policy() const { return m_policy; } + bool save_session(const Session& session) const { return m_handshake_cb(session); } + handshake_msg_cb get_handshake_msg_cb() const { return m_handshake_msg_cb; } private: size_t maximum_fragment_size() const; @@ -245,14 +254,16 @@ class BOTAN_DLL Channel bool m_is_datagram; /* callbacks */ - handshake_cb m_handshake_cb; data_cb m_data_cb; alert_cb m_alert_cb; output_fn m_output_fn; + handshake_cb m_handshake_cb; + handshake_msg_cb m_handshake_msg_cb; /* external state */ - RandomNumberGenerator& m_rng; Session_Manager& m_session_manager; + const Policy& m_policy; + RandomNumberGenerator& m_rng; /* sequence number state */ std::unique_ptr<Connection_Sequence_Numbers> m_sequence_numbers; diff --git a/src/lib/tls/tls_client.cpp b/src/lib/tls/tls_client.cpp index 9306092ce..82630b7fa 100644 --- a/src/lib/tls/tls_client.cpp +++ b/src/lib/tls/tls_client.cpp @@ -1,6 +1,6 @@ /* * TLS Client -* (C) 2004-2011,2012 Jack Lloyd +* (C) 2004-2011,2012,2015 Jack Lloyd * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -23,8 +23,7 @@ class Client_Handshake_State : public Handshake_State public: // using Handshake_State::Handshake_State; - Client_Handshake_State(Handshake_IO* io, hs_msg_cb cb = hs_msg_cb()) : - Handshake_State(io, cb) {} + Client_Handshake_State(Handshake_IO* io, handshake_msg_cb cb) : Handshake_State(io, cb) {} const Public_Key& get_server_public_Key() const { @@ -55,9 +54,31 @@ Client::Client(output_fn output_fn, const Protocol_Version offer_version, const std::vector<std::string>& next_protos, size_t io_buf_sz) : - Channel(output_fn, proc_cb, alert_cb, handshake_cb, session_manager, rng, - offer_version.is_datagram_protocol(), io_buf_sz), - m_policy(policy), + Channel(output_fn, proc_cb, alert_cb, handshake_cb, Channel::handshake_msg_cb(), + session_manager, rng, policy, offer_version.is_datagram_protocol(), io_buf_sz), + m_creds(creds), + m_info(info) + { + const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_info.hostname()); + + Handshake_State& state = create_handshake_state(offer_version); + send_client_hello(state, false, offer_version, srp_identifier, next_protos); + } + +Client::Client(output_fn output_fn, + data_cb proc_cb, + alert_cb alert_cb, + handshake_cb handshake_cb, + handshake_msg_cb hs_msg_cb, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + const Server_Information& info, + const Protocol_Version offer_version, + const std::vector<std::string>& next_protos) : + Channel(output_fn, proc_cb, alert_cb, handshake_cb, hs_msg_cb, + session_manager, rng, policy, offer_version.is_datagram_protocol()), m_creds(creds), m_info(info) { @@ -69,7 +90,7 @@ Client::Client(output_fn output_fn, Handshake_State* Client::new_handshake_state(Handshake_IO* io) { - return new Client_Handshake_State(io); // , m_hs_msg_cb); + return new Client_Handshake_State(io, get_handshake_msg_cb()); } std::vector<X509_Certificate> @@ -111,7 +132,7 @@ void Client::send_client_hello(Handshake_State& state_base, state.client_hello(new Client_Hello( state.handshake_io(), state.hash(), - m_policy, + policy(), rng(), secure_renegotiation_data_for_client_hello(), session_info, @@ -128,7 +149,7 @@ void Client::send_client_hello(Handshake_State& state_base, state.handshake_io(), state.hash(), version, - m_policy, + policy(), rng(), secure_renegotiation_data_for_client_hello(), next_protocols, @@ -157,9 +178,9 @@ void Client::process_handshake_msg(const Handshake_State* active_state, if(state.client_hello()) return; - if(m_policy.allow_server_initiated_renegotiation()) + if(policy().allow_server_initiated_renegotiation()) { - if(!secure_renegotiation_supported() && m_policy.allow_insecure_renegotiation() == false) + if(!secure_renegotiation_supported() && policy().allow_insecure_renegotiation() == false) send_warning_alert(Alert::NO_RENEGOTIATION); else this->initiate_handshake(state, false); @@ -263,7 +284,9 @@ void Client::process_handshake_msg(const Handshake_State* active_state, if(state.server_hello()->supports_session_ticket()) state.set_expected_next(NEW_SESSION_TICKET); else + { state.set_expected_next(HANDSHAKE_CCS); + } } else { @@ -282,7 +305,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, "Server replied with later version than in hello"); } - if(!m_policy.acceptable_protocol_version(state.version())) + if(!policy().acceptable_protocol_version(state.version())) { throw TLS_Exception(Alert::PROTOCOL_VERSION, "Server version " + state.version().to_string() + @@ -406,7 +429,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, state.client_kex( new Client_Key_Exchange(state.handshake_io(), state, - m_policy, + policy(), m_creds, state.server_public_key.get(), m_info.hostname(), @@ -426,7 +449,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, state.client_verify( new Certificate_Verify(state.handshake_io(), state, - m_policy, + policy(), rng(), private_key) ); @@ -477,7 +500,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, const std::vector<byte>& session_ticket = state.session_ticket(); if(session_id.empty() && !session_ticket.empty()) - session_id = make_hello_random(rng(), m_policy); + session_id = make_hello_random(rng(), policy()); Session session_info( session_id, diff --git a/src/lib/tls/tls_client.h b/src/lib/tls/tls_client.h index e4e0dc363..b835c013e 100644 --- a/src/lib/tls/tls_client.h +++ b/src/lib/tls/tls_client.h @@ -67,6 +67,20 @@ class BOTAN_DLL Client : public Channel size_t reserved_io_buffer_size = 16*1024 ); + Client(output_fn out, + data_cb app_data_cb, + alert_cb alert_cb, + handshake_cb hs_cb, + handshake_msg_cb hs_msg_cb, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + const Server_Information& server_info = Server_Information(), + const Protocol_Version offer_version = Protocol_Version::latest_tls_version(), + const std::vector<std::string>& next_protocols = {} + ); + const std::string& application_protocol() const { return m_application_protocol; } private: std::vector<X509_Certificate> @@ -88,7 +102,6 @@ class BOTAN_DLL Client : public Channel Handshake_State* new_handshake_state(Handshake_IO* io) override; - const Policy& m_policy; Credentials_Manager& m_creds; const Server_Information m_info; std::string m_application_protocol; diff --git a/src/lib/tls/tls_handshake_io.cpp b/src/lib/tls/tls_handshake_io.cpp index 6286eab08..f39c9f84e 100644 --- a/src/lib/tls/tls_handshake_io.cpp +++ b/src/lib/tls/tls_handshake_io.cpp @@ -1,6 +1,6 @@ /* * TLS Handshake IO -* (C) 2012,2014 Jack Lloyd +* (C) 2012,2014,2015 Jack Lloyd * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -33,6 +33,24 @@ void store_be24(byte out[3], size_t val) out[2] = get_byte<u32bit>(3, val); } +u64bit steady_clock_ms() + { + return std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()).count(); + } + +size_t split_for_mtu(size_t mtu, size_t msg_size) + { + const size_t DTLS_HEADERS_SIZE = 25; // DTLS record+handshake headers + + const size_t parts = (msg_size + mtu) / mtu; + + if(parts + DTLS_HEADERS_SIZE > mtu) + return parts + 1; + + return parts; + } + } Protocol_Version Stream_Handshake_IO::initial_record_version() const @@ -123,41 +141,15 @@ Protocol_Version Datagram_Handshake_IO::initial_record_version() const return Protocol_Version::DTLS_V10; } -namespace { - -// 1 second initial timeout, 60 second max - see RFC 6347 sec 4.2.4.1 -const u64bit INITIAL_TIMEOUT = 1*1000; -const u64bit MAXIMUM_TIMEOUT = 60*1000; - -u64bit steady_clock_ms() +void Datagram_Handshake_IO::retransmit_last_flight() { - return std::chrono::duration_cast<std::chrono::milliseconds>( - std::chrono::steady_clock::now().time_since_epoch()).count(); + const size_t flight_idx = (m_flights.size() == 1) ? 0 : (m_flights.size() - 2); + retransmit_flight(flight_idx); } -} - -bool Datagram_Handshake_IO::timeout_check() +void Datagram_Handshake_IO::retransmit_flight(size_t flight_idx) { - if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty())) - { - /* - If we haven't written anything yet obviously no timeout. - Also no timeout possible if we are mid-flight, - */ - return false; - } - - const u64bit ms_since_write = steady_clock_ms() - m_last_write; - - if(ms_since_write < m_next_timeout) - return false; - - std::vector<u16bit> flight; - if(m_flights.size() == 1) - flight = m_flights.at(0); // lost initial client hello - else - flight = m_flights.at(m_flights.size() - 2); + const std::vector<u16bit>& flight = m_flights.at(flight_idx); BOTAN_ASSERT(flight.size() > 0, "Nonempty flight to retransmit"); @@ -177,8 +169,27 @@ bool Datagram_Handshake_IO::timeout_check() send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits); epoch = msg.epoch; } + } + +bool Datagram_Handshake_IO::timeout_check() + { + if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty())) + { + /* + If we haven't written anything yet obviously no timeout. + Also no timeout possible if we are mid-flight, + */ + return false; + } + + const u64bit ms_since_write = steady_clock_ms() - m_last_write; + + if(ms_since_write < m_next_timeout) + return false; - m_next_timeout = std::min(2 * m_next_timeout, MAXIMUM_TIMEOUT); + retransmit_last_flight(); + + m_next_timeout = std::min(2 * m_next_timeout, m_max_timeout); return true; } @@ -251,7 +262,6 @@ Datagram_Handshake_IO::get_next_record(bool expecting_ccs) if(m_ccs_epochs.count(current_epoch)) return std::make_pair(HANDSHAKE_CCS, std::vector<byte>()); } - return std::make_pair(HANDSHAKE_NONE, std::vector<byte>()); } @@ -376,21 +386,6 @@ Datagram_Handshake_IO::format(const std::vector<byte>& msg, return format_w_seq(msg, type, m_in_message_seq - 1); } -namespace { - -size_t split_for_mtu(size_t mtu, size_t msg_size) - { - const size_t DTLS_HEADERS_SIZE = 25; // DTLS record+handshake headers - - const size_t parts = (msg_size + mtu) / mtu; - - if(parts + DTLS_HEADERS_SIZE > mtu) - return parts + 1; - - return parts; - } - -} std::vector<byte> Datagram_Handshake_IO::send(const Handshake_Message& msg) @@ -411,7 +406,7 @@ Datagram_Handshake_IO::send(const Handshake_Message& msg) m_out_message_seq += 1; m_last_write = steady_clock_ms(); - m_next_timeout = INITIAL_TIMEOUT; + m_next_timeout = m_initial_timeout; return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits); } @@ -425,7 +420,9 @@ std::vector<byte> Datagram_Handshake_IO::send_message(u16bit msg_seq, format_w_seq(msg_bits, msg_type, msg_seq); if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu) + { m_send_hs(epoch, HANDSHAKE, no_fragment); + } else { const size_t parts = split_for_mtu(m_mtu, msg_bits.size()); diff --git a/src/lib/tls/tls_handshake_io.h b/src/lib/tls/tls_handshake_io.h index 00074a744..a1c1c5ce3 100644 --- a/src/lib/tls/tls_handshake_io.h +++ b/src/lib/tls/tls_handshake_io.h @@ -100,8 +100,14 @@ class Datagram_Handshake_IO : public Handshake_IO Datagram_Handshake_IO(writer_fn writer, class Connection_Sequence_Numbers& seq, - u16bit mtu) : - m_seqs(seq), m_flights(1), m_send_hs(writer), m_mtu(mtu) {} + u16bit mtu, u64bit initial_timeout_ms, u64bit max_timeout_ms) : + m_seqs(seq), + m_flights(1), + m_initial_timeout(initial_timeout_ms), + m_max_timeout(max_timeout_ms), + m_send_hs(writer), + m_mtu(mtu) + {} Protocol_Version initial_record_version() const override; @@ -120,6 +126,9 @@ class Datagram_Handshake_IO : public Handshake_IO std::pair<Handshake_Type, std::vector<byte>> get_next_record(bool expecting_ccs) override; private: + void retransmit_flight(size_t flight); + void retransmit_last_flight(); + std::vector<byte> format_fragment( const byte fragment[], size_t fragment_len, @@ -183,6 +192,9 @@ class Datagram_Handshake_IO : public Handshake_IO std::vector<std::vector<u16bit>> m_flights; std::map<u16bit, Message_Info> m_flight_data; + u64bit m_initial_timeout = 0; + u64bit m_max_timeout = 0; + u64bit m_last_write = 0; u64bit m_next_timeout = 0; diff --git a/src/lib/tls/tls_handshake_msg.h b/src/lib/tls/tls_handshake_msg.h index 6937d4f2c..7e527abf4 100644 --- a/src/lib/tls/tls_handshake_msg.h +++ b/src/lib/tls/tls_handshake_msg.h @@ -22,6 +22,8 @@ namespace TLS { class BOTAN_DLL Handshake_Message { public: + std::string type_string() const; + virtual Handshake_Type type() const = 0; virtual std::vector<byte> serialize() const = 0; diff --git a/src/lib/tls/tls_handshake_state.cpp b/src/lib/tls/tls_handshake_state.cpp index cbbca3a0d..f885d3b08 100644 --- a/src/lib/tls/tls_handshake_state.cpp +++ b/src/lib/tls/tls_handshake_state.cpp @@ -1,6 +1,6 @@ /* * TLS Handshaking -* (C) 2004-2006,2011,2012 Jack Lloyd +* (C) 2004-2006,2011,2012,2015 Jack Lloyd * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -13,6 +13,67 @@ namespace Botan { namespace TLS { +std::string Handshake_Message::type_string() const + { + return handshake_type_to_string(type()); + } + +const char* handshake_type_to_string(Handshake_Type type) + { + switch(type) + { + case HELLO_VERIFY_REQUEST: + return "hello_verify_request"; + + case HELLO_REQUEST: + return "hello_request"; + + case CLIENT_HELLO: + return "client_hello"; + + case SERVER_HELLO: + return "server_hello"; + + case CERTIFICATE: + return "certificate"; + + case CERTIFICATE_URL: + return "certificate_url"; + + case CERTIFICATE_STATUS: + return "certificate_status"; + + case SERVER_KEX: + return "server_key_exchange"; + + case CERTIFICATE_REQUEST: + return "certificate_request"; + + case SERVER_HELLO_DONE: + return "server_hello_done"; + + case CERTIFICATE_VERIFY: + return "certificate_verify"; + + case CLIENT_KEX: + return "client_key_exchange"; + + case NEW_SESSION_TICKET: + return "new_session_ticket"; + + case HANDSHAKE_CCS: + return "change_cipher_spec"; + + case FINISHED: + return "finished"; + + case HANDSHAKE_NONE: + return "invalid"; + } + + throw Internal_Error("Unknown TLS handshake message type " + std::to_string(type)); + } + namespace { u32bit bitmask_for_handshake_type(Handshake_Type type) @@ -25,9 +86,6 @@ u32bit bitmask_for_handshake_type(Handshake_Type type) case HELLO_REQUEST: return (1 << 1); - /* - * Same code point for both client hello styles - */ case CLIENT_HELLO: return (1 << 2); @@ -75,12 +133,48 @@ u32bit bitmask_for_handshake_type(Handshake_Type type) throw Internal_Error("Unknown handshake type " + std::to_string(type)); } +std::string handshake_mask_to_string(u32bit mask) + { + const Handshake_Type types[] = { + HELLO_VERIFY_REQUEST, + HELLO_REQUEST, + CLIENT_HELLO, + CERTIFICATE, + CERTIFICATE_URL, + CERTIFICATE_STATUS, + SERVER_KEX, + CERTIFICATE_REQUEST, + SERVER_HELLO_DONE, + CERTIFICATE_VERIFY, + CLIENT_KEX, + NEW_SESSION_TICKET, + HANDSHAKE_CCS, + FINISHED + }; + + std::ostringstream o; + bool empty = true; + + for(auto&& t : types) + { + if(mask & bitmask_for_handshake_type(t)) + { + if(!empty) + o << ","; + o << handshake_type_to_string(t); + empty = false; + } + } + + return o.str(); + } + } /* * Initialize the SSL/TLS Handshake State */ -Handshake_State::Handshake_State(Handshake_IO* io, hs_msg_cb cb) : +Handshake_State::Handshake_State(Handshake_IO* io, handshake_msg_cb cb) : m_msg_callback(cb), m_handshake_io(io), m_version(m_handshake_io->initial_record_version()) @@ -196,10 +290,10 @@ void Handshake_State::confirm_transition_to(Handshake_Type handshake_msg) const bool ok = (m_hand_expecting_mask & mask); // overlap? if(!ok) - throw Unexpected_Message("Unexpected state transition in handshake, got " + + throw Unexpected_Message("Unexpected state transition in handshake, got type " + std::to_string(handshake_msg) + - " expected " + std::to_string(m_hand_expecting_mask) + - " received " + std::to_string(m_hand_received_mask)); + " expected " + handshake_mask_to_string(m_hand_expecting_mask) + + " received " + handshake_mask_to_string(m_hand_received_mask)); /* We don't know what to expect next, so force a call to set_expected_next; if it doesn't happen, the next transition diff --git a/src/lib/tls/tls_handshake_state.h b/src/lib/tls/tls_handshake_state.h index 3b60178b4..6260b090f 100644 --- a/src/lib/tls/tls_handshake_state.h +++ b/src/lib/tls/tls_handshake_state.h @@ -45,9 +45,9 @@ class Finished; class Handshake_State { public: - typedef std::function<void (const Handshake_Message&)> hs_msg_cb; + typedef std::function<void (const Handshake_Message&)> handshake_msg_cb; - Handshake_State(Handshake_IO* io, hs_msg_cb cb); + Handshake_State(Handshake_IO* io, handshake_msg_cb cb); virtual ~Handshake_State(); @@ -170,7 +170,7 @@ class Handshake_State private: - hs_msg_cb m_msg_callback; + handshake_msg_cb m_msg_callback; std::unique_ptr<Handshake_IO> m_handshake_io; diff --git a/src/lib/tls/tls_magic.h b/src/lib/tls/tls_magic.h index 882e59158..6db908b08 100644 --- a/src/lib/tls/tls_magic.h +++ b/src/lib/tls/tls_magic.h @@ -57,6 +57,8 @@ enum Handshake_Type { HANDSHAKE_NONE = 255 // Null value }; +const char* handshake_type_to_string(Handshake_Type t); + enum Compression_Method { NO_COMPRESSION = 0x00, DEFLATE_COMPRESSION = 0x01 diff --git a/src/lib/tls/tls_policy.cpp b/src/lib/tls/tls_policy.cpp index f50cf1f3e..d8dd2c828 100644 --- a/src/lib/tls/tls_policy.cpp +++ b/src/lib/tls/tls_policy.cpp @@ -20,9 +20,9 @@ std::vector<std::string> Policy::allowed_ciphers() const return { //"AES-256/OCB(12)", //"AES-128/OCB(12)", - "ChaCha20Poly1305", "AES-256/GCM", "AES-128/GCM", + "ChaCha20Poly1305", "AES-256/CCM", "AES-128/CCM", "AES-256/CCM(8)", @@ -35,7 +35,6 @@ std::vector<std::string> Policy::allowed_ciphers() const //"Camellia-128", //"SEED" //"3DES", - //"RC4", }; } @@ -175,6 +174,16 @@ bool Policy::include_time_in_hello_random() const { return true; } bool Policy::hide_unknown_users() const { return false; } bool Policy::server_uses_own_ciphersuite_preferences() const { return true; } +// 1 second initial timeout, 60 second max - see RFC 6347 sec 4.2.4.1 +size_t Policy::dtls_initial_timeout() const { return 1*1000; } +size_t Policy::dtls_maximum_timeout() const { return 60*1000; } + +size_t Policy::dtls_default_mtu() const + { + // default MTU is IPv6 min MTU minus UDP/IP headers + return 1280 - 40 - 8; + } + std::vector<u16bit> Policy::srtp_profiles() const { return std::vector<u16bit>(); diff --git a/src/lib/tls/tls_policy.h b/src/lib/tls/tls_policy.h index 581d04bcd..c3f8f1ee2 100644 --- a/src/lib/tls/tls_policy.h +++ b/src/lib/tls/tls_policy.h @@ -13,6 +13,7 @@ #include <botan/x509cert.h> #include <botan/dl_group.h> #include <vector> +#include <sstream> namespace Botan { @@ -173,6 +174,12 @@ class BOTAN_DLL Policy virtual std::vector<u16bit> ciphersuite_list(Protocol_Version version, bool have_srp) const; + virtual size_t dtls_default_mtu() const; + + virtual size_t dtls_initial_timeout() const; + + virtual size_t dtls_maximum_timeout() const; + virtual void print(std::ostream& o) const; virtual ~Policy() {} @@ -299,6 +306,14 @@ class BOTAN_DLL Text_Policy : public Policy return r; } + void set(const std::string& k, const std::string& v) { m_kv[k] = v; } + + Text_Policy(const std::string& s) + { + std::istringstream iss(s); + m_kv = read_cfg(iss); + } + Text_Policy(std::istream& in) { m_kv = read_cfg(in); diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp index 71542de16..e38b26547 100644 --- a/src/lib/tls/tls_record.cpp +++ b/src/lib/tls/tls_record.cpp @@ -12,6 +12,7 @@ #include <botan/internal/tls_seq_numbers.h> #include <botan/internal/tls_session_key.h> #include <botan/internal/rounding.h> +#include <botan/internal/ct_utils.h> #include <botan/rng.h> namespace Botan { @@ -284,31 +285,26 @@ size_t fill_buffer_to(secure_vector<byte>& readbuf, * * Returning 0 in the error case should ensure the MAC check will fail. * This approach is suggested in section 6.2.3.2 of RFC 5246. -* -* Also returns 0 if block_size == 0, so can be safely called with a -* stream cipher in use. -* -* @fixme This should run in constant time */ -size_t tls_padding_check(const byte record[], size_t record_len) +u16bit tls_padding_check(const byte record[], size_t record_len) { - const size_t padding_length = record[(record_len-1)]; - - if(padding_length >= record_len) - return 0; - /* * TLS v1.0 and up require all the padding bytes be the same value * and allows up to 255 bytes. */ - const size_t pad_start = record_len - padding_length - 1; - volatile size_t cmp = 0; + const byte pad_byte = record[(record_len-1)]; - for(size_t i = 0; i != padding_length; ++i) - cmp += record[pad_start + i] ^ padding_length; + byte pad_invalid = 0; + for(size_t i = 0; i != record_len; ++i) + { + const size_t left = record_len - i - 2; + const byte delim_mask = CT::is_less<u16bit>(left, pad_byte) & 0xFF; + pad_invalid |= (delim_mask & (record[i] ^ pad_byte)); + } - return cmp ? 0 : padding_length + 1; + u16bit pad_invalid_mask = CT::expand_mask<u16bit>(pad_invalid); + return CT::select<u16bit>(pad_invalid_mask, 0, pad_byte + 1); } void cbc_decrypt_record(byte record_contents[], size_t record_len, @@ -375,38 +371,39 @@ void decrypt_record(secure_vector<byte>& output, else { // GenericBlockCipher case + BlockCipher* bc = cs.block_cipher(); + BOTAN_ASSERT(bc != nullptr, "No cipher state set but needed to decrypt"); - volatile bool padding_bad = false; - size_t pad_size = 0; + const size_t mac_size = cs.mac_size(); + const size_t iv_size = cs.iv_size(); - if(BlockCipher* bc = cs.block_cipher()) - { - cbc_decrypt_record(record_contents, record_len, cs, *bc); + // This early exit does not leak info because all the values are public + if((record_len < mac_size + iv_size) || (record_len % cs.block_size() != 0)) + throw Decoding_Error("Record sent with invalid length"); - pad_size = tls_padding_check(record_contents, record_len); + CT::poison(record_contents, record_len); - padding_bad = (pad_size == 0); - } - else - { - throw Internal_Error("No cipher state set but needed to decrypt"); - } + cbc_decrypt_record(record_contents, record_len, cs, *bc); - const size_t mac_size = cs.mac_size(); - const size_t iv_size = cs.iv_size(); + // 0 if padding was invalid, otherwise 1 + padding_bytes + u16bit pad_size = tls_padding_check(record_contents, record_len); - const size_t mac_pad_iv_size = mac_size + pad_size + iv_size; + // This mask is zero if there is not enough room in the packet + const u16bit size_ok_mask = CT::is_less<u16bit>(mac_size + pad_size + iv_size, record_len); + pad_size &= size_ok_mask; - if(record_len < mac_pad_iv_size) - throw Decoding_Error("Record sent with invalid length"); + CT::unpoison(record_contents, record_len); - const byte* plaintext_block = &record_contents[iv_size]; - const u16bit plaintext_length = record_len - mac_pad_iv_size; + /* + This is unpoisoned sooner than it should. The pad_size leaks to plaintext_length and + then to the timing channel in the MAC computation described in the Lucky 13 paper. + */ + CT::unpoison(pad_size); - cs.mac()->update( - cs.format_ad(record_sequence, record_type, record_version, plaintext_length) - ); + const byte* plaintext_block = &record_contents[iv_size]; + const u16bit plaintext_length = record_len - mac_size - iv_size - pad_size; + cs.mac()->update(cs.format_ad(record_sequence, record_type, record_version, plaintext_length)); cs.mac()->update(plaintext_block, plaintext_length); std::vector<byte> mac_buf(mac_size); @@ -414,12 +411,16 @@ void decrypt_record(secure_vector<byte>& output, const size_t mac_offset = record_len - (mac_size + pad_size); - const bool mac_bad = !same_mem(&record_contents[mac_offset], mac_buf.data(), mac_size); + const bool mac_ok = same_mem(&record_contents[mac_offset], mac_buf.data(), mac_size); - if(mac_bad || padding_bad) - throw TLS_Exception(Alert::BAD_RECORD_MAC, "Message authentication failure"); + const u16bit ok_mask = size_ok_mask & CT::expand_mask<u16bit>(mac_ok) & CT::expand_mask<u16bit>(pad_size); - output.assign(plaintext_block, plaintext_block + plaintext_length); + CT::unpoison(ok_mask); + + if(ok_mask) + output.assign(plaintext_block, plaintext_block + plaintext_length); + else + throw TLS_Exception(Alert::BAD_RECORD_MAC, "Message authentication failure"); } } diff --git a/src/lib/tls/tls_server.cpp b/src/lib/tls/tls_server.cpp index 330135e63..774827346 100644 --- a/src/lib/tls/tls_server.cpp +++ b/src/lib/tls/tls_server.cpp @@ -21,8 +21,7 @@ class Server_Handshake_State : public Handshake_State public: // using Handshake_State::Handshake_State; - Server_Handshake_State(Handshake_IO* io, hs_msg_cb cb = hs_msg_cb()) : - Handshake_State(io, cb) {} + Server_Handshake_State(Handshake_IO* io, handshake_msg_cb cb) : Handshake_State(io, cb) {} // Used by the server only, in case of RSA key exchange. Not owned Private_Key* server_rsa_kex_key = nullptr; @@ -215,9 +214,26 @@ Server::Server(output_fn output, next_protocol_fn next_proto, bool is_datagram, size_t io_buf_sz) : - Channel(output, data_cb, alert_cb, handshake_cb, - session_manager, rng, is_datagram, io_buf_sz), - m_policy(policy), + Channel(output, data_cb, alert_cb, handshake_cb, Channel::handshake_msg_cb(), + session_manager, rng, policy, is_datagram, io_buf_sz), + m_creds(creds), + m_choose_next_protocol(next_proto) + { + } + +Server::Server(output_fn output, + data_cb data_cb, + alert_cb alert_cb, + handshake_cb handshake_cb, + handshake_msg_cb hs_msg_cb, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + next_protocol_fn next_proto, + bool is_datagram) : + Channel(output, data_cb, alert_cb, handshake_cb, hs_msg_cb, + session_manager, rng, policy, is_datagram), m_creds(creds), m_choose_next_protocol(next_proto) { @@ -225,7 +241,9 @@ Server::Server(output_fn output, Handshake_State* Server::new_handshake_state(Handshake_IO* io) { - std::unique_ptr<Handshake_State> state(new Server_Handshake_State(io)); + std::unique_ptr<Handshake_State> state( + new Server_Handshake_State(io, get_handshake_msg_cb())); + state->set_expected_next(CLIENT_HELLO); return state.release(); } @@ -278,7 +296,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, { const bool initial_handshake = !active_state; - if(!m_policy.allow_insecure_renegotiation() && + if(!policy().allow_insecure_renegotiation() && !(initial_handshake || secure_renegotiation_supported())) { send_warning_alert(Alert::NO_RENEGOTIATION); @@ -292,7 +310,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, Protocol_Version negotiated_version; const Protocol_Version latest_supported = - m_policy.latest_supported_version(client_version.is_datagram_protocol()); + policy().latest_supported_version(client_version.is_datagram_protocol()); if((initial_handshake && client_version.known_version()) || (!initial_handshake && client_version == active_state->version())) @@ -334,7 +352,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, negotiated_version = latest_supported; } - if(!m_policy.acceptable_protocol_version(negotiated_version)) + if(!policy().acceptable_protocol_version(negotiated_version)) { throw TLS_Exception(Alert::PROTOCOL_VERSION, "Client version " + negotiated_version.to_string() + @@ -359,7 +377,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, session_manager(), m_creds, state.client_hello(), - std::chrono::seconds(m_policy.session_ticket_lifetime())); + std::chrono::seconds(policy().session_ticket_lifetime())); bool have_session_ticket_key = false; @@ -387,7 +405,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, state.server_hello(new Server_Hello( state.handshake_io(), state.hash(), - m_policy, + policy(), rng(), secure_renegotiation_data_for_server_hello(), *state.client_hello(), @@ -423,7 +441,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, new New_Session_Ticket(state.handshake_io(), state.hash(), session_info.encrypt(ticket_key, rng()), - m_policy.session_ticket_lifetime()) + policy().session_ticket_lifetime()) ); } catch(...) {} @@ -469,14 +487,14 @@ void Server::process_handshake_msg(const Handshake_State* active_state, state.server_hello(new Server_Hello( state.handshake_io(), state.hash(), - m_policy, + policy(), rng(), secure_renegotiation_data_for_server_hello(), *state.client_hello(), - make_hello_random(rng(), m_policy), // new session ID + make_hello_random(rng(), policy()), // new session ID state.version(), - choose_ciphersuite(m_policy, state.version(), m_creds, cert_chains, state.client_hello()), - choose_compression(m_policy, state.client_hello()->compression_methods()), + choose_ciphersuite(policy(), state.version(), m_creds, cert_chains, state.client_hello()), + choose_compression(policy(), state.client_hello()->compression_methods()), have_session_ticket_key, m_next_protocol) ); @@ -516,7 +534,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, else { state.server_kex(new Server_Key_Exchange(state.handshake_io(), - state, m_policy, + state, policy(), m_creds, rng(), private_key)); } @@ -534,7 +552,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, { state.cert_req( new Certificate_Req(state.handshake_io(), state.hash(), - m_policy, client_auth_CAs, state.version())); + policy(), client_auth_CAs, state.version())); state.set_expected_next(CERTIFICATE); } @@ -565,7 +583,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, state.client_kex( new Client_Key_Exchange(contents, state, state.server_rsa_kex_key, - m_creds, m_policy, rng()) + m_creds, policy(), rng()) ); state.compute_session_keys(); @@ -649,7 +667,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, new New_Session_Ticket(state.handshake_io(), state.hash(), session_info.encrypt(ticket_key, rng()), - m_policy.session_ticket_lifetime()) + policy().session_ticket_lifetime()) ); } catch(...) {} diff --git a/src/lib/tls/tls_server.h b/src/lib/tls/tls_server.h index 4f2a11ba4..ffe1111bc 100644 --- a/src/lib/tls/tls_server.h +++ b/src/lib/tls/tls_server.h @@ -40,6 +40,19 @@ class BOTAN_DLL Server : public Channel size_t reserved_io_buffer_size = 16*1024 ); + Server(output_fn output, + data_cb data_cb, + alert_cb alert_cb, + handshake_cb handshake_cb, + handshake_msg_cb hs_msg_cb, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + next_protocol_fn next_proto = next_protocol_fn(), + bool is_datagram = false + ); + /** * Return the protocol notification set by the client (using the * NPN extension) for this connection, if any. This value is not @@ -62,7 +75,6 @@ class BOTAN_DLL Server : public Channel Handshake_State* new_handshake_state(Handshake_IO* io) override; - const Policy& m_policy; Credentials_Manager& m_creds; next_protocol_fn m_choose_next_protocol; diff --git a/src/lib/utils/ct_utils.h b/src/lib/utils/ct_utils.h index 52a3bc388..2ea07b382 100644 --- a/src/lib/utils/ct_utils.h +++ b/src/lib/utils/ct_utils.h @@ -51,6 +51,12 @@ inline void unpoison(T* p, size_t n) #endif } +template<typename T> +inline void unpoison(T& p) + { + unpoison(&p, 1); + } + /* * T should be an unsigned machine integer type * Expand to a mask used for other operations @@ -90,6 +96,16 @@ inline T is_equal(T x, T y) } template<typename T> +inline T is_less(T x, T y) + { + /* + This expands to a constant time sequence with GCC 5.2.0 on x86-64 + but something more complicated may be needed for portable const time. + */ + return expand_mask<T>(x < y); + } + +template<typename T> inline void conditional_copy_mem(T value, T* to, const T* from0, @@ -102,6 +118,26 @@ inline void conditional_copy_mem(T value, to[i] = CT::select(mask, from0[i], from1[i]); } +template<typename T> +inline T expand_top_bit(T a) + { + return expand_mask<T>(a >> (sizeof(T)*8-1)); + } + +template<typename T> +inline T max(T a, T b) + { + const T a_larger = b - a; // negative if a is larger + return select(expand_top_bit(a), a, b); + } + +template<typename T> +inline T min(T a, T b) + { + const T a_larger = b - a; // negative if a is larger + return select(expand_top_bit(b), b, a); + } + } } diff --git a/src/tests/unit_tls.cpp b/src/tests/unit_tls.cpp index 116eb2cdf..d4abef119 100644 --- a/src/tests/unit_tls.cpp +++ b/src/tests/unit_tls.cpp @@ -11,6 +11,7 @@ #include <botan/tls_server.h> #include <botan/tls_client.h> +#include <botan/tls_handshake_msg.h> #include <botan/pkcs10.h> #include <botan/x509self.h> #include <botan/rsa.h> @@ -21,6 +22,7 @@ #include <iostream> #include <vector> #include <memory> +#include <thread> using namespace Botan; @@ -155,158 +157,441 @@ Credentials_Manager* create_creds() return new Credentials_Manager_Test(server_cert, ca_cert, server_key); } -size_t basic_test_handshake(RandomNumberGenerator& rng, - TLS::Protocol_Version offer_version, - Credentials_Manager& creds, - TLS::Policy& policy) +std::function<void (const byte[], size_t)> queue_inserter(std::vector<byte>& q) { - TLS::Session_Manager_In_Memory server_sessions(rng); - TLS::Session_Manager_In_Memory client_sessions(rng); - - std::vector<byte> c2s_q, s2c_q, c2s_data, s2c_data; + return [&](const byte buf[], size_t sz) { q.insert(q.end(), buf, buf + sz); }; + } - auto handshake_complete = [&](const TLS::Session& session) -> bool +void print_alert(TLS::Alert alert, const byte[], size_t) { - if(session.version() != offer_version) - std::cout << "Offered " << offer_version.to_string() - << " got " << session.version().to_string() << std::endl; - return true; + //std::cout << "Alert " << alert.type_string() << std::endl; }; - auto print_alert = [&](TLS::Alert alert, const byte[], size_t) +void mutate(std::vector<byte>& v, RandomNumberGenerator& rng) { - if(alert.is_valid()) - std::cout << "Server recvd alert " << alert.type_string() << std::endl; - }; + if(v.empty()) + return; - auto save_server_data = [&](const byte buf[], size_t sz) - { - c2s_data.insert(c2s_data.end(), buf, buf+sz); - }; + size_t voff = rng.get_random<size_t>() % v.size(); + v[voff] ^= rng.next_nonzero_byte(); + } - auto save_client_data = [&](const byte buf[], size_t sz) +size_t test_tls_handshake(RandomNumberGenerator& rng, + TLS::Protocol_Version offer_version, + Credentials_Manager& creds, + TLS::Policy& policy) { - s2c_data.insert(s2c_data.end(), buf, buf+sz); - }; + TLS::Session_Manager_In_Memory server_sessions(rng); + TLS::Session_Manager_In_Memory client_sessions(rng); - auto next_protocol_chooser = [&](std::vector<std::string> protos) { - if(protos.size() != 2) - std::cout << "Bad protocol size" << std::endl; - if(protos[0] != "test/1" || protos[1] != "test/2") - std::cout << "Bad protocol values" << std::endl; - return "test/3"; - }; - const std::vector<std::string> protocols_offered = { "test/1", "test/2" }; - - TLS::Server server([&](const byte buf[], size_t sz) - { s2c_q.insert(s2c_q.end(), buf, buf+sz); }, - save_server_data, - print_alert, - handshake_complete, - server_sessions, - creds, - policy, - rng, - next_protocol_chooser, - offer_version.is_datagram_protocol()); - - TLS::Client client([&](const byte buf[], size_t sz) - { c2s_q.insert(c2s_q.end(), buf, buf+sz); }, - save_client_data, - print_alert, - handshake_complete, - client_sessions, - creds, - policy, - rng, - TLS::Server_Information("server.example.com"), - offer_version, - protocols_offered); - - while(true) + for(size_t r = 1; r <= 4; ++r) { - if(client.is_closed() && server.is_closed()) - break; + //std::cout << offer_version.to_string() << " r " << r << "\n"; - if(client.is_active()) - client.send("1"); - if(server.is_active()) - { - if(server.next_protocol() != "test/3") - std::cout << "Wrong protocol " << server.next_protocol() << std::endl; - server.send("2"); - } + bool handshake_done = false; - /* - * Use this as a temp value to hold the queues as otherwise they - * might end up appending more in response to messages during the - * handshake. - */ - std::vector<byte> input; - std::swap(c2s_q, input); + auto handshake_complete = [&](const TLS::Session& session) -> bool { + handshake_done = true; - try - { - server.received_data(input.data(), input.size()); - } - catch(std::exception& e) - { - std::cout << "Server error - " << e.what() << std::endl; - return 1; - } + /* + std::cout << "Session established " << session.version().to_string() << " " + << session.ciphersuite().to_string() << " " << hex_encode(session.session_id()) << "\n"; + */ + + if(session.version() != offer_version) + std::cout << "Offered " << offer_version.to_string() + << " got " << session.version().to_string() << std::endl; + return true; + }; - input.clear(); - std::swap(s2c_q, input); + auto next_protocol_chooser = [&](std::vector<std::string> protos) { + if(protos.size() != 2) + std::cout << "Bad protocol size" << std::endl; + if(protos[0] != "test/1" || protos[1] != "test/2") + std::cout << "Bad protocol values" << std::endl; + return "test/3"; + }; + + const std::vector<std::string> protocols_offered = { "test/1", "test/2" }; try { - client.received_data(input.data(), input.size()); + std::vector<byte> c2s_traffic, s2c_traffic, client_recv, server_recv, client_sent, server_sent; + + TLS::Server server(queue_inserter(s2c_traffic), + queue_inserter(server_recv), + print_alert, + handshake_complete, + server_sessions, + creds, + policy, + rng, + next_protocol_chooser, + false); + + TLS::Client client(queue_inserter(c2s_traffic), + queue_inserter(client_recv), + print_alert, + handshake_complete, + client_sessions, + creds, + policy, + rng, + TLS::Server_Information("server.example.com"), + offer_version, + protocols_offered); + + size_t rounds = 0; + + while(true) + { + ++rounds; + + if(rounds > 25) + { + std::cout << "Still here, something went wrong\n"; + return 1; + } + + if(handshake_done && (client.is_closed() || server.is_closed())) + break; + + if(client.is_active() && client_sent.empty()) + { + // Choose a len between 1 and 511 + const size_t c_len = 1 + rng.next_byte() + rng.next_byte(); + client_sent = unlock(rng.random_vec(c_len)); + + // TODO send in several records + client.send(client_sent); + } + + if(server.is_active() && server_sent.empty()) + { + if(server.next_protocol() != "test/3") + std::cout << "Wrong protocol " << server.next_protocol() << std::endl; + + const size_t s_len = 1 + rng.next_byte() + rng.next_byte(); + server_sent = unlock(rng.random_vec(s_len)); + server.send(server_sent); + } + + const bool corrupt_client_data = (r == 3 && c2s_traffic.size() && rng.next_byte() % 3 == 0 && rounds > 1); + const bool corrupt_server_data = (r == 4 && s2c_traffic.size() && rng.next_byte() % 3 == 0 && rounds > 1); + + try + { + /* + * Use this as a temp value to hold the queues as otherwise they + * might end up appending more in response to messages during the + * handshake. + */ + //std::cout << "server recv " << c2s_traffic.size() << " bytes\n"; + std::vector<byte> input; + std::swap(c2s_traffic, input); + + if(corrupt_server_data) + { + //std::cout << "Corrupting server data\n"; + mutate(input, rng); + } + server.received_data(input.data(), input.size()); + } + catch(std::exception& e) + { + std::cout << "Server error - " << e.what() << std::endl; + continue; + } + + try + { + //std::cout << "client recv " << s2c_traffic.size() << " bytes\n"; + std::vector<byte> input; + std::swap(s2c_traffic, input); + if(corrupt_client_data) + { + //std::cout << "Corrupting client data\n"; + mutate(input, rng); + } + + client.received_data(input.data(), input.size()); + } + catch(std::exception& e) + { + std::cout << "Client error - " << e.what() << std::endl; + continue; + } + + if(client_recv.size()) + { + if(client_recv != server_sent) + { + std::cout << "Error in client recv" << std::endl; + return 1; + } + } + + if(server_recv.size()) + { + if(server_recv != client_sent) + { + std::cout << "Error in server recv" << std::endl; + return 1; + } + } + + if(client.is_closed() && server.is_closed()) + break; + + if(server_recv.size() && client_recv.size()) + { + SymmetricKey client_key = client.key_material_export("label", "context", 32); + SymmetricKey server_key = server.key_material_export("label", "context", 32); + + if(client_key != server_key) + { + std::cout << "TLS key material export mismatch: " + << client_key.as_string() << " != " + << server_key.as_string() << "\n"; + return 1; + } + + if(r % 2 == 0) + client.close(); + else + server.close(); + } + } } catch(std::exception& e) { - std::cout << "Client error - " << e.what() << std::endl; + std::cout << e.what() << "\n"; return 1; } + } - if(c2s_data.size()) - { - if(c2s_data[0] != '1') - { - std::cout << "Error" << std::endl; - return 1; - } - } + return 0; + } + +size_t test_dtls_handshake(RandomNumberGenerator& rng, + TLS::Protocol_Version offer_version, + Credentials_Manager& creds, + TLS::Policy& policy) + { + BOTAN_ASSERT(offer_version.is_datagram_protocol(), "Test is for datagram version"); + + TLS::Session_Manager_In_Memory server_sessions(rng); + TLS::Session_Manager_In_Memory client_sessions(rng); + + for(size_t r = 1; r <= 2; ++r) + { + //std::cout << offer_version.to_string() << " round " << r << "\n"; + + bool handshake_done = false; + + auto handshake_complete = [&](const TLS::Session& session) -> bool { + handshake_done = true; - if(s2c_data.size()) + /* + std::cout << "Session established " << session.version().to_string() << " " + << session.ciphersuite().to_string() << " " << hex_encode(session.session_id()) << "\n"; + */ + + if(session.version() != offer_version) + std::cout << "Offered " << offer_version.to_string() + << " got " << session.version().to_string() << std::endl; + return true; + }; + + auto next_protocol_chooser = [&](std::vector<std::string> protos) { + if(protos.size() != 2) + std::cout << "Bad protocol size" << std::endl; + if(protos[0] != "test/1" || protos[1] != "test/2") + std::cout << "Bad protocol values" << std::endl; + return "test/3"; + }; + + const std::vector<std::string> protocols_offered = { "test/1", "test/2" }; + + try { - if(s2c_data[0] != '2') + std::vector<byte> c2s_traffic, s2c_traffic, client_recv, server_recv, client_sent, server_sent; + + TLS::Server server(queue_inserter(s2c_traffic), + queue_inserter(server_recv), + print_alert, + handshake_complete, + server_sessions, + creds, + policy, + rng, + next_protocol_chooser, + true); + + TLS::Client client(queue_inserter(c2s_traffic), + queue_inserter(client_recv), + print_alert, + handshake_complete, + client_sessions, + creds, + policy, + rng, + TLS::Server_Information("server.example.com"), + offer_version, + protocols_offered); + + size_t rounds = 0; + + while(true) { - std::cout << "Error" << std::endl; - return 1; - } - } + // TODO: client and server should be in different threads + std::this_thread::sleep_for(std::chrono::milliseconds(rng.next_byte() % 2)); + ++rounds; - if(s2c_data.size() && c2s_data.size()) - { - SymmetricKey client_key = client.key_material_export("label", "context", 32); - SymmetricKey server_key = server.key_material_export("label", "context", 32); + if(rounds > 100) + { + std::cout << "Still here, something went wrong\n"; + return 1; + } + + if(handshake_done && (client.is_closed() || server.is_closed())) + break; + + if(client.is_active() && client_sent.empty()) + { + // Choose a len between 1 and 511 and send random chunks: + const size_t c_len = 1 + rng.next_byte() + rng.next_byte(); + client_sent = unlock(rng.random_vec(c_len)); + + // TODO send multiple parts + //std::cout << "Sending " << client_sent.size() << " bytes to server\n"; + client.send(client_sent); + } - if(client_key != server_key) - return 1; + if(server.is_active() && server_sent.empty()) + { + if(server.next_protocol() != "test/3") + std::cout << "Wrong protocol " << server.next_protocol() << std::endl; + + const size_t s_len = 1 + rng.next_byte() + rng.next_byte(); + server_sent = unlock(rng.random_vec(s_len)); + //std::cout << "Sending " << server_sent.size() << " bytes to client\n"; + server.send(server_sent); + } + + const bool corrupt_client_data = (r == 3 && c2s_traffic.size() && rng.next_byte() % 3 == 0 && rounds < 10); + const bool corrupt_server_data = (r == 4 && s2c_traffic.size() && rng.next_byte() % 3 == 0 && rounds < 10); + + try + { + /* + * Use this as a temp value to hold the queues as otherwise they + * might end up appending more in response to messages during the + * handshake. + */ + //std::cout << "server got " << c2s_traffic.size() << " bytes\n"; + std::vector<byte> input; + std::swap(c2s_traffic, input); + + if(corrupt_client_data) + { + //std::cout << "Corrupting client data\n"; + mutate(input, rng); + } + + server.received_data(input.data(), input.size()); + } + catch(std::exception& e) + { + std::cout << "Server error - " << e.what() << std::endl; + continue; + } + + try + { + //std::cout << "client got " << s2c_traffic.size() << " bytes\n"; + std::vector<byte> input; + std::swap(s2c_traffic, input); + + if(corrupt_server_data) + { + //std::cout << "Corrupting server data\n"; + mutate(input, rng); + } + client.received_data(input.data(), input.size()); + } + catch(std::exception& e) + { + std::cout << "Client error - " << e.what() << std::endl; + continue; + } - server.close(); - client.close(); + // If we corrupted a DTLS application message, resend it: + if(client.is_active() && corrupt_client_data && server_recv.empty()) + client.send(client_sent); + if(server.is_active() && corrupt_server_data && client_recv.empty()) + server.send(server_sent); + + if(client_recv.size()) + { + if(client_recv != server_sent) + { + std::cout << "Error in client recv" << std::endl; + return 1; + } + } + + if(server_recv.size()) + { + if(server_recv != client_sent) + { + std::cout << "Error in server recv" << std::endl; + return 1; + } + } + + if(client.is_closed() && server.is_closed()) + break; + + if(server_recv.size() && client_recv.size()) + { + SymmetricKey client_key = client.key_material_export("label", "context", 32); + SymmetricKey server_key = server.key_material_export("label", "context", 32); + + if(client_key != server_key) + { + std::cout << "TLS key material export mismatch: " + << client_key.as_string() << " != " + << server_key.as_string() << "\n"; + return 1; + } + + if(r % 2 == 0) + client.close(); + else + server.close(); + } + } + } + catch(std::exception& e) + { + std::cout << e.what() << "\n"; + return 1; } } return 0; } -class Test_Policy : public TLS::Policy +class Test_Policy : public TLS::Text_Policy { public: + Test_Policy() : Text_Policy("") {} bool acceptable_protocol_version(TLS::Protocol_Version) const override { return true; } bool send_fallback_scsv(TLS::Protocol_Version) const override { return false; } + + size_t dtls_initial_timeout() const override { return 1; } + size_t dtls_maximum_timeout() const override { return 8; } }; } @@ -315,17 +600,43 @@ size_t test_tls() { size_t errors = 0; - Test_Policy default_policy; auto& rng = test_rng(); std::unique_ptr<Credentials_Manager> basic_creds(create_creds()); - errors += basic_test_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, default_policy); - errors += basic_test_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, default_policy); - errors += basic_test_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, default_policy); - errors += basic_test_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, default_policy); - errors += basic_test_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, default_policy); - - test_report("TLS", 5, errors); + Test_Policy policy; + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy); + + policy.set("key_exchange_methods", "RSA"); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy); + + policy.set("key_exchange_methods", "DH"); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy); + policy.set("key_exchange_methods", "ECDH"); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy); + + policy.set("ciphers", "AES-128"); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy); + + policy.set("ciphers", "ChaCha20Poly1305"); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy); + + test_report("TLS", 22, errors); return errors; } |