diff options
-rw-r--r-- | doc/manual/tls.rst | 6 | ||||
-rw-r--r-- | doc/news.rst | 2 | ||||
-rw-r--r-- | src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp | 3 | ||||
-rw-r--r-- | src/lib/rng/rng.h | 8 | ||||
-rw-r--r-- | src/lib/tls/msg_client_hello.cpp | 6 | ||||
-rw-r--r-- | src/lib/tls/tls_channel.cpp | 38 | ||||
-rw-r--r-- | src/lib/tls/tls_channel.h | 17 | ||||
-rw-r--r-- | src/lib/tls/tls_client.cpp | 53 | ||||
-rw-r--r-- | src/lib/tls/tls_client.h | 15 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_io.cpp | 97 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_io.h | 16 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_msg.h | 2 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_state.cpp | 110 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_state.h | 6 | ||||
-rw-r--r-- | src/lib/tls/tls_magic.h | 2 | ||||
-rw-r--r-- | src/lib/tls/tls_policy.cpp | 13 | ||||
-rw-r--r-- | src/lib/tls/tls_policy.h | 15 | ||||
-rw-r--r-- | src/lib/tls/tls_record.cpp | 85 | ||||
-rw-r--r-- | src/lib/tls/tls_server.cpp | 58 | ||||
-rw-r--r-- | src/lib/tls/tls_server.h | 14 | ||||
-rw-r--r-- | src/lib/utils/ct_utils.h | 36 | ||||
-rw-r--r-- | src/tests/unit_tls.cpp | 551 |
22 files changed, 861 insertions, 292 deletions
diff --git a/doc/manual/tls.rst b/doc/manual/tls.rst index 554846c25..331bb56bb 100644 --- a/doc/manual/tls.rst +++ b/doc/manual/tls.rst @@ -517,7 +517,9 @@ policy settings from a file. authentication, sending data in cleartext) are also not supported by the implementation and cannot be negotiated. - Default value: "ChaCha20Poly1305", "AES-256/GCM", "AES-128/GCM", + Values without an explicit mode use old-style CBC with HMAC encryption. + + Default value: "AES-256/GCM", "AES-128/GCM", "ChaCha20Poly1305", "AES-256/CCM", "AES-128/CCM", "AES-256/CCM-8", "AES-128/CCM-8", "AES-256", "AES-128" @@ -570,7 +572,7 @@ policy settings from a file. Default: "ECDSA", "RSA", "DSA" - Also allowed: "" (meaning anonymous) + Also allowed (disabled by default): "" (meaning anonymous) .. cpp:function:: std::vector<std::string> allowed_ecc_curves() const diff --git a/doc/news.rst b/doc/news.rst index 72ab4ad9f..61df06d83 100644 --- a/doc/news.rst +++ b/doc/news.rst @@ -35,6 +35,8 @@ Version 1.11.22, Not Yet Released deriving the next value by squaring the previous ones. The reinitializion interval can be controlled by the build.h parameter BOTAN_BLINDING_REINIT_INTERVAL. +* A bug decoding DTLS client hellos prevented session resumption for suceeding. + * DL_Group now prohibits creating a group smaller than 1024 bits. * Add System_RNG type. Previously the global system RNG was only accessible via diff --git a/src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp b/src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp index 6b3bce0aa..5ff288db2 100644 --- a/src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp +++ b/src/lib/pk_pad/eme_pkcs1/eme_pkcs.cpp @@ -28,8 +28,7 @@ secure_vector<byte> EME_PKCS1v15::pad(const byte in[], size_t inlen, out[0] = 0x02; for(size_t j = 1; j != olen - inlen - 1; ++j) - while(out[j] == 0) - out[j] = rng.next_byte(); + out[j] = rng.next_nonzero_byte(); buffer_insert(out, olen - inlen, in, inlen); return out; diff --git a/src/lib/rng/rng.h b/src/lib/rng/rng.h index 6ee67f66f..261880d5d 100644 --- a/src/lib/rng/rng.h +++ b/src/lib/rng/rng.h @@ -75,6 +75,14 @@ class BOTAN_DLL RandomNumberGenerator */ byte next_byte() { return get_random<byte>(); } + byte next_nonzero_byte() + { + byte b = next_byte(); + while(b == 0) + b = next_byte(); + return b; + } + /** * Check whether this RNG is seeded. * @return true if this RNG was already seeded, false otherwise. diff --git a/src/lib/tls/msg_client_hello.cpp b/src/lib/tls/msg_client_hello.cpp index 82ba6f4f6..77bdc5cf5 100644 --- a/src/lib/tls/msg_client_hello.cpp +++ b/src/lib/tls/msg_client_hello.cpp @@ -1,6 +1,6 @@ /* * TLS Hello Request and Client Hello Messages -* (C) 2004-2011 Jack Lloyd +* (C) 2004-2011,2015 Jack Lloyd * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -210,11 +210,11 @@ Client_Hello::Client_Hello(const std::vector<byte>& buf) m_random = reader.get_fixed<byte>(32); + m_session_id = reader.get_range<byte>(1, 0, 32); + if(m_version.is_datagram_protocol()) m_hello_cookie = reader.get_range<byte>(1, 0, 255); - m_session_id = reader.get_range<byte>(1, 0, 32); - m_suites = reader.get_range_vector<u16bit>(2, 1, 32767); m_comp_methods = reader.get_range_vector<byte>(1, 1, 255); diff --git a/src/lib/tls/tls_channel.cpp b/src/lib/tls/tls_channel.cpp index e2b1aad9d..5dfcec34e 100644 --- a/src/lib/tls/tls_channel.cpp +++ b/src/lib/tls/tls_channel.cpp @@ -23,17 +23,21 @@ Channel::Channel(output_fn output_fn, data_cb data_cb, alert_cb alert_cb, handshake_cb handshake_cb, + handshake_msg_cb handshake_msg_cb, Session_Manager& session_manager, RandomNumberGenerator& rng, + const Policy& policy, bool is_datagram, size_t reserved_io_buffer_size) : m_is_datagram(is_datagram), - m_handshake_cb(handshake_cb), m_data_cb(data_cb), m_alert_cb(alert_cb), m_output_fn(output_fn), - m_rng(rng), - m_session_manager(session_manager) + m_handshake_cb(handshake_cb), + m_handshake_msg_cb(handshake_msg_cb), + m_session_manager(session_manager), + m_policy(policy), + m_rng(rng) { /* epoch 0 is plaintext, thus null cipher state */ m_write_cipher_states[0] = nullptr; @@ -66,20 +70,16 @@ Connection_Sequence_Numbers& Channel::sequence_numbers() const std::shared_ptr<Connection_Cipher_State> Channel::read_cipher_state_epoch(u16bit epoch) const { auto i = m_read_cipher_states.find(epoch); - - BOTAN_ASSERT(i != m_read_cipher_states.end(), - "Have a cipher state for the specified epoch"); - + if(i == m_read_cipher_states.end()) + throw Internal_Error("TLS::Channel No read cipherstate for epoch " + std::to_string(epoch)); return i->second; } std::shared_ptr<Connection_Cipher_State> Channel::write_cipher_state_epoch(u16bit epoch) const { auto i = m_write_cipher_states.find(epoch); - - BOTAN_ASSERT(i != m_write_cipher_states.end(), - "Have a cipher state for the specified epoch"); - + if(i == m_write_cipher_states.end()) + throw Internal_Error("TLS::Channel No write cipherstate for epoch " + std::to_string(epoch)); return i->second; } @@ -120,17 +120,17 @@ Handshake_State& Channel::create_handshake_state(Protocol_Version version) std::unique_ptr<Handshake_IO> io; if(version.is_datagram_protocol()) { - // default MTU is IPv6 min MTU minus UDP/IP headers (TODO: make configurable) - const u16bit mtu = 1280 - 40 - 8; - io.reset(new Datagram_Handshake_IO( std::bind(&Channel::send_record_under_epoch, this, _1, _2, _3), sequence_numbers(), - mtu)); + m_policy.dtls_default_mtu(), + m_policy.dtls_initial_timeout(), + m_policy.dtls_maximum_timeout())); } else - io.reset(new Stream_Handshake_IO( - std::bind(&Channel::send_record, this, _1, _2))); + { + io.reset(new Stream_Handshake_IO(std::bind(&Channel::send_record, this, _1, _2))); + } m_pending_state.reset(new_handshake_state(io.release())); @@ -333,12 +333,13 @@ size_t Channel::received_data(const byte input[], size_t input_size) if(record.size() > max_fragment_size) throw TLS_Exception(Alert::RECORD_OVERFLOW, - "Plaintext record is too large"); + "TLS input record is larger than allowed maximum"); if(record_type == HANDSHAKE || record_type == CHANGE_CIPHER_SPEC) { if(!m_pending_state) { + // No pending handshake, possibly new: if(record_version.is_datagram_protocol()) { if(m_sequence_numbers) @@ -374,6 +375,7 @@ size_t Channel::received_data(const byte input[], size_t input_size) } } + // May have been created in above conditional if(m_pending_state) { m_pending_state->handshake_io().add_record(unlock(record), diff --git a/src/lib/tls/tls_channel.h b/src/lib/tls/tls_channel.h index 4e6874a16..9ef2d17c4 100644 --- a/src/lib/tls/tls_channel.h +++ b/src/lib/tls/tls_channel.h @@ -24,6 +24,7 @@ namespace TLS { class Connection_Cipher_State; class Connection_Sequence_Numbers; class Handshake_State; +class Handshake_Message; /** * Generic interface for TLS endpoint @@ -35,15 +36,18 @@ class BOTAN_DLL Channel typedef std::function<void (const byte[], size_t)> data_cb; typedef std::function<void (Alert, const byte[], size_t)> alert_cb; typedef std::function<bool (const Session&)> handshake_cb; + typedef std::function<void (const Handshake_Message&)> handshake_msg_cb; Channel(output_fn out, data_cb app_data_cb, alert_cb alert_cb, handshake_cb hs_cb, + handshake_msg_cb hs_msg_cb, Session_Manager& session_manager, RandomNumberGenerator& rng, + const Policy& policy, bool is_datagram, - size_t reserved_io_buffer_size); + size_t io_buf_sz = 16*1024); Channel(const Channel&) = delete; @@ -196,6 +200,8 @@ class BOTAN_DLL Channel Handshake_State& create_handshake_state(Protocol_Version version); + void inspect_handshake_message(const Handshake_Message& msg); + void activate_session(); void change_cipher_spec_reader(Connection_Side side); @@ -214,8 +220,11 @@ class BOTAN_DLL Channel Session_Manager& session_manager() { return m_session_manager; } + const Policy& policy() const { return m_policy; } + bool save_session(const Session& session) const { return m_handshake_cb(session); } + handshake_msg_cb get_handshake_msg_cb() const { return m_handshake_msg_cb; } private: size_t maximum_fragment_size() const; @@ -245,14 +254,16 @@ class BOTAN_DLL Channel bool m_is_datagram; /* callbacks */ - handshake_cb m_handshake_cb; data_cb m_data_cb; alert_cb m_alert_cb; output_fn m_output_fn; + handshake_cb m_handshake_cb; + handshake_msg_cb m_handshake_msg_cb; /* external state */ - RandomNumberGenerator& m_rng; Session_Manager& m_session_manager; + const Policy& m_policy; + RandomNumberGenerator& m_rng; /* sequence number state */ std::unique_ptr<Connection_Sequence_Numbers> m_sequence_numbers; diff --git a/src/lib/tls/tls_client.cpp b/src/lib/tls/tls_client.cpp index 9306092ce..82630b7fa 100644 --- a/src/lib/tls/tls_client.cpp +++ b/src/lib/tls/tls_client.cpp @@ -1,6 +1,6 @@ /* * TLS Client -* (C) 2004-2011,2012 Jack Lloyd +* (C) 2004-2011,2012,2015 Jack Lloyd * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -23,8 +23,7 @@ class Client_Handshake_State : public Handshake_State public: // using Handshake_State::Handshake_State; - Client_Handshake_State(Handshake_IO* io, hs_msg_cb cb = hs_msg_cb()) : - Handshake_State(io, cb) {} + Client_Handshake_State(Handshake_IO* io, handshake_msg_cb cb) : Handshake_State(io, cb) {} const Public_Key& get_server_public_Key() const { @@ -55,9 +54,31 @@ Client::Client(output_fn output_fn, const Protocol_Version offer_version, const std::vector<std::string>& next_protos, size_t io_buf_sz) : - Channel(output_fn, proc_cb, alert_cb, handshake_cb, session_manager, rng, - offer_version.is_datagram_protocol(), io_buf_sz), - m_policy(policy), + Channel(output_fn, proc_cb, alert_cb, handshake_cb, Channel::handshake_msg_cb(), + session_manager, rng, policy, offer_version.is_datagram_protocol(), io_buf_sz), + m_creds(creds), + m_info(info) + { + const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_info.hostname()); + + Handshake_State& state = create_handshake_state(offer_version); + send_client_hello(state, false, offer_version, srp_identifier, next_protos); + } + +Client::Client(output_fn output_fn, + data_cb proc_cb, + alert_cb alert_cb, + handshake_cb handshake_cb, + handshake_msg_cb hs_msg_cb, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + const Server_Information& info, + const Protocol_Version offer_version, + const std::vector<std::string>& next_protos) : + Channel(output_fn, proc_cb, alert_cb, handshake_cb, hs_msg_cb, + session_manager, rng, policy, offer_version.is_datagram_protocol()), m_creds(creds), m_info(info) { @@ -69,7 +90,7 @@ Client::Client(output_fn output_fn, Handshake_State* Client::new_handshake_state(Handshake_IO* io) { - return new Client_Handshake_State(io); // , m_hs_msg_cb); + return new Client_Handshake_State(io, get_handshake_msg_cb()); } std::vector<X509_Certificate> @@ -111,7 +132,7 @@ void Client::send_client_hello(Handshake_State& state_base, state.client_hello(new Client_Hello( state.handshake_io(), state.hash(), - m_policy, + policy(), rng(), secure_renegotiation_data_for_client_hello(), session_info, @@ -128,7 +149,7 @@ void Client::send_client_hello(Handshake_State& state_base, state.handshake_io(), state.hash(), version, - m_policy, + policy(), rng(), secure_renegotiation_data_for_client_hello(), next_protocols, @@ -157,9 +178,9 @@ void Client::process_handshake_msg(const Handshake_State* active_state, if(state.client_hello()) return; - if(m_policy.allow_server_initiated_renegotiation()) + if(policy().allow_server_initiated_renegotiation()) { - if(!secure_renegotiation_supported() && m_policy.allow_insecure_renegotiation() == false) + if(!secure_renegotiation_supported() && policy().allow_insecure_renegotiation() == false) send_warning_alert(Alert::NO_RENEGOTIATION); else this->initiate_handshake(state, false); @@ -263,7 +284,9 @@ void Client::process_handshake_msg(const Handshake_State* active_state, if(state.server_hello()->supports_session_ticket()) state.set_expected_next(NEW_SESSION_TICKET); else + { state.set_expected_next(HANDSHAKE_CCS); + } } else { @@ -282,7 +305,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, "Server replied with later version than in hello"); } - if(!m_policy.acceptable_protocol_version(state.version())) + if(!policy().acceptable_protocol_version(state.version())) { throw TLS_Exception(Alert::PROTOCOL_VERSION, "Server version " + state.version().to_string() + @@ -406,7 +429,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, state.client_kex( new Client_Key_Exchange(state.handshake_io(), state, - m_policy, + policy(), m_creds, state.server_public_key.get(), m_info.hostname(), @@ -426,7 +449,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, state.client_verify( new Certificate_Verify(state.handshake_io(), state, - m_policy, + policy(), rng(), private_key) ); @@ -477,7 +500,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, const std::vector<byte>& session_ticket = state.session_ticket(); if(session_id.empty() && !session_ticket.empty()) - session_id = make_hello_random(rng(), m_policy); + session_id = make_hello_random(rng(), policy()); Session session_info( session_id, diff --git a/src/lib/tls/tls_client.h b/src/lib/tls/tls_client.h index e4e0dc363..b835c013e 100644 --- a/src/lib/tls/tls_client.h +++ b/src/lib/tls/tls_client.h @@ -67,6 +67,20 @@ class BOTAN_DLL Client : public Channel size_t reserved_io_buffer_size = 16*1024 ); + Client(output_fn out, + data_cb app_data_cb, + alert_cb alert_cb, + handshake_cb hs_cb, + handshake_msg_cb hs_msg_cb, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + const Server_Information& server_info = Server_Information(), + const Protocol_Version offer_version = Protocol_Version::latest_tls_version(), + const std::vector<std::string>& next_protocols = {} + ); + const std::string& application_protocol() const { return m_application_protocol; } private: std::vector<X509_Certificate> @@ -88,7 +102,6 @@ class BOTAN_DLL Client : public Channel Handshake_State* new_handshake_state(Handshake_IO* io) override; - const Policy& m_policy; Credentials_Manager& m_creds; const Server_Information m_info; std::string m_application_protocol; diff --git a/src/lib/tls/tls_handshake_io.cpp b/src/lib/tls/tls_handshake_io.cpp index 6286eab08..f39c9f84e 100644 --- a/src/lib/tls/tls_handshake_io.cpp +++ b/src/lib/tls/tls_handshake_io.cpp @@ -1,6 +1,6 @@ /* * TLS Handshake IO -* (C) 2012,2014 Jack Lloyd +* (C) 2012,2014,2015 Jack Lloyd * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -33,6 +33,24 @@ void store_be24(byte out[3], size_t val) out[2] = get_byte<u32bit>(3, val); } +u64bit steady_clock_ms() + { + return std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()).count(); + } + +size_t split_for_mtu(size_t mtu, size_t msg_size) + { + const size_t DTLS_HEADERS_SIZE = 25; // DTLS record+handshake headers + + const size_t parts = (msg_size + mtu) / mtu; + + if(parts + DTLS_HEADERS_SIZE > mtu) + return parts + 1; + + return parts; + } + } Protocol_Version Stream_Handshake_IO::initial_record_version() const @@ -123,41 +141,15 @@ Protocol_Version Datagram_Handshake_IO::initial_record_version() const return Protocol_Version::DTLS_V10; } -namespace { - -// 1 second initial timeout, 60 second max - see RFC 6347 sec 4.2.4.1 -const u64bit INITIAL_TIMEOUT = 1*1000; -const u64bit MAXIMUM_TIMEOUT = 60*1000; - -u64bit steady_clock_ms() +void Datagram_Handshake_IO::retransmit_last_flight() { - return std::chrono::duration_cast<std::chrono::milliseconds>( - std::chrono::steady_clock::now().time_since_epoch()).count(); + const size_t flight_idx = (m_flights.size() == 1) ? 0 : (m_flights.size() - 2); + retransmit_flight(flight_idx); } -} - -bool Datagram_Handshake_IO::timeout_check() +void Datagram_Handshake_IO::retransmit_flight(size_t flight_idx) { - if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty())) - { - /* - If we haven't written anything yet obviously no timeout. - Also no timeout possible if we are mid-flight, - */ - return false; - } - - const u64bit ms_since_write = steady_clock_ms() - m_last_write; - - if(ms_since_write < m_next_timeout) - return false; - - std::vector<u16bit> flight; - if(m_flights.size() == 1) - flight = m_flights.at(0); // lost initial client hello - else - flight = m_flights.at(m_flights.size() - 2); + const std::vector<u16bit>& flight = m_flights.at(flight_idx); BOTAN_ASSERT(flight.size() > 0, "Nonempty flight to retransmit"); @@ -177,8 +169,27 @@ bool Datagram_Handshake_IO::timeout_check() send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits); epoch = msg.epoch; } + } + +bool Datagram_Handshake_IO::timeout_check() + { + if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty())) + { + /* + If we haven't written anything yet obviously no timeout. + Also no timeout possible if we are mid-flight, + */ + return false; + } + + const u64bit ms_since_write = steady_clock_ms() - m_last_write; + + if(ms_since_write < m_next_timeout) + return false; - m_next_timeout = std::min(2 * m_next_timeout, MAXIMUM_TIMEOUT); + retransmit_last_flight(); + + m_next_timeout = std::min(2 * m_next_timeout, m_max_timeout); return true; } @@ -251,7 +262,6 @@ Datagram_Handshake_IO::get_next_record(bool expecting_ccs) if(m_ccs_epochs.count(current_epoch)) return std::make_pair(HANDSHAKE_CCS, std::vector<byte>()); } - return std::make_pair(HANDSHAKE_NONE, std::vector<byte>()); } @@ -376,21 +386,6 @@ Datagram_Handshake_IO::format(const std::vector<byte>& msg, return format_w_seq(msg, type, m_in_message_seq - 1); } -namespace { - -size_t split_for_mtu(size_t mtu, size_t msg_size) - { - const size_t DTLS_HEADERS_SIZE = 25; // DTLS record+handshake headers - - const size_t parts = (msg_size + mtu) / mtu; - - if(parts + DTLS_HEADERS_SIZE > mtu) - return parts + 1; - - return parts; - } - -} std::vector<byte> Datagram_Handshake_IO::send(const Handshake_Message& msg) @@ -411,7 +406,7 @@ Datagram_Handshake_IO::send(const Handshake_Message& msg) m_out_message_seq += 1; m_last_write = steady_clock_ms(); - m_next_timeout = INITIAL_TIMEOUT; + m_next_timeout = m_initial_timeout; return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits); } @@ -425,7 +420,9 @@ std::vector<byte> Datagram_Handshake_IO::send_message(u16bit msg_seq, format_w_seq(msg_bits, msg_type, msg_seq); if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu) + { m_send_hs(epoch, HANDSHAKE, no_fragment); + } else { const size_t parts = split_for_mtu(m_mtu, msg_bits.size()); diff --git a/src/lib/tls/tls_handshake_io.h b/src/lib/tls/tls_handshake_io.h index 00074a744..a1c1c5ce3 100644 --- a/src/lib/tls/tls_handshake_io.h +++ b/src/lib/tls/tls_handshake_io.h @@ -100,8 +100,14 @@ class Datagram_Handshake_IO : public Handshake_IO Datagram_Handshake_IO(writer_fn writer, class Connection_Sequence_Numbers& seq, - u16bit mtu) : - m_seqs(seq), m_flights(1), m_send_hs(writer), m_mtu(mtu) {} + u16bit mtu, u64bit initial_timeout_ms, u64bit max_timeout_ms) : + m_seqs(seq), + m_flights(1), + m_initial_timeout(initial_timeout_ms), + m_max_timeout(max_timeout_ms), + m_send_hs(writer), + m_mtu(mtu) + {} Protocol_Version initial_record_version() const override; @@ -120,6 +126,9 @@ class Datagram_Handshake_IO : public Handshake_IO std::pair<Handshake_Type, std::vector<byte>> get_next_record(bool expecting_ccs) override; private: + void retransmit_flight(size_t flight); + void retransmit_last_flight(); + std::vector<byte> format_fragment( const byte fragment[], size_t fragment_len, @@ -183,6 +192,9 @@ class Datagram_Handshake_IO : public Handshake_IO std::vector<std::vector<u16bit>> m_flights; std::map<u16bit, Message_Info> m_flight_data; + u64bit m_initial_timeout = 0; + u64bit m_max_timeout = 0; + u64bit m_last_write = 0; u64bit m_next_timeout = 0; diff --git a/src/lib/tls/tls_handshake_msg.h b/src/lib/tls/tls_handshake_msg.h index 6937d4f2c..7e527abf4 100644 --- a/src/lib/tls/tls_handshake_msg.h +++ b/src/lib/tls/tls_handshake_msg.h @@ -22,6 +22,8 @@ namespace TLS { class BOTAN_DLL Handshake_Message { public: + std::string type_string() const; + virtual Handshake_Type type() const = 0; virtual std::vector<byte> serialize() const = 0; diff --git a/src/lib/tls/tls_handshake_state.cpp b/src/lib/tls/tls_handshake_state.cpp index cbbca3a0d..f885d3b08 100644 --- a/src/lib/tls/tls_handshake_state.cpp +++ b/src/lib/tls/tls_handshake_state.cpp @@ -1,6 +1,6 @@ /* * TLS Handshaking -* (C) 2004-2006,2011,2012 Jack Lloyd +* (C) 2004-2006,2011,2012,2015 Jack Lloyd * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -13,6 +13,67 @@ namespace Botan { namespace TLS { +std::string Handshake_Message::type_string() const + { + return handshake_type_to_string(type()); + } + +const char* handshake_type_to_string(Handshake_Type type) + { + switch(type) + { + case HELLO_VERIFY_REQUEST: + return "hello_verify_request"; + + case HELLO_REQUEST: + return "hello_request"; + + case CLIENT_HELLO: + return "client_hello"; + + case SERVER_HELLO: + return "server_hello"; + + case CERTIFICATE: + return "certificate"; + + case CERTIFICATE_URL: + return "certificate_url"; + + case CERTIFICATE_STATUS: + return "certificate_status"; + + case SERVER_KEX: + return "server_key_exchange"; + + case CERTIFICATE_REQUEST: + return "certificate_request"; + + case SERVER_HELLO_DONE: + return "server_hello_done"; + + case CERTIFICATE_VERIFY: + return "certificate_verify"; + + case CLIENT_KEX: + return "client_key_exchange"; + + case NEW_SESSION_TICKET: + return "new_session_ticket"; + + case HANDSHAKE_CCS: + return "change_cipher_spec"; + + case FINISHED: + return "finished"; + + case HANDSHAKE_NONE: + return "invalid"; + } + + throw Internal_Error("Unknown TLS handshake message type " + std::to_string(type)); + } + namespace { u32bit bitmask_for_handshake_type(Handshake_Type type) @@ -25,9 +86,6 @@ u32bit bitmask_for_handshake_type(Handshake_Type type) case HELLO_REQUEST: return (1 << 1); - /* - * Same code point for both client hello styles - */ case CLIENT_HELLO: return (1 << 2); @@ -75,12 +133,48 @@ u32bit bitmask_for_handshake_type(Handshake_Type type) throw Internal_Error("Unknown handshake type " + std::to_string(type)); } +std::string handshake_mask_to_string(u32bit mask) + { + const Handshake_Type types[] = { + HELLO_VERIFY_REQUEST, + HELLO_REQUEST, + CLIENT_HELLO, + CERTIFICATE, + CERTIFICATE_URL, + CERTIFICATE_STATUS, + SERVER_KEX, + CERTIFICATE_REQUEST, + SERVER_HELLO_DONE, + CERTIFICATE_VERIFY, + CLIENT_KEX, + NEW_SESSION_TICKET, + HANDSHAKE_CCS, + FINISHED + }; + + std::ostringstream o; + bool empty = true; + + for(auto&& t : types) + { + if(mask & bitmask_for_handshake_type(t)) + { + if(!empty) + o << ","; + o << handshake_type_to_string(t); + empty = false; + } + } + + return o.str(); + } + } /* * Initialize the SSL/TLS Handshake State */ -Handshake_State::Handshake_State(Handshake_IO* io, hs_msg_cb cb) : +Handshake_State::Handshake_State(Handshake_IO* io, handshake_msg_cb cb) : m_msg_callback(cb), m_handshake_io(io), m_version(m_handshake_io->initial_record_version()) @@ -196,10 +290,10 @@ void Handshake_State::confirm_transition_to(Handshake_Type handshake_msg) const bool ok = (m_hand_expecting_mask & mask); // overlap? if(!ok) - throw Unexpected_Message("Unexpected state transition in handshake, got " + + throw Unexpected_Message("Unexpected state transition in handshake, got type " + std::to_string(handshake_msg) + - " expected " + std::to_string(m_hand_expecting_mask) + - " received " + std::to_string(m_hand_received_mask)); + " expected " + handshake_mask_to_string(m_hand_expecting_mask) + + " received " + handshake_mask_to_string(m_hand_received_mask)); /* We don't know what to expect next, so force a call to set_expected_next; if it doesn't happen, the next transition diff --git a/src/lib/tls/tls_handshake_state.h b/src/lib/tls/tls_handshake_state.h index 3b60178b4..6260b090f 100644 --- a/src/lib/tls/tls_handshake_state.h +++ b/src/lib/tls/tls_handshake_state.h @@ -45,9 +45,9 @@ class Finished; class Handshake_State { public: - typedef std::function<void (const Handshake_Message&)> hs_msg_cb; + typedef std::function<void (const Handshake_Message&)> handshake_msg_cb; - Handshake_State(Handshake_IO* io, hs_msg_cb cb); + Handshake_State(Handshake_IO* io, handshake_msg_cb cb); virtual ~Handshake_State(); @@ -170,7 +170,7 @@ class Handshake_State private: - hs_msg_cb m_msg_callback; + handshake_msg_cb m_msg_callback; std::unique_ptr<Handshake_IO> m_handshake_io; diff --git a/src/lib/tls/tls_magic.h b/src/lib/tls/tls_magic.h index 882e59158..6db908b08 100644 --- a/src/lib/tls/tls_magic.h +++ b/src/lib/tls/tls_magic.h @@ -57,6 +57,8 @@ enum Handshake_Type { HANDSHAKE_NONE = 255 // Null value }; +const char* handshake_type_to_string(Handshake_Type t); + enum Compression_Method { NO_COMPRESSION = 0x00, DEFLATE_COMPRESSION = 0x01 diff --git a/src/lib/tls/tls_policy.cpp b/src/lib/tls/tls_policy.cpp index f50cf1f3e..d8dd2c828 100644 --- a/src/lib/tls/tls_policy.cpp +++ b/src/lib/tls/tls_policy.cpp @@ -20,9 +20,9 @@ std::vector<std::string> Policy::allowed_ciphers() const return { //"AES-256/OCB(12)", //"AES-128/OCB(12)", - "ChaCha20Poly1305", "AES-256/GCM", "AES-128/GCM", + "ChaCha20Poly1305", "AES-256/CCM", "AES-128/CCM", "AES-256/CCM(8)", @@ -35,7 +35,6 @@ std::vector<std::string> Policy::allowed_ciphers() const //"Camellia-128", //"SEED" //"3DES", - //"RC4", }; } @@ -175,6 +174,16 @@ bool Policy::include_time_in_hello_random() const { return true; } bool Policy::hide_unknown_users() const { return false; } bool Policy::server_uses_own_ciphersuite_preferences() const { return true; } +// 1 second initial timeout, 60 second max - see RFC 6347 sec 4.2.4.1 +size_t Policy::dtls_initial_timeout() const { return 1*1000; } +size_t Policy::dtls_maximum_timeout() const { return 60*1000; } + +size_t Policy::dtls_default_mtu() const + { + // default MTU is IPv6 min MTU minus UDP/IP headers + return 1280 - 40 - 8; + } + std::vector<u16bit> Policy::srtp_profiles() const { return std::vector<u16bit>(); diff --git a/src/lib/tls/tls_policy.h b/src/lib/tls/tls_policy.h index 581d04bcd..c3f8f1ee2 100644 --- a/src/lib/tls/tls_policy.h +++ b/src/lib/tls/tls_policy.h @@ -13,6 +13,7 @@ #include <botan/x509cert.h> #include <botan/dl_group.h> #include <vector> +#include <sstream> namespace Botan { @@ -173,6 +174,12 @@ class BOTAN_DLL Policy virtual std::vector<u16bit> ciphersuite_list(Protocol_Version version, bool have_srp) const; + virtual size_t dtls_default_mtu() const; + + virtual size_t dtls_initial_timeout() const; + + virtual size_t dtls_maximum_timeout() const; + virtual void print(std::ostream& o) const; virtual ~Policy() {} @@ -299,6 +306,14 @@ class BOTAN_DLL Text_Policy : public Policy return r; } + void set(const std::string& k, const std::string& v) { m_kv[k] = v; } + + Text_Policy(const std::string& s) + { + std::istringstream iss(s); + m_kv = read_cfg(iss); + } + Text_Policy(std::istream& in) { m_kv = read_cfg(in); diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp index 71542de16..e38b26547 100644 --- a/src/lib/tls/tls_record.cpp +++ b/src/lib/tls/tls_record.cpp @@ -12,6 +12,7 @@ #include <botan/internal/tls_seq_numbers.h> #include <botan/internal/tls_session_key.h> #include <botan/internal/rounding.h> +#include <botan/internal/ct_utils.h> #include <botan/rng.h> namespace Botan { @@ -284,31 +285,26 @@ size_t fill_buffer_to(secure_vector<byte>& readbuf, * * Returning 0 in the error case should ensure the MAC check will fail. * This approach is suggested in section 6.2.3.2 of RFC 5246. -* -* Also returns 0 if block_size == 0, so can be safely called with a -* stream cipher in use. -* -* @fixme This should run in constant time */ -size_t tls_padding_check(const byte record[], size_t record_len) +u16bit tls_padding_check(const byte record[], size_t record_len) { - const size_t padding_length = record[(record_len-1)]; - - if(padding_length >= record_len) - return 0; - /* * TLS v1.0 and up require all the padding bytes be the same value * and allows up to 255 bytes. */ - const size_t pad_start = record_len - padding_length - 1; - volatile size_t cmp = 0; + const byte pad_byte = record[(record_len-1)]; - for(size_t i = 0; i != padding_length; ++i) - cmp += record[pad_start + i] ^ padding_length; + byte pad_invalid = 0; + for(size_t i = 0; i != record_len; ++i) + { + const size_t left = record_len - i - 2; + const byte delim_mask = CT::is_less<u16bit>(left, pad_byte) & 0xFF; + pad_invalid |= (delim_mask & (record[i] ^ pad_byte)); + } - return cmp ? 0 : padding_length + 1; + u16bit pad_invalid_mask = CT::expand_mask<u16bit>(pad_invalid); + return CT::select<u16bit>(pad_invalid_mask, 0, pad_byte + 1); } void cbc_decrypt_record(byte record_contents[], size_t record_len, @@ -375,38 +371,39 @@ void decrypt_record(secure_vector<byte>& output, else { // GenericBlockCipher case + BlockCipher* bc = cs.block_cipher(); + BOTAN_ASSERT(bc != nullptr, "No cipher state set but needed to decrypt"); - volatile bool padding_bad = false; - size_t pad_size = 0; + const size_t mac_size = cs.mac_size(); + const size_t iv_size = cs.iv_size(); - if(BlockCipher* bc = cs.block_cipher()) - { - cbc_decrypt_record(record_contents, record_len, cs, *bc); + // This early exit does not leak info because all the values are public + if((record_len < mac_size + iv_size) || (record_len % cs.block_size() != 0)) + throw Decoding_Error("Record sent with invalid length"); - pad_size = tls_padding_check(record_contents, record_len); + CT::poison(record_contents, record_len); - padding_bad = (pad_size == 0); - } - else - { - throw Internal_Error("No cipher state set but needed to decrypt"); - } + cbc_decrypt_record(record_contents, record_len, cs, *bc); - const size_t mac_size = cs.mac_size(); - const size_t iv_size = cs.iv_size(); + // 0 if padding was invalid, otherwise 1 + padding_bytes + u16bit pad_size = tls_padding_check(record_contents, record_len); - const size_t mac_pad_iv_size = mac_size + pad_size + iv_size; + // This mask is zero if there is not enough room in the packet + const u16bit size_ok_mask = CT::is_less<u16bit>(mac_size + pad_size + iv_size, record_len); + pad_size &= size_ok_mask; - if(record_len < mac_pad_iv_size) - throw Decoding_Error("Record sent with invalid length"); + CT::unpoison(record_contents, record_len); - const byte* plaintext_block = &record_contents[iv_size]; - const u16bit plaintext_length = record_len - mac_pad_iv_size; + /* + This is unpoisoned sooner than it should. The pad_size leaks to plaintext_length and + then to the timing channel in the MAC computation described in the Lucky 13 paper. + */ + CT::unpoison(pad_size); - cs.mac()->update( - cs.format_ad(record_sequence, record_type, record_version, plaintext_length) - ); + const byte* plaintext_block = &record_contents[iv_size]; + const u16bit plaintext_length = record_len - mac_size - iv_size - pad_size; + cs.mac()->update(cs.format_ad(record_sequence, record_type, record_version, plaintext_length)); cs.mac()->update(plaintext_block, plaintext_length); std::vector<byte> mac_buf(mac_size); @@ -414,12 +411,16 @@ void decrypt_record(secure_vector<byte>& output, const size_t mac_offset = record_len - (mac_size + pad_size); - const bool mac_bad = !same_mem(&record_contents[mac_offset], mac_buf.data(), mac_size); + const bool mac_ok = same_mem(&record_contents[mac_offset], mac_buf.data(), mac_size); - if(mac_bad || padding_bad) - throw TLS_Exception(Alert::BAD_RECORD_MAC, "Message authentication failure"); + const u16bit ok_mask = size_ok_mask & CT::expand_mask<u16bit>(mac_ok) & CT::expand_mask<u16bit>(pad_size); - output.assign(plaintext_block, plaintext_block + plaintext_length); + CT::unpoison(ok_mask); + + if(ok_mask) + output.assign(plaintext_block, plaintext_block + plaintext_length); + else + throw TLS_Exception(Alert::BAD_RECORD_MAC, "Message authentication failure"); } } diff --git a/src/lib/tls/tls_server.cpp b/src/lib/tls/tls_server.cpp index 330135e63..774827346 100644 --- a/src/lib/tls/tls_server.cpp +++ b/src/lib/tls/tls_server.cpp @@ -21,8 +21,7 @@ class Server_Handshake_State : public Handshake_State public: // using Handshake_State::Handshake_State; - Server_Handshake_State(Handshake_IO* io, hs_msg_cb cb = hs_msg_cb()) : - Handshake_State(io, cb) {} + Server_Handshake_State(Handshake_IO* io, handshake_msg_cb cb) : Handshake_State(io, cb) {} // Used by the server only, in case of RSA key exchange. Not owned Private_Key* server_rsa_kex_key = nullptr; @@ -215,9 +214,26 @@ Server::Server(output_fn output, next_protocol_fn next_proto, bool is_datagram, size_t io_buf_sz) : - Channel(output, data_cb, alert_cb, handshake_cb, - session_manager, rng, is_datagram, io_buf_sz), - m_policy(policy), + Channel(output, data_cb, alert_cb, handshake_cb, Channel::handshake_msg_cb(), + session_manager, rng, policy, is_datagram, io_buf_sz), + m_creds(creds), + m_choose_next_protocol(next_proto) + { + } + +Server::Server(output_fn output, + data_cb data_cb, + alert_cb alert_cb, + handshake_cb handshake_cb, + handshake_msg_cb hs_msg_cb, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + next_protocol_fn next_proto, + bool is_datagram) : + Channel(output, data_cb, alert_cb, handshake_cb, hs_msg_cb, + session_manager, rng, policy, is_datagram), m_creds(creds), m_choose_next_protocol(next_proto) { @@ -225,7 +241,9 @@ Server::Server(output_fn output, Handshake_State* Server::new_handshake_state(Handshake_IO* io) { - std::unique_ptr<Handshake_State> state(new Server_Handshake_State(io)); + std::unique_ptr<Handshake_State> state( + new Server_Handshake_State(io, get_handshake_msg_cb())); + state->set_expected_next(CLIENT_HELLO); return state.release(); } @@ -278,7 +296,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, { const bool initial_handshake = !active_state; - if(!m_policy.allow_insecure_renegotiation() && + if(!policy().allow_insecure_renegotiation() && !(initial_handshake || secure_renegotiation_supported())) { send_warning_alert(Alert::NO_RENEGOTIATION); @@ -292,7 +310,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, Protocol_Version negotiated_version; const Protocol_Version latest_supported = - m_policy.latest_supported_version(client_version.is_datagram_protocol()); + policy().latest_supported_version(client_version.is_datagram_protocol()); if((initial_handshake && client_version.known_version()) || (!initial_handshake && client_version == active_state->version())) @@ -334,7 +352,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, negotiated_version = latest_supported; } - if(!m_policy.acceptable_protocol_version(negotiated_version)) + if(!policy().acceptable_protocol_version(negotiated_version)) { throw TLS_Exception(Alert::PROTOCOL_VERSION, "Client version " + negotiated_version.to_string() + @@ -359,7 +377,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, session_manager(), m_creds, state.client_hello(), - std::chrono::seconds(m_policy.session_ticket_lifetime())); + std::chrono::seconds(policy().session_ticket_lifetime())); bool have_session_ticket_key = false; @@ -387,7 +405,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, state.server_hello(new Server_Hello( state.handshake_io(), state.hash(), - m_policy, + policy(), rng(), secure_renegotiation_data_for_server_hello(), *state.client_hello(), @@ -423,7 +441,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, new New_Session_Ticket(state.handshake_io(), state.hash(), session_info.encrypt(ticket_key, rng()), - m_policy.session_ticket_lifetime()) + policy().session_ticket_lifetime()) ); } catch(...) {} @@ -469,14 +487,14 @@ void Server::process_handshake_msg(const Handshake_State* active_state, state.server_hello(new Server_Hello( state.handshake_io(), state.hash(), - m_policy, + policy(), rng(), secure_renegotiation_data_for_server_hello(), *state.client_hello(), - make_hello_random(rng(), m_policy), // new session ID + make_hello_random(rng(), policy()), // new session ID state.version(), - choose_ciphersuite(m_policy, state.version(), m_creds, cert_chains, state.client_hello()), - choose_compression(m_policy, state.client_hello()->compression_methods()), + choose_ciphersuite(policy(), state.version(), m_creds, cert_chains, state.client_hello()), + choose_compression(policy(), state.client_hello()->compression_methods()), have_session_ticket_key, m_next_protocol) ); @@ -516,7 +534,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, else { state.server_kex(new Server_Key_Exchange(state.handshake_io(), - state, m_policy, + state, policy(), m_creds, rng(), private_key)); } @@ -534,7 +552,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, { state.cert_req( new Certificate_Req(state.handshake_io(), state.hash(), - m_policy, client_auth_CAs, state.version())); + policy(), client_auth_CAs, state.version())); state.set_expected_next(CERTIFICATE); } @@ -565,7 +583,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, state.client_kex( new Client_Key_Exchange(contents, state, state.server_rsa_kex_key, - m_creds, m_policy, rng()) + m_creds, policy(), rng()) ); state.compute_session_keys(); @@ -649,7 +667,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, new New_Session_Ticket(state.handshake_io(), state.hash(), session_info.encrypt(ticket_key, rng()), - m_policy.session_ticket_lifetime()) + policy().session_ticket_lifetime()) ); } catch(...) {} diff --git a/src/lib/tls/tls_server.h b/src/lib/tls/tls_server.h index 4f2a11ba4..ffe1111bc 100644 --- a/src/lib/tls/tls_server.h +++ b/src/lib/tls/tls_server.h @@ -40,6 +40,19 @@ class BOTAN_DLL Server : public Channel size_t reserved_io_buffer_size = 16*1024 ); + Server(output_fn output, + data_cb data_cb, + alert_cb alert_cb, + handshake_cb handshake_cb, + handshake_msg_cb hs_msg_cb, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + next_protocol_fn next_proto = next_protocol_fn(), + bool is_datagram = false + ); + /** * Return the protocol notification set by the client (using the * NPN extension) for this connection, if any. This value is not @@ -62,7 +75,6 @@ class BOTAN_DLL Server : public Channel Handshake_State* new_handshake_state(Handshake_IO* io) override; - const Policy& m_policy; Credentials_Manager& m_creds; next_protocol_fn m_choose_next_protocol; diff --git a/src/lib/utils/ct_utils.h b/src/lib/utils/ct_utils.h index 52a3bc388..2ea07b382 100644 --- a/src/lib/utils/ct_utils.h +++ b/src/lib/utils/ct_utils.h @@ -51,6 +51,12 @@ inline void unpoison(T* p, size_t n) #endif } +template<typename T> +inline void unpoison(T& p) + { + unpoison(&p, 1); + } + /* * T should be an unsigned machine integer type * Expand to a mask used for other operations @@ -90,6 +96,16 @@ inline T is_equal(T x, T y) } template<typename T> +inline T is_less(T x, T y) + { + /* + This expands to a constant time sequence with GCC 5.2.0 on x86-64 + but something more complicated may be needed for portable const time. + */ + return expand_mask<T>(x < y); + } + +template<typename T> inline void conditional_copy_mem(T value, T* to, const T* from0, @@ -102,6 +118,26 @@ inline void conditional_copy_mem(T value, to[i] = CT::select(mask, from0[i], from1[i]); } +template<typename T> +inline T expand_top_bit(T a) + { + return expand_mask<T>(a >> (sizeof(T)*8-1)); + } + +template<typename T> +inline T max(T a, T b) + { + const T a_larger = b - a; // negative if a is larger + return select(expand_top_bit(a), a, b); + } + +template<typename T> +inline T min(T a, T b) + { + const T a_larger = b - a; // negative if a is larger + return select(expand_top_bit(b), b, a); + } + } } diff --git a/src/tests/unit_tls.cpp b/src/tests/unit_tls.cpp index 116eb2cdf..d4abef119 100644 --- a/src/tests/unit_tls.cpp +++ b/src/tests/unit_tls.cpp @@ -11,6 +11,7 @@ #include <botan/tls_server.h> #include <botan/tls_client.h> +#include <botan/tls_handshake_msg.h> #include <botan/pkcs10.h> #include <botan/x509self.h> #include <botan/rsa.h> @@ -21,6 +22,7 @@ #include <iostream> #include <vector> #include <memory> +#include <thread> using namespace Botan; @@ -155,158 +157,441 @@ Credentials_Manager* create_creds() return new Credentials_Manager_Test(server_cert, ca_cert, server_key); } -size_t basic_test_handshake(RandomNumberGenerator& rng, - TLS::Protocol_Version offer_version, - Credentials_Manager& creds, - TLS::Policy& policy) +std::function<void (const byte[], size_t)> queue_inserter(std::vector<byte>& q) { - TLS::Session_Manager_In_Memory server_sessions(rng); - TLS::Session_Manager_In_Memory client_sessions(rng); - - std::vector<byte> c2s_q, s2c_q, c2s_data, s2c_data; + return [&](const byte buf[], size_t sz) { q.insert(q.end(), buf, buf + sz); }; + } - auto handshake_complete = [&](const TLS::Session& session) -> bool +void print_alert(TLS::Alert alert, const byte[], size_t) { - if(session.version() != offer_version) - std::cout << "Offered " << offer_version.to_string() - << " got " << session.version().to_string() << std::endl; - return true; + //std::cout << "Alert " << alert.type_string() << std::endl; }; - auto print_alert = [&](TLS::Alert alert, const byte[], size_t) +void mutate(std::vector<byte>& v, RandomNumberGenerator& rng) { - if(alert.is_valid()) - std::cout << "Server recvd alert " << alert.type_string() << std::endl; - }; + if(v.empty()) + return; - auto save_server_data = [&](const byte buf[], size_t sz) - { - c2s_data.insert(c2s_data.end(), buf, buf+sz); - }; + size_t voff = rng.get_random<size_t>() % v.size(); + v[voff] ^= rng.next_nonzero_byte(); + } - auto save_client_data = [&](const byte buf[], size_t sz) +size_t test_tls_handshake(RandomNumberGenerator& rng, + TLS::Protocol_Version offer_version, + Credentials_Manager& creds, + TLS::Policy& policy) { - s2c_data.insert(s2c_data.end(), buf, buf+sz); - }; + TLS::Session_Manager_In_Memory server_sessions(rng); + TLS::Session_Manager_In_Memory client_sessions(rng); - auto next_protocol_chooser = [&](std::vector<std::string> protos) { - if(protos.size() != 2) - std::cout << "Bad protocol size" << std::endl; - if(protos[0] != "test/1" || protos[1] != "test/2") - std::cout << "Bad protocol values" << std::endl; - return "test/3"; - }; - const std::vector<std::string> protocols_offered = { "test/1", "test/2" }; - - TLS::Server server([&](const byte buf[], size_t sz) - { s2c_q.insert(s2c_q.end(), buf, buf+sz); }, - save_server_data, - print_alert, - handshake_complete, - server_sessions, - creds, - policy, - rng, - next_protocol_chooser, - offer_version.is_datagram_protocol()); - - TLS::Client client([&](const byte buf[], size_t sz) - { c2s_q.insert(c2s_q.end(), buf, buf+sz); }, - save_client_data, - print_alert, - handshake_complete, - client_sessions, - creds, - policy, - rng, - TLS::Server_Information("server.example.com"), - offer_version, - protocols_offered); - - while(true) + for(size_t r = 1; r <= 4; ++r) { - if(client.is_closed() && server.is_closed()) - break; + //std::cout << offer_version.to_string() << " r " << r << "\n"; - if(client.is_active()) - client.send("1"); - if(server.is_active()) - { - if(server.next_protocol() != "test/3") - std::cout << "Wrong protocol " << server.next_protocol() << std::endl; - server.send("2"); - } + bool handshake_done = false; - /* - * Use this as a temp value to hold the queues as otherwise they - * might end up appending more in response to messages during the - * handshake. - */ - std::vector<byte> input; - std::swap(c2s_q, input); + auto handshake_complete = [&](const TLS::Session& session) -> bool { + handshake_done = true; - try - { - server.received_data(input.data(), input.size()); - } - catch(std::exception& e) - { - std::cout << "Server error - " << e.what() << std::endl; - return 1; - } + /* + std::cout << "Session established " << session.version().to_string() << " " + << session.ciphersuite().to_string() << " " << hex_encode(session.session_id()) << "\n"; + */ + + if(session.version() != offer_version) + std::cout << "Offered " << offer_version.to_string() + << " got " << session.version().to_string() << std::endl; + return true; + }; - input.clear(); - std::swap(s2c_q, input); + auto next_protocol_chooser = [&](std::vector<std::string> protos) { + if(protos.size() != 2) + std::cout << "Bad protocol size" << std::endl; + if(protos[0] != "test/1" || protos[1] != "test/2") + std::cout << "Bad protocol values" << std::endl; + return "test/3"; + }; + + const std::vector<std::string> protocols_offered = { "test/1", "test/2" }; try { - client.received_data(input.data(), input.size()); + std::vector<byte> c2s_traffic, s2c_traffic, client_recv, server_recv, client_sent, server_sent; + + TLS::Server server(queue_inserter(s2c_traffic), + queue_inserter(server_recv), + print_alert, + handshake_complete, + server_sessions, + creds, + policy, + rng, + next_protocol_chooser, + false); + + TLS::Client client(queue_inserter(c2s_traffic), + queue_inserter(client_recv), + print_alert, + handshake_complete, + client_sessions, + creds, + policy, + rng, + TLS::Server_Information("server.example.com"), + offer_version, + protocols_offered); + + size_t rounds = 0; + + while(true) + { + ++rounds; + + if(rounds > 25) + { + std::cout << "Still here, something went wrong\n"; + return 1; + } + + if(handshake_done && (client.is_closed() || server.is_closed())) + break; + + if(client.is_active() && client_sent.empty()) + { + // Choose a len between 1 and 511 + const size_t c_len = 1 + rng.next_byte() + rng.next_byte(); + client_sent = unlock(rng.random_vec(c_len)); + + // TODO send in several records + client.send(client_sent); + } + + if(server.is_active() && server_sent.empty()) + { + if(server.next_protocol() != "test/3") + std::cout << "Wrong protocol " << server.next_protocol() << std::endl; + + const size_t s_len = 1 + rng.next_byte() + rng.next_byte(); + server_sent = unlock(rng.random_vec(s_len)); + server.send(server_sent); + } + + const bool corrupt_client_data = (r == 3 && c2s_traffic.size() && rng.next_byte() % 3 == 0 && rounds > 1); + const bool corrupt_server_data = (r == 4 && s2c_traffic.size() && rng.next_byte() % 3 == 0 && rounds > 1); + + try + { + /* + * Use this as a temp value to hold the queues as otherwise they + * might end up appending more in response to messages during the + * handshake. + */ + //std::cout << "server recv " << c2s_traffic.size() << " bytes\n"; + std::vector<byte> input; + std::swap(c2s_traffic, input); + + if(corrupt_server_data) + { + //std::cout << "Corrupting server data\n"; + mutate(input, rng); + } + server.received_data(input.data(), input.size()); + } + catch(std::exception& e) + { + std::cout << "Server error - " << e.what() << std::endl; + continue; + } + + try + { + //std::cout << "client recv " << s2c_traffic.size() << " bytes\n"; + std::vector<byte> input; + std::swap(s2c_traffic, input); + if(corrupt_client_data) + { + //std::cout << "Corrupting client data\n"; + mutate(input, rng); + } + + client.received_data(input.data(), input.size()); + } + catch(std::exception& e) + { + std::cout << "Client error - " << e.what() << std::endl; + continue; + } + + if(client_recv.size()) + { + if(client_recv != server_sent) + { + std::cout << "Error in client recv" << std::endl; + return 1; + } + } + + if(server_recv.size()) + { + if(server_recv != client_sent) + { + std::cout << "Error in server recv" << std::endl; + return 1; + } + } + + if(client.is_closed() && server.is_closed()) + break; + + if(server_recv.size() && client_recv.size()) + { + SymmetricKey client_key = client.key_material_export("label", "context", 32); + SymmetricKey server_key = server.key_material_export("label", "context", 32); + + if(client_key != server_key) + { + std::cout << "TLS key material export mismatch: " + << client_key.as_string() << " != " + << server_key.as_string() << "\n"; + return 1; + } + + if(r % 2 == 0) + client.close(); + else + server.close(); + } + } } catch(std::exception& e) { - std::cout << "Client error - " << e.what() << std::endl; + std::cout << e.what() << "\n"; return 1; } + } - if(c2s_data.size()) - { - if(c2s_data[0] != '1') - { - std::cout << "Error" << std::endl; - return 1; - } - } + return 0; + } + +size_t test_dtls_handshake(RandomNumberGenerator& rng, + TLS::Protocol_Version offer_version, + Credentials_Manager& creds, + TLS::Policy& policy) + { + BOTAN_ASSERT(offer_version.is_datagram_protocol(), "Test is for datagram version"); + + TLS::Session_Manager_In_Memory server_sessions(rng); + TLS::Session_Manager_In_Memory client_sessions(rng); + + for(size_t r = 1; r <= 2; ++r) + { + //std::cout << offer_version.to_string() << " round " << r << "\n"; + + bool handshake_done = false; + + auto handshake_complete = [&](const TLS::Session& session) -> bool { + handshake_done = true; - if(s2c_data.size()) + /* + std::cout << "Session established " << session.version().to_string() << " " + << session.ciphersuite().to_string() << " " << hex_encode(session.session_id()) << "\n"; + */ + + if(session.version() != offer_version) + std::cout << "Offered " << offer_version.to_string() + << " got " << session.version().to_string() << std::endl; + return true; + }; + + auto next_protocol_chooser = [&](std::vector<std::string> protos) { + if(protos.size() != 2) + std::cout << "Bad protocol size" << std::endl; + if(protos[0] != "test/1" || protos[1] != "test/2") + std::cout << "Bad protocol values" << std::endl; + return "test/3"; + }; + + const std::vector<std::string> protocols_offered = { "test/1", "test/2" }; + + try { - if(s2c_data[0] != '2') + std::vector<byte> c2s_traffic, s2c_traffic, client_recv, server_recv, client_sent, server_sent; + + TLS::Server server(queue_inserter(s2c_traffic), + queue_inserter(server_recv), + print_alert, + handshake_complete, + server_sessions, + creds, + policy, + rng, + next_protocol_chooser, + true); + + TLS::Client client(queue_inserter(c2s_traffic), + queue_inserter(client_recv), + print_alert, + handshake_complete, + client_sessions, + creds, + policy, + rng, + TLS::Server_Information("server.example.com"), + offer_version, + protocols_offered); + + size_t rounds = 0; + + while(true) { - std::cout << "Error" << std::endl; - return 1; - } - } + // TODO: client and server should be in different threads + std::this_thread::sleep_for(std::chrono::milliseconds(rng.next_byte() % 2)); + ++rounds; - if(s2c_data.size() && c2s_data.size()) - { - SymmetricKey client_key = client.key_material_export("label", "context", 32); - SymmetricKey server_key = server.key_material_export("label", "context", 32); + if(rounds > 100) + { + std::cout << "Still here, something went wrong\n"; + return 1; + } + + if(handshake_done && (client.is_closed() || server.is_closed())) + break; + + if(client.is_active() && client_sent.empty()) + { + // Choose a len between 1 and 511 and send random chunks: + const size_t c_len = 1 + rng.next_byte() + rng.next_byte(); + client_sent = unlock(rng.random_vec(c_len)); + + // TODO send multiple parts + //std::cout << "Sending " << client_sent.size() << " bytes to server\n"; + client.send(client_sent); + } - if(client_key != server_key) - return 1; + if(server.is_active() && server_sent.empty()) + { + if(server.next_protocol() != "test/3") + std::cout << "Wrong protocol " << server.next_protocol() << std::endl; + + const size_t s_len = 1 + rng.next_byte() + rng.next_byte(); + server_sent = unlock(rng.random_vec(s_len)); + //std::cout << "Sending " << server_sent.size() << " bytes to client\n"; + server.send(server_sent); + } + + const bool corrupt_client_data = (r == 3 && c2s_traffic.size() && rng.next_byte() % 3 == 0 && rounds < 10); + const bool corrupt_server_data = (r == 4 && s2c_traffic.size() && rng.next_byte() % 3 == 0 && rounds < 10); + + try + { + /* + * Use this as a temp value to hold the queues as otherwise they + * might end up appending more in response to messages during the + * handshake. + */ + //std::cout << "server got " << c2s_traffic.size() << " bytes\n"; + std::vector<byte> input; + std::swap(c2s_traffic, input); + + if(corrupt_client_data) + { + //std::cout << "Corrupting client data\n"; + mutate(input, rng); + } + + server.received_data(input.data(), input.size()); + } + catch(std::exception& e) + { + std::cout << "Server error - " << e.what() << std::endl; + continue; + } + + try + { + //std::cout << "client got " << s2c_traffic.size() << " bytes\n"; + std::vector<byte> input; + std::swap(s2c_traffic, input); + + if(corrupt_server_data) + { + //std::cout << "Corrupting server data\n"; + mutate(input, rng); + } + client.received_data(input.data(), input.size()); + } + catch(std::exception& e) + { + std::cout << "Client error - " << e.what() << std::endl; + continue; + } - server.close(); - client.close(); + // If we corrupted a DTLS application message, resend it: + if(client.is_active() && corrupt_client_data && server_recv.empty()) + client.send(client_sent); + if(server.is_active() && corrupt_server_data && client_recv.empty()) + server.send(server_sent); + + if(client_recv.size()) + { + if(client_recv != server_sent) + { + std::cout << "Error in client recv" << std::endl; + return 1; + } + } + + if(server_recv.size()) + { + if(server_recv != client_sent) + { + std::cout << "Error in server recv" << std::endl; + return 1; + } + } + + if(client.is_closed() && server.is_closed()) + break; + + if(server_recv.size() && client_recv.size()) + { + SymmetricKey client_key = client.key_material_export("label", "context", 32); + SymmetricKey server_key = server.key_material_export("label", "context", 32); + + if(client_key != server_key) + { + std::cout << "TLS key material export mismatch: " + << client_key.as_string() << " != " + << server_key.as_string() << "\n"; + return 1; + } + + if(r % 2 == 0) + client.close(); + else + server.close(); + } + } + } + catch(std::exception& e) + { + std::cout << e.what() << "\n"; + return 1; } } return 0; } -class Test_Policy : public TLS::Policy +class Test_Policy : public TLS::Text_Policy { public: + Test_Policy() : Text_Policy("") {} bool acceptable_protocol_version(TLS::Protocol_Version) const override { return true; } bool send_fallback_scsv(TLS::Protocol_Version) const override { return false; } + + size_t dtls_initial_timeout() const override { return 1; } + size_t dtls_maximum_timeout() const override { return 8; } }; } @@ -315,17 +600,43 @@ size_t test_tls() { size_t errors = 0; - Test_Policy default_policy; auto& rng = test_rng(); std::unique_ptr<Credentials_Manager> basic_creds(create_creds()); - errors += basic_test_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, default_policy); - errors += basic_test_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, default_policy); - errors += basic_test_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, default_policy); - errors += basic_test_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, default_policy); - errors += basic_test_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, default_policy); - - test_report("TLS", 5, errors); + Test_Policy policy; + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy); + + policy.set("key_exchange_methods", "RSA"); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy); + + policy.set("key_exchange_methods", "DH"); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy); + policy.set("key_exchange_methods", "ECDH"); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy); + + policy.set("ciphers", "AES-128"); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy); + + policy.set("ciphers", "ChaCha20Poly1305"); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy); + + test_report("TLS", 22, errors); return errors; } |