diff options
author | Jack Lloyd <[email protected]> | 2018-01-27 13:38:04 -0500 |
---|---|---|
committer | Jack Lloyd <[email protected]> | 2018-01-27 13:38:04 -0500 |
commit | e5cf7992ff53c3fbe4beb106d3fd80b8845957b7 (patch) | |
tree | a1732c98087a76ccc2bfc40f4b0ce846902632a5 | |
parent | cfe57137e5957b84b6b749db8d9f02c3ee1f8c1e (diff) | |
parent | 7f7feb41880d87ea170633b47f5dede30ea528de (diff) |
Merge GH #1394 Add ability to use custom extensions, control which extensions are used
-rw-r--r-- | src/lib/tls/msg_client_hello.cpp | 7 | ||||
-rw-r--r-- | src/lib/tls/msg_server_hello.cpp | 7 | ||||
-rw-r--r-- | src/lib/tls/tls_callbacks.cpp | 8 | ||||
-rw-r--r-- | src/lib/tls/tls_callbacks.h | 34 | ||||
-rw-r--r-- | src/lib/tls/tls_client.cpp | 4 | ||||
-rw-r--r-- | src/lib/tls/tls_extensions.cpp | 30 | ||||
-rw-r--r-- | src/lib/tls/tls_extensions.h | 48 | ||||
-rw-r--r-- | src/lib/tls/tls_messages.h | 11 | ||||
-rw-r--r-- | src/lib/tls/tls_server.cpp | 4 | ||||
-rw-r--r-- | src/lib/tls/tls_server.h | 10 | ||||
-rw-r--r-- | src/tests/data/tls/client_hello.vec | 6 | ||||
-rw-r--r-- | src/tests/data/tls/server_hello.vec | 6 | ||||
-rw-r--r-- | src/tests/unit_tls.cpp | 528 |
13 files changed, 427 insertions, 276 deletions
diff --git a/src/lib/tls/msg_client_hello.cpp b/src/lib/tls/msg_client_hello.cpp index eeeaf8c71..68753fa26 100644 --- a/src/lib/tls/msg_client_hello.cpp +++ b/src/lib/tls/msg_client_hello.cpp @@ -10,6 +10,7 @@ #include <botan/tls_messages.h> #include <botan/tls_alert.h> #include <botan/tls_exceptn.h> +#include <botan/tls_callbacks.h> #include <botan/rng.h> #include <botan/hash.h> @@ -81,6 +82,7 @@ std::vector<uint8_t> Hello_Request::serialize() const Client_Hello::Client_Hello(Handshake_IO& io, Handshake_Hash& hash, const Policy& policy, + Callbacks& cb, RandomNumberGenerator& rng, const std::vector<uint8_t>& reneg_info, const Client_Hello::Settings& client_settings, @@ -140,6 +142,8 @@ Client_Hello::Client_Hello(Handshake_IO& io, m_extensions.add(new Signature_Algorithms(policy.allowed_signature_hashes(), policy.allowed_signature_methods())); + cb.tls_modify_extensions(m_extensions, CLIENT); + if(policy.send_fallback_scsv(client_settings.protocol_version())) m_suites.push_back(TLS_FALLBACK_SCSV); @@ -152,6 +156,7 @@ Client_Hello::Client_Hello(Handshake_IO& io, Client_Hello::Client_Hello(Handshake_IO& io, Handshake_Hash& hash, const Policy& policy, + Callbacks& cb, RandomNumberGenerator& rng, const std::vector<uint8_t>& reneg_info, const Session& session, @@ -201,6 +206,8 @@ Client_Hello::Client_Hello(Handshake_IO& io, if(reneg_info.empty() && !next_protocols.empty()) m_extensions.add(new Application_Layer_Protocol_Notification(next_protocols)); + cb.tls_modify_extensions(m_extensions, CLIENT); + hash.update(io.send(*this)); } diff --git a/src/lib/tls/msg_server_hello.cpp b/src/lib/tls/msg_server_hello.cpp index 5e290eb68..2d5a185f0 100644 --- a/src/lib/tls/msg_server_hello.cpp +++ b/src/lib/tls/msg_server_hello.cpp @@ -9,6 +9,7 @@ #include <botan/tls_messages.h> #include <botan/tls_extensions.h> +#include <botan/tls_callbacks.h> #include <botan/internal/tls_reader.h> #include <botan/internal/tls_session_key.h> #include <botan/internal/tls_handshake_io.h> @@ -23,6 +24,7 @@ namespace TLS { Server_Hello::Server_Hello(Handshake_IO& io, Handshake_Hash& hash, const Policy& policy, + Callbacks& cb, RandomNumberGenerator& rng, const std::vector<uint8_t>& reneg_info, const Client_Hello& client_hello, @@ -83,6 +85,8 @@ Server_Hello::Server_Hello(Handshake_IO& io, } } + cb.tls_modify_extensions(m_extensions, SERVER); + hash.update(io.send(*this)); } @@ -90,6 +94,7 @@ Server_Hello::Server_Hello(Handshake_IO& io, Server_Hello::Server_Hello(Handshake_IO& io, Handshake_Hash& hash, const Policy& policy, + Callbacks& cb, RandomNumberGenerator& rng, const std::vector<uint8_t>& reneg_info, const Client_Hello& client_hello, @@ -130,6 +135,8 @@ Server_Hello::Server_Hello(Handshake_IO& io, if(!next_protocol.empty() && client_hello.supports_alpn()) m_extensions.add(new Application_Layer_Protocol_Notification(next_protocol)); + cb.tls_modify_extensions(m_extensions, SERVER); + hash.update(io.send(*this)); } diff --git a/src/lib/tls/tls_callbacks.cpp b/src/lib/tls/tls_callbacks.cpp index b8f38589e..7a64291c8 100644 --- a/src/lib/tls/tls_callbacks.cpp +++ b/src/lib/tls/tls_callbacks.cpp @@ -32,6 +32,14 @@ std::string TLS::Callbacks::tls_server_choose_app_protocol(const std::vector<std return ""; } +void TLS::Callbacks::tls_modify_extensions(Extensions&, Connection_Side) + { + } + +void TLS::Callbacks::tls_examine_extensions(const Extensions&, Connection_Side) + { + } + void TLS::Callbacks::tls_verify_cert_chain( const std::vector<X509_Certificate>& cert_chain, const std::vector<std::shared_ptr<const OCSP::Response>>& ocsp_responses, diff --git a/src/lib/tls/tls_callbacks.h b/src/lib/tls/tls_callbacks.h index 4437a222a..dd6ad2d4b 100644 --- a/src/lib/tls/tls_callbacks.h +++ b/src/lib/tls/tls_callbacks.h @@ -30,6 +30,7 @@ namespace TLS { class Handshake_Message; class Policy; +class Extensions; /** * Encapsulates the callbacks that a TLS channel will make which are due to @@ -250,6 +251,39 @@ class BOTAN_PUBLIC_API(2,0) Callbacks virtual std::string tls_server_choose_app_protocol(const std::vector<std::string>& client_protos); /** + * Optional callback: examine/modify Extensions before sending. + * + * Both client and server will call this callback on the Extensions object + * before serializing it in the client/server hellos. This allows an + * application to modify which extensions are sent during the + * handshake. + * + * Default implementation does nothing. + * + * @param extn the extensions + * @param which_side will be CLIENT or SERVER which is the current + * applications role in the exchange. + */ + virtual void tls_modify_extensions(Extensions& extn, Connection_Side which_side); + + /** + * Optional callback: examine peer extensions. + * + * Both client and server will call this callback with the Extensions + * object after receiving it from the peer. This allows examining the + * Extensions, for example to implement a custom extension. It also allows + * an application to require that a particular extension be implemented; + * throw an exception from this function to abort the handshake. + * + * Default implementation does nothing. + * + * @param extn the extensions + * @param which_side will be CLIENT if these are are the clients extensions (ie we are + * the server) or SERVER if these are the server extensions (we are the client). + */ + virtual void tls_examine_extensions(const Extensions& extn, Connection_Side which_side); + + /** * Optional callback: error logging. (not currently called) * @param err An error message related to this connection. */ diff --git a/src/lib/tls/tls_client.cpp b/src/lib/tls/tls_client.cpp index c88b6a7db..5f84481ac 100644 --- a/src/lib/tls/tls_client.cpp +++ b/src/lib/tls/tls_client.cpp @@ -169,6 +169,7 @@ void Client::send_client_hello(Handshake_State& state_base, new Client_Hello(state.handshake_io(), state.hash(), policy(), + callbacks(), rng(), secure_renegotiation_data_for_client_hello(), session_info, @@ -188,6 +189,7 @@ void Client::send_client_hello(Handshake_State& state_base, state.handshake_io(), state.hash(), policy(), + callbacks(), rng(), secure_renegotiation_data_for_client_hello(), client_settings, @@ -294,6 +296,8 @@ void Client::process_handshake_msg(const Handshake_State* active_state, "Server replied with DTLS-SRTP alg we did not send"); } + callbacks().tls_examine_extensions(state.server_hello()->extensions(), SERVER); + state.set_version(state.server_hello()->version()); m_application_protocol = state.server_hello()->next_protocol(); diff --git a/src/lib/tls/tls_extensions.cpp b/src/lib/tls/tls_extensions.cpp index d521f6bf8..522cf4a4f 100644 --- a/src/lib/tls/tls_extensions.cpp +++ b/src/lib/tls/tls_extensions.cpp @@ -59,7 +59,8 @@ Extension* make_extension(TLS_Data_Reader& reader, uint16_t code, uint16_t size) return new Session_Ticket(reader, size); } - return nullptr; // not known + return new Unknown_Extension(static_cast<Handshake_Extension_Type>(code), + reader, size); } } @@ -82,10 +83,7 @@ void Extensions::deserialize(TLS_Data_Reader& reader) extension_code, extension_size); - if(extn) - this->add(extn); - else // unknown/unhandled extension - reader.discard_next(extension_size); + this->add(extn); } } } @@ -124,6 +122,15 @@ std::vector<uint8_t> Extensions::serialize() const return buf; } +bool Extensions::remove_extension(Handshake_Extension_Type typ) + { + auto i = m_extensions.find(typ); + if(i == m_extensions.end()) + return false; + m_extensions.erase(i); + return true; + } + std::set<Handshake_Extension_Type> Extensions::extension_types() const { std::set<Handshake_Extension_Type> offers; @@ -132,6 +139,19 @@ std::set<Handshake_Extension_Type> Extensions::extension_types() const return offers; } +Unknown_Extension::Unknown_Extension(Handshake_Extension_Type type, + TLS_Data_Reader& reader, + uint16_t extension_size) : + m_type(type), + m_value(reader.get_fixed<uint8_t>(extension_size)) + { + } + +std::vector<uint8_t> Unknown_Extension::serialize() const + { + throw Invalid_State("Cannot encode an unknown TLS extension"); + } + Server_Name_Indicator::Server_Name_Indicator(TLS_Data_Reader& reader, uint16_t extension_size) { diff --git a/src/lib/tls/tls_extensions.h b/src/lib/tls/tls_extensions.h index 221d8b46f..5ba3c0b8e 100644 --- a/src/lib/tls/tls_extensions.h +++ b/src/lib/tls/tls_extensions.h @@ -432,6 +432,30 @@ class Certificate_Status_Request final : public Extension }; /** +* Unknown extensions are deserialized as this type +*/ +class BOTAN_UNSTABLE_API Unknown_Extension final : public Extension + { + public: + Unknown_Extension(Handshake_Extension_Type type, + TLS_Data_Reader& reader, + uint16_t extension_size); + + std::vector<uint8_t> serialize() const override; // always fails + + const std::vector<uint8_t>& value() { return m_value; } + + bool empty() const override { return false; } + + Handshake_Extension_Type type() const override { return m_type; } + + private: + Handshake_Extension_Type m_type; + std::vector<uint8_t> m_value; + + }; + +/** * Represents a block of extensions in a hello message */ class BOTAN_UNSTABLE_API Extensions final @@ -442,13 +466,7 @@ class BOTAN_UNSTABLE_API Extensions final template<typename T> T* get() const { - Handshake_Extension_Type type = T::static_type(); - - auto i = m_extensions.find(type); - - if(i != m_extensions.end()) - return dynamic_cast<T*>(i->second.get()); - return nullptr; + return dynamic_cast<T*>(get(T::static_type())); } template<typename T> @@ -462,10 +480,26 @@ class BOTAN_UNSTABLE_API Extensions final m_extensions[extn->type()].reset(extn); } + Extension* get(Handshake_Extension_Type type) const + { + auto i = m_extensions.find(type); + + if(i != m_extensions.end()) + return i->second.get(); + return nullptr; + } + std::vector<uint8_t> serialize() const; void deserialize(TLS_Data_Reader& reader); + /** + * Remvoe an extension from this extensions object, if it exists. + * Returns true if the extension existed (and thus is now removed), + * otherwise false (the extension wasn't set in the first place). + */ + bool remove_extension(Handshake_Extension_Type typ); + Extensions() = default; explicit Extensions(TLS_Data_Reader& reader) { deserialize(reader); } diff --git a/src/lib/tls/tls_messages.h b/src/lib/tls/tls_messages.h index 35ec3c83c..75e65fa7f 100644 --- a/src/lib/tls/tls_messages.h +++ b/src/lib/tls/tls_messages.h @@ -38,9 +38,10 @@ namespace TLS { class Session; class Handshake_IO; class Handshake_State; +class Callbacks; std::vector<uint8_t> make_hello_random(RandomNumberGenerator& rng, - const Policy& policy); + const Policy& policy); /** * DTLS Hello Verify Request @@ -145,9 +146,12 @@ class BOTAN_UNSTABLE_API Client_Hello final : public Handshake_Message std::set<Handshake_Extension_Type> extension_types() const { return m_extensions.extension_types(); } + const Extensions& extensions() const { return m_extensions; } + Client_Hello(Handshake_IO& io, Handshake_Hash& hash, const Policy& policy, + Callbacks& cb, RandomNumberGenerator& rng, const std::vector<uint8_t>& reneg_info, const Client_Hello::Settings& client_settings, @@ -156,6 +160,7 @@ class BOTAN_UNSTABLE_API Client_Hello final : public Handshake_Message Client_Hello(Handshake_IO& io, Handshake_Hash& hash, const Policy& policy, + Callbacks& cb, RandomNumberGenerator& rng, const std::vector<uint8_t>& reneg_info, const Session& resumed_session, @@ -274,6 +279,8 @@ class BOTAN_UNSTABLE_API Server_Hello final : public Handshake_Message std::set<Handshake_Extension_Type> extension_types() const { return m_extensions.extension_types(); } + const Extensions& extensions() const { return m_extensions; } + bool prefers_compressed_ec_points() const { if(auto ecc_formats = m_extensions.get<Supported_Point_Formats>()) @@ -286,6 +293,7 @@ class BOTAN_UNSTABLE_API Server_Hello final : public Handshake_Message Server_Hello(Handshake_IO& io, Handshake_Hash& hash, const Policy& policy, + Callbacks& cb, RandomNumberGenerator& rng, const std::vector<uint8_t>& secure_reneg_info, const Client_Hello& client_hello, @@ -295,6 +303,7 @@ class BOTAN_UNSTABLE_API Server_Hello final : public Handshake_Message Server_Hello(Handshake_IO& io, Handshake_Hash& hash, const Policy& policy, + Callbacks& cb, RandomNumberGenerator& rng, const std::vector<uint8_t>& secure_reneg_info, const Client_Hello& client_hello, diff --git a/src/lib/tls/tls_server.cpp b/src/lib/tls/tls_server.cpp index 2d2fb769b..38c5cf2ca 100644 --- a/src/lib/tls/tls_server.cpp +++ b/src/lib/tls/tls_server.cpp @@ -460,6 +460,8 @@ void Server::process_client_hello_msg(const Handshake_State* active_state, pending_state.set_version(negotiated_version); + callbacks().tls_examine_extensions(pending_state.client_hello()->extensions(), CLIENT); + Session session_info; const bool resuming = pending_state.allow_session_resumption() && @@ -703,6 +705,7 @@ void Server::session_resume(Server_Handshake_State& pending_state, pending_state.handshake_io(), pending_state.hash(), policy(), + callbacks(), rng(), secure_renegotiation_data_for_server_hello(), *pending_state.client_hello(), @@ -794,6 +797,7 @@ void Server::session_create(Server_Handshake_State& pending_state, pending_state.handshake_io(), pending_state.hash(), policy(), + callbacks(), rng(), secure_renegotiation_data_for_server_hello(), *pending_state.client_hello(), diff --git a/src/lib/tls/tls_server.h b/src/lib/tls/tls_server.h index eb6e710e1..7c5d9668f 100644 --- a/src/lib/tls/tls_server.h +++ b/src/lib/tls/tls_server.h @@ -96,12 +96,20 @@ class BOTAN_PUBLIC_API(2,0) Server final : public Channel /** * Return the protocol notification set by the client (using the - * NPN extension) for this connection, if any. This value is not + * ALPN extension) for this connection, if any. This value is not * tied to the session and a later renegotiation of the same * session can choose a new protocol. */ std::string next_protocol() const { return m_next_protocol; } + /** + * Return the protocol notification set by the client (using the + * ALPN extension) for this connection, if any. This value is not + * tied to the session and a later renegotiation of the same + * session can choose a new protocol. + */ + std::string application_protocol() const { return m_next_protocol; } + private: std::vector<X509_Certificate> get_peer_cert_chain(const Handshake_State& state) const override; diff --git a/src/tests/data/tls/client_hello.vec b/src/tests/data/tls/client_hello.vec index 827f2ea4d..afd8e83c1 100644 --- a/src/tests/data/tls/client_hello.vec +++ b/src/tests/data/tls/client_hello.vec @@ -13,13 +13,13 @@ Exception = # with extensions: point formats, ec curves, session ticket, signature algorithms, heartbeat (point formats and heartbeat not supported, empty renegotiation generated) Buffer = 0303871e18983024eaee1be8ae6607d5ecad941d33fd7fc1d8554a9e1fbfda8d30880000aac030c02cc028c024c014c00a00a500a300a1009f006b006a0069006800390038003700360088008700860085c032c02ec02ac026c00fc005009d003d00350084c02fc02bc027c023c013c00900a400a200a0009e00670040003f003e0033003200310030009a0099009800970045004400430042c031c02dc029c025c00ec004009c003c002f00960041c011c007c00cc00200050004c012c008001600130010000dc00dc003000a00ff01000055000b000403000102000a001c001a00170019001c001b0018001a0016000e000d000b000c0009000a00230000000d0020001e060106020603050105020503040104020403030103020303020102020203000f000101 Protocol = 0303 -AdditionalData = 000A000B000D0023FF01 +AdditionalData = 000A000B000D000F0023FF01 Exception = # with extensions: point formats, ec curves, session ticket, signature algorithms, heartbeat, Encrypt-then-MAC, Extended Master Secret (point formats and heartbeat not supported, empty renegotiation generated) Buffer = 0303e00da23523058b5dc9c445d97b2bb6315b019e97838ac4f16c23b2cb031b6a490000e2c0afc0adc030c02cc028c024c014c00ac0a3c09f00a500a300a1009f006b006a006900680039003800370036cca9cca8c077c073ccaa00c400c300c200c10088008700860085c032c02ec02ac026c00fc005c079c075c0a1c09d009d003d003500c00084c0aec0acc02fc02bc027c023c013c009c0a2c09e00a400a200a0009e00670040003f003e0033003200310030c076c07200be00bd00bc00bb009a0099009800970045004400430042c031c02dc029c025c00ec004c078c074c0a0c09c009c003c002f00ba009600410007c012c008001600130010000dc00dc003000a00ff0100005f000b000403000102000a001c001a00170019001c001b0018001a0016000e000d000b000c0009000a00230000000d00220020060106020603050105020503040104020403030103020303020102020203eded000f0001010016000000170000 Protocol = 0303 -AdditionalData = 000A000B000D001600170023FF01 +AdditionalData = 000A000B000D000F001600170023FF01 Exception = # empty @@ -65,4 +65,4 @@ Exception = Invalid argument Decoding error: Invalid ClientHello: Expected 255 b #invalid length of the heartbeat extension Buffer = 0303871e18983024eaee1be8ae6607d5ecad941d33fd7fc1d8554a9e1fbfda8d30880000aac030c02cc028c024c014c00a00a500a300a1009f006b006a0069006800390038003700360088008700860085c032c02ec02ac026c00fc005009d003d00350084c02fc02bc027c023c013c00900a400a200a0009e00670040003f003e0033003200310030009a0099009800970045004400430042c031c02dc029c025c00ec004009c003c002f00960041c011c007c00cc00200050004c012c008001600130010000dc00dc003000a00ff01000055000b000403000102000a001c001a00170019001c001b0018001a0016000e000d000b000c0009000a00230000000d0020001e060106020603050105020503040104020403030103020303020102020203000f000201 Protocol = 0303 -Exception = Invalid argument Decoding error: Invalid ClientHello: Expected 2 bytes remaining, only 1 left
\ No newline at end of file +Exception = Invalid argument Decoding error: Invalid ClientHello: Expected 2 bytes remaining, only 1 left diff --git a/src/tests/data/tls/server_hello.vec b/src/tests/data/tls/server_hello.vec index f3bf889cb..c4daed84e 100644 --- a/src/tests/data/tls/server_hello.vec +++ b/src/tests/data/tls/server_hello.vec @@ -9,14 +9,14 @@ Buffer = 0303ffea0bcfba564a4ce177c6a444b0ebdff5629b277293c618c1125f231e8628dd00c030000016ff01000100000b00040300010200230000000f000101 Protocol = 0303 Ciphersuite = C030 -AdditionalData = 000B0023FF01 +AdditionalData = 000B000F0023FF01 Exception = # correct, with session ticket, extended master secret, and renegotiation info Buffer = 03019f9cafa88664d9095f85dd64a39e5dd5c09f5a4a5362938af3718ee4e818af6a00c03000001aff01000100000b00040300010200230000000f00010100170000 Protocol = 0301 Ciphersuite = C030 -AdditionalData = 000B00170023FF01 +AdditionalData = 000B000F00170023FF01 Exception = # incorrect, corrupted @@ -45,4 +45,4 @@ Buffer = 03039f9cafa88664d9095f85dd64a39e5dd5c09f5a4a5362938af3718ee4e818af6a00c Protocol = 0303 Ciphersuite = C030 AdditionalData = 00170023FF01 -Exception = Invalid argument Decoding error: Invalid ServerHello: Expected 256 bytes remaining, only 9 left
\ No newline at end of file +Exception = Invalid argument Decoding error: Invalid ServerHello: Expected 256 bytes remaining, only 9 left diff --git a/src/tests/unit_tls.cpp b/src/tests/unit_tls.cpp index 0aa33d213..ae2146304 100644 --- a/src/tests/unit_tls.cpp +++ b/src/tests/unit_tls.cpp @@ -1,5 +1,5 @@ /* -* (C) 2014,2015 Jack Lloyd +* (C) 2014,2015,2018 Jack Lloyd * 2016 Matthias Gierlings * 2017 René Korthaus, Rohde & Schwarz Cybersecurity * 2017 Harry Reimann, Rohde & Schwarz Cybersecurity @@ -21,6 +21,8 @@ #include <botan/tls_client.h> #include <botan/tls_server.h> #include <botan/tls_policy.h> + #include <botan/tls_extensions.h> + #include <botan/internal/tls_reader.h> #include <botan/ec_group.h> #include <botan/hex.h> @@ -300,325 +302,339 @@ void alert_cb_with_data(Botan::TLS::Alert, const uint8_t[], size_t) { } -Test::Result test_tls_handshake(Botan::TLS::Protocol_Version offer_version, - Botan::Credentials_Manager& creds, - const Botan::TLS::Policy& client_policy, - const Botan::TLS::Policy& server_policy, - Botan::RandomNumberGenerator& rng, - Botan::TLS::Session_Manager& client_sessions, - Botan::TLS::Session_Manager& server_sessions) +class TLS_Handshake_Test final { - Test::Result result(offer_version.to_string()); - - result.start_timer(); - - for(size_t r = 1; r <= 4; ++r) - { - bool handshake_done = false; - - result.test_note("Test round " + std::to_string(r)); - - auto handshake_complete = [&](const Botan::TLS::Session& session) + public: + TLS_Handshake_Test(Botan::TLS::Protocol_Version offer_version, + Botan::Credentials_Manager& creds, + const Botan::TLS::Policy& client_policy, + const Botan::TLS::Policy& server_policy, + Botan::RandomNumberGenerator& rng, + Botan::TLS::Session_Manager& client_sessions, + Botan::TLS::Session_Manager& server_sessions) : + m_offer_version(offer_version), + m_results(offer_version.to_string()), // TODO descriptive constructor arg + m_creds(creds), + m_client_policy(client_policy), + m_client_sessions(client_sessions), + m_rng(rng) { - handshake_done = true; + m_server_cb.reset(new Test_Callbacks(m_results, offer_version, m_s2c, m_server_recv)); + m_client_cb.reset(new Test_Callbacks(m_results, offer_version, m_c2s, m_client_recv)); - const std::string session_report = - "Session established " + session.version().to_string() + " " + - session.ciphersuite().to_string() + " " + - Botan::hex_encode(session.session_id()); + m_server.reset( + new Botan::TLS::Server(*m_server_cb, server_sessions, m_creds, server_policy, m_rng) + ); - result.test_note(session_report); + } - if(session.version() != offer_version) - { - result.test_failure("Offered " + offer_version.to_string() + " got " + session.version().to_string()); - } + void go(); - if(r <= 2) - { - return true; - } - return false; - }; + const Test::Result& results() const { return m_results; } + private: - auto next_protocol_chooser = [&](std::vector<std::string> protos) -> std::string + class Test_Extension : public Botan::TLS::Extension { - if(r <= 2) - { - result.test_eq("protocol count", protos.size(), 2); - result.test_eq("protocol[0]", protos[0], "test/1"); - result.test_eq("protocol[1]", protos[1], "test/2"); - } - return "test/3"; - }; + public: + static Botan::TLS::Handshake_Extension_Type static_type() + { return static_cast<Botan::TLS::Handshake_Extension_Type>(666); } - const std::vector<std::string> protocols_offered = { "test/1", "test/2" }; + Botan::TLS::Handshake_Extension_Type type() const override { return static_type(); } - try - { - std::vector<uint8_t> c2s_traffic, s2c_traffic, client_recv, server_recv, client_sent, server_sent; + std::vector<uint8_t> serialize() const override { return m_buf; } - std::unique_ptr<Botan::TLS::Callbacks> server_cb(new Botan::TLS::Compat_Callbacks( - queue_inserter(s2c_traffic), - queue_inserter(server_recv), - std::function<void (Botan::TLS::Alert, const uint8_t[], size_t)>(alert_cb_with_data), - handshake_complete, - nullptr, - next_protocol_chooser)); + const std::vector<uint8_t>& value() const { return m_buf; } - // TLS::Server object constructed by new constructor using virtual callback interface. - std::unique_ptr<Botan::TLS::Server> server( - new Botan::TLS::Server(*server_cb, - server_sessions, - creds, - server_policy, - rng, - false)); + bool empty() const override { return false; } - std::unique_ptr<Botan::TLS::Callbacks> client_cb(new Botan::TLS::Compat_Callbacks( - queue_inserter(c2s_traffic), - queue_inserter(client_recv), - std::function<void (Botan::TLS::Alert, const uint8_t[], size_t)>(alert_cb_with_data), - handshake_complete)); + Test_Extension(Botan::TLS::Connection_Side side) + { + const uint8_t client_extn[6] = { 'c', 'l', 'i', 'e', 'n', 't' }; + const uint8_t server_extn[6] = { 's', 'e', 'r', 'v', 'e', 'r' }; - // TLS::Client object constructed by new constructor using virtual callback interface. - std::unique_ptr<Botan::TLS::Client> client( - new Botan::TLS::Client(*client_cb, - client_sessions, - creds, - client_policy, - rng, - Botan::TLS::Server_Information("server.example.com"), - offer_version, - protocols_offered)); + Botan::TLS::append_tls_length_value(m_buf, + (side == Botan::TLS::CLIENT) ? client_extn : server_extn, + 6, 1); + } - size_t rounds = 0; + Test_Extension(Botan::TLS::TLS_Data_Reader& reader, uint16_t) + { + m_buf = reader.get_range_vector<uint8_t>(1, 6, 6); + } + private: + std::vector<uint8_t> m_buf; + }; - // Test TLS using both new and legacy constructors. - for(size_t ctor_sel = 0; ctor_sel < 2; ctor_sel++) - { - if(ctor_sel == 1) + class Test_Callbacks : public Botan::TLS::Callbacks + { + public: + Test_Callbacks(Test::Result& results, + Botan::TLS::Protocol_Version expected_version, + std::vector<uint8_t>& outbound, + std::vector<uint8_t>& recv_buf) : + m_results(results), + m_expected_version(expected_version), + m_outbound(outbound), + m_recv(recv_buf) + {} + + void tls_emit_data(const uint8_t bits[], size_t len) override { - c2s_traffic.clear(); - s2c_traffic.clear(); - server_recv.clear(); - client_recv.clear(); - client_sent.clear(); - server_sent.clear(); + m_outbound.insert(m_outbound.end(), bits, bits + len); + } - // TLS::Server object constructed by legacy constructor. - server.reset( - new Botan::TLS::Server(queue_inserter(s2c_traffic), - queue_inserter(server_recv), - alert_cb_with_data, - handshake_complete, - server_sessions, - creds, - server_policy, - rng, - next_protocol_chooser, - false)); + void tls_record_received(uint64_t /*seq*/, const uint8_t bits[], size_t len) override + { + m_recv.insert(m_recv.end(), bits, bits + len); + } - // TLS::Client object constructed by legacy constructor. - client.reset( - new Botan::TLS::Client(queue_inserter(c2s_traffic), - queue_inserter(client_recv), - alert_cb_with_data, - handshake_complete, - client_sessions, - creds, - server_policy, - rng, - Botan::TLS::Server_Information("server.example.com"), - offer_version, - protocols_offered)); + void tls_alert(Botan::TLS::Alert /*alert*/) override + { + // TODO test that it is a no_renegotiation alert + // ignore } - while(true) + void tls_modify_extensions(Botan::TLS::Extensions& extn, Botan::TLS::Connection_Side which_side) override { - ++rounds; + extn.add(new Test_Extension(which_side)); + } - if(rounds > 25) - { - if(r <= 2) - { - result.test_failure("Still here after many rounds, deadlock?"); - } - break; - } + void tls_examine_extensions(const Botan::TLS::Extensions& extn, Botan::TLS::Connection_Side which_side) override + { + Botan::TLS::Extension* test_extn = extn.get(static_cast<Botan::TLS::Handshake_Extension_Type>(666)); - if(handshake_done && (client->is_closed() || server->is_closed())) + if(test_extn == nullptr) { - break; + m_results.test_failure("Did not receive test extension from peer"); } - - if(client->is_active() && client_sent.empty()) + else { - // Choose random application data to send - const size_t c_len = 1 + ((static_cast<size_t>(rng.next_byte()) << 4) ^ rng.next_byte()); - client_sent = unlock(rng.random_vec(c_len)); + Botan::TLS::Unknown_Extension* unknown_ext = dynamic_cast<Botan::TLS::Unknown_Extension*>(test_extn); - size_t sent_so_far = 0; - while(sent_so_far != client_sent.size()) - { - const size_t left = client_sent.size() - sent_so_far; - const size_t rnd12 = (rng.next_byte() << 4) ^ rng.next_byte(); - const size_t sending = std::min(left, rnd12); + const std::vector<uint8_t> val = unknown_ext->value(); - client->send(&client_sent[sent_so_far], sending); - sent_so_far += sending; + if(m_results.test_eq("Expected size for test extn", val.size(), 7)) + { + if(which_side == Botan::TLS::CLIENT) + m_results.test_eq("Expected extension value", val, "06636C69656E74"); + else + m_results.test_eq("Expected extension value", val, "06736572766572"); } - client->send_warning_alert(Botan::TLS::Alert::NO_RENEGOTIATION); + } + } - if(server->is_active() && server_sent.empty()) - { - result.test_eq("server->protocol", server->next_protocol(), "test/3"); + bool tls_session_established(const Botan::TLS::Session& session) override + { + const std::string session_report = + "Session established " + session.version().to_string() + " " + + session.ciphersuite().to_string() + " " + + Botan::hex_encode(session.session_id()); - const size_t s_len = 1 + ((static_cast<size_t>(rng.next_byte()) << 4) ^ rng.next_byte()); - server_sent = unlock(rng.random_vec(s_len)); + m_results.test_note(session_report); - size_t sent_so_far = 0; - while(sent_so_far != server_sent.size()) - { - const size_t left = server_sent.size() - sent_so_far; - const size_t rnd12 = (rng.next_byte() << 4) ^ rng.next_byte(); - const size_t sending = std::min(left, rnd12); + if(session.version() != m_expected_version) + { + m_results.test_failure("Expected " + m_expected_version.to_string() + + " negotiated " + session.version().to_string()); + } - server->send(&server_sent[sent_so_far], sending); - sent_so_far += sending; - } + return true; + } - server->send_warning_alert(Botan::TLS::Alert::NO_RENEGOTIATION); - } + std::string tls_server_choose_app_protocol(const std::vector<std::string>& protos) override + { + m_results.test_eq("ALPN protocol count", protos.size(), 2); + m_results.test_eq("ALPN protocol 1", protos[0], "test/1"); + m_results.test_eq("ALPN protocol 2", protos[1], "test/2"); + return "test/3"; + } - const bool corrupt_client_data = (r == 3); - const bool corrupt_server_data = (r == 4); + private: + Test::Result& m_results; + const Botan::TLS::Protocol_Version m_expected_version; + std::vector<uint8_t>& m_outbound; + std::vector<uint8_t>& m_recv; + }; - if(c2s_traffic.size() > 0) - { - /* - * Use this as a temp value to hold the queues as otherwise they - * might end up appending more in response to messages during the - * handshake. - */ - std::vector<uint8_t> input; - std::swap(c2s_traffic, input); + const Botan::TLS::Protocol_Version m_offer_version; + Test::Result m_results; - if(corrupt_server_data) - { - input = Test::mutate_vec(input, true, 5); - size_t needed = server->received_data(input.data(), input.size()); + Botan::Credentials_Manager& m_creds; + const Botan::TLS::Policy& m_client_policy; + Botan::TLS::Session_Manager& m_client_sessions; + Botan::RandomNumberGenerator& m_rng; - size_t total_consumed = needed; + std::unique_ptr<Test_Callbacks> m_client_cb; - while(needed > 0 && - result.test_lt("Never requesting more than max protocol len", needed, Botan::TLS::MAX_CIPHERTEXT_SIZE + 1) && - result.test_lt("Total requested is readonable", total_consumed, 128 * 1024)) - { - input.resize(needed); - rng.randomize(input.data(), input.size()); - needed = server->received_data(input.data(), input.size()); - total_consumed += needed; - } - } - else - { - size_t needed = server->received_data(input.data(), input.size()); - result.test_eq("full packet received", needed, 0); - } + std::unique_ptr<Test_Callbacks> m_server_cb; + std::unique_ptr<Botan::TLS::Server> m_server; - continue; - } + std::vector<uint8_t> m_c2s, m_s2c, m_client_recv, m_server_recv; + }; - if(s2c_traffic.size() > 0) - { - std::vector<uint8_t> input; - std::swap(s2c_traffic, input); +void TLS_Handshake_Test::go() + { + m_results.start_timer(); - if(corrupt_client_data) - { - input = Test::mutate_vec(input, true, 5); - size_t needed = client->received_data(input.data(), input.size()); + Botan::RandomNumberGenerator& rng = Test::rng(); - size_t total_consumed = 0; + const std::vector<std::string> protocols_offered = { "test/1", "test/2" }; - while(needed > 0 && - result.test_lt("Never requesting more than max protocol len", needed, Botan::TLS::MAX_CIPHERTEXT_SIZE + 1)) - { - input.resize(needed); - rng.randomize(input.data(), input.size()); - needed = client->received_data(input.data(), input.size()); - total_consumed += needed; - } - } - else - { - size_t needed = client->received_data(input.data(), input.size()); - result.test_eq("full packet received", needed, 0); - } + // Choose random application data to send + //const size_t c_len = 1 + ((static_cast<size_t>(rng.next_byte()) << 4) ^ rng.next_byte()); + const size_t c_len = 180; + std::vector<uint8_t> client_msg(c_len); + Test::rng().randomize(client_msg.data(), client_msg.size()); + bool client_has_written = false; - continue; - } + //const size_t s_len = 1 + ((static_cast<size_t>(rng.next_byte()) << 4) ^ rng.next_byte()); + const size_t s_len = 400; + std::vector<uint8_t> server_msg(s_len); + Test::rng().randomize(server_msg.data(), server_msg.size()); + bool server_has_written = false; - if(client_recv.size()) - { - result.test_eq("client recv", client_recv, server_sent); - } + std::unique_ptr<Botan::TLS::Client> client; + client.reset( + new Botan::TLS::Client(*m_client_cb, m_client_sessions, m_creds, + m_client_policy, m_rng, + Botan::TLS::Server_Information("server.example.com"), + m_offer_version, + protocols_offered)); - if(server_recv.size()) - { - result.test_eq("server->recv", server_recv, client_sent); - } + size_t rounds = 0; - if(r > 2) - { - if(client_recv.size() && server_recv.size()) - { - result.test_failure("Negotiated in the face of data corruption " + std::to_string(r)); - } - } + bool client_handshake_completed = false; + bool server_handshake_completed = false; - if(client->is_closed() && server->is_closed()) - { - break; - } + while(true) + { + ++rounds; - if(server_recv.size() && client_recv.size()) - { - Botan::SymmetricKey client_key = client->key_material_export("label", "context", 32); - Botan::SymmetricKey server_key = server->key_material_export("label", "context", 32); + if(rounds > 25) + { + m_results.test_failure("Still here after many rounds, deadlock?"); + break; + } - result.test_eq("TLS key material export", client_key.bits_of(), server_key.bits_of()); + if(client_handshake_completed == false && client->is_active()) + client_handshake_completed = true; - if(r % 2 == 0) - { - client->close(); - } - else - { - server->close(); - } - } - } - } + if(server_handshake_completed == false && m_server->is_active()) + server_handshake_completed = true; + + if(client->is_closed() || m_server->is_closed()) + { + break; } - catch(std::exception& e) + + if(client->is_active() && client_has_written == false) { - if(r > 2) + m_results.test_eq("client ALPN protocol", client->application_protocol(), "test/3"); + + size_t sent_so_far = 0; + while(sent_so_far != client_msg.size()) { - result.test_note("Corruption caused exception"); + const size_t left = client_msg.size() - sent_so_far; + const size_t rnd12 = (rng.next_byte() << 4) ^ rng.next_byte(); + const size_t sending = std::min(left, rnd12); + + client->send(&client_msg[sent_so_far], sending); + sent_so_far += sending; } - else + client->send_warning_alert(Botan::TLS::Alert::NO_RENEGOTIATION); + client_has_written = true; + } + + if(m_server->is_active() && server_has_written == false) + { + m_results.test_eq("server ALPN protocol", m_server->application_protocol(), "test/3"); + + size_t sent_so_far = 0; + while(sent_so_far != server_msg.size()) { - result.test_failure("TLS client", e.what()); + const size_t left = server_msg.size() - sent_so_far; + const size_t rnd12 = (rng.next_byte() << 4) ^ rng.next_byte(); + const size_t sending = std::min(left, rnd12); + + m_server->send(&server_msg[sent_so_far], sending); + sent_so_far += sending; } + + m_server->send_warning_alert(Botan::TLS::Alert::NO_RENEGOTIATION); + server_has_written = true; + } + + if(m_c2s.size() > 0) + { + /* + * Use this as a temp value to hold the queues as otherwise they + * might end up appending more in response to messages during the + * handshake. + */ + std::vector<uint8_t> input; + std::swap(m_c2s, input); + + size_t needed = m_server->received_data(input.data(), input.size()); + m_results.test_eq("full packet received", needed, 0); + + continue; + } + + if(m_s2c.size() > 0) + { + std::vector<uint8_t> input; + std::swap(m_s2c, input); + + size_t needed = client->received_data(input.data(), input.size()); + m_results.test_eq("full packet received", needed, 0); + + continue; + } + + if(m_client_recv.size()) + { + m_results.test_eq("client recv", m_client_recv, server_msg); + } + + if(m_server_recv.size()) + { + m_results.test_eq("server recv", m_server_recv, client_msg); + } + + if(client->is_closed() && m_server->is_closed()) + { + break; + } + + if(m_server_recv.size() && m_client_recv.size()) + { + Botan::SymmetricKey client_key = client->key_material_export("label", "context", 32); + Botan::SymmetricKey server_key = m_server->key_material_export("label", "context", 32); + + m_results.test_eq("TLS key material export", client_key.bits_of(), server_key.bits_of()); + + client->close(); } } - result.end_timer(); + m_results.end_timer(); + } - return result; +Test::Result test_tls_handshake(Botan::TLS::Protocol_Version offer_version, + Botan::Credentials_Manager& creds, + const Botan::TLS::Policy& client_policy, + const Botan::TLS::Policy& server_policy, + Botan::RandomNumberGenerator& rng, + Botan::TLS::Session_Manager& client_sessions, + Botan::TLS::Session_Manager& server_sessions) + { + TLS_Handshake_Test test(offer_version, creds, + client_policy, server_policy, rng, + client_sessions, server_sessions); + + test.go(); + return test.results(); } Test::Result test_tls_handshake(Botan::TLS::Protocol_Version offer_version, |