diff options
-rw-r--r-- | src/lib/tls/msg_client_hello.cpp | 7 | ||||
-rw-r--r-- | src/lib/tls/msg_server_hello.cpp | 7 | ||||
-rw-r--r-- | src/lib/tls/tls_callbacks.cpp | 4 | ||||
-rw-r--r-- | src/lib/tls/tls_callbacks.h | 12 | ||||
-rw-r--r-- | src/lib/tls/tls_client.cpp | 2 | ||||
-rw-r--r-- | src/lib/tls/tls_extensions.cpp | 9 | ||||
-rw-r--r-- | src/lib/tls/tls_extensions.h | 7 | ||||
-rw-r--r-- | src/lib/tls/tls_messages.h | 7 | ||||
-rw-r--r-- | src/lib/tls/tls_server.cpp | 2 |
9 files changed, 56 insertions, 1 deletions
diff --git a/src/lib/tls/msg_client_hello.cpp b/src/lib/tls/msg_client_hello.cpp index eeeaf8c71..77068a928 100644 --- a/src/lib/tls/msg_client_hello.cpp +++ b/src/lib/tls/msg_client_hello.cpp @@ -10,6 +10,7 @@ #include <botan/tls_messages.h> #include <botan/tls_alert.h> #include <botan/tls_exceptn.h> +#include <botan/tls_callbacks.h> #include <botan/rng.h> #include <botan/hash.h> @@ -81,6 +82,7 @@ std::vector<uint8_t> Hello_Request::serialize() const Client_Hello::Client_Hello(Handshake_IO& io, Handshake_Hash& hash, const Policy& policy, + Callbacks& cb, RandomNumberGenerator& rng, const std::vector<uint8_t>& reneg_info, const Client_Hello::Settings& client_settings, @@ -140,6 +142,8 @@ Client_Hello::Client_Hello(Handshake_IO& io, m_extensions.add(new Signature_Algorithms(policy.allowed_signature_hashes(), policy.allowed_signature_methods())); + cb.tls_modify_extensions(m_extensions); + if(policy.send_fallback_scsv(client_settings.protocol_version())) m_suites.push_back(TLS_FALLBACK_SCSV); @@ -152,6 +156,7 @@ Client_Hello::Client_Hello(Handshake_IO& io, Client_Hello::Client_Hello(Handshake_IO& io, Handshake_Hash& hash, const Policy& policy, + Callbacks& cb, RandomNumberGenerator& rng, const std::vector<uint8_t>& reneg_info, const Session& session, @@ -201,6 +206,8 @@ Client_Hello::Client_Hello(Handshake_IO& io, if(reneg_info.empty() && !next_protocols.empty()) m_extensions.add(new Application_Layer_Protocol_Notification(next_protocols)); + cb.tls_modify_extensions(m_extensions); + hash.update(io.send(*this)); } diff --git a/src/lib/tls/msg_server_hello.cpp b/src/lib/tls/msg_server_hello.cpp index 5e290eb68..815fd37f0 100644 --- a/src/lib/tls/msg_server_hello.cpp +++ b/src/lib/tls/msg_server_hello.cpp @@ -9,6 +9,7 @@ #include <botan/tls_messages.h> #include <botan/tls_extensions.h> +#include <botan/tls_callbacks.h> #include <botan/internal/tls_reader.h> #include <botan/internal/tls_session_key.h> #include <botan/internal/tls_handshake_io.h> @@ -23,6 +24,7 @@ namespace TLS { Server_Hello::Server_Hello(Handshake_IO& io, Handshake_Hash& hash, const Policy& policy, + Callbacks& cb, RandomNumberGenerator& rng, const std::vector<uint8_t>& reneg_info, const Client_Hello& client_hello, @@ -83,6 +85,8 @@ Server_Hello::Server_Hello(Handshake_IO& io, } } + cb.tls_modify_extensions(m_extensions); + hash.update(io.send(*this)); } @@ -90,6 +94,7 @@ Server_Hello::Server_Hello(Handshake_IO& io, Server_Hello::Server_Hello(Handshake_IO& io, Handshake_Hash& hash, const Policy& policy, + Callbacks& cb, RandomNumberGenerator& rng, const std::vector<uint8_t>& reneg_info, const Client_Hello& client_hello, @@ -130,6 +135,8 @@ Server_Hello::Server_Hello(Handshake_IO& io, if(!next_protocol.empty() && client_hello.supports_alpn()) m_extensions.add(new Application_Layer_Protocol_Notification(next_protocol)); + cb.tls_modify_extensions(m_extensions); + hash.update(io.send(*this)); } diff --git a/src/lib/tls/tls_callbacks.cpp b/src/lib/tls/tls_callbacks.cpp index b8f38589e..b13aa2406 100644 --- a/src/lib/tls/tls_callbacks.cpp +++ b/src/lib/tls/tls_callbacks.cpp @@ -32,6 +32,10 @@ std::string TLS::Callbacks::tls_server_choose_app_protocol(const std::vector<std return ""; } +void TLS::Callbacks::tls_modify_extensions(Extensions&) + { + } + void TLS::Callbacks::tls_verify_cert_chain( const std::vector<X509_Certificate>& cert_chain, const std::vector<std::shared_ptr<const OCSP::Response>>& ocsp_responses, diff --git a/src/lib/tls/tls_callbacks.h b/src/lib/tls/tls_callbacks.h index 4437a222a..3ac1a9d2f 100644 --- a/src/lib/tls/tls_callbacks.h +++ b/src/lib/tls/tls_callbacks.h @@ -30,6 +30,7 @@ namespace TLS { class Handshake_Message; class Policy; +class Extensions; /** * Encapsulates the callbacks that a TLS channel will make which are due to @@ -250,6 +251,17 @@ class BOTAN_PUBLIC_API(2,0) Callbacks virtual std::string tls_server_choose_app_protocol(const std::vector<std::string>& client_protos); /** + * Optional callback: examine/modify Extensions before sending. Both + * client and server will call this callback on the Extensions object + * before serializing it in the client/server hellos. This allows a client + * to modify which extensions are sent during the handshake. This also + * allows creating custom extensions. + * + * Default implementation does nothing. + */ + virtual void tls_modify_extensions(Extensions& extn); + + /** * Optional callback: error logging. (not currently called) * @param err An error message related to this connection. */ diff --git a/src/lib/tls/tls_client.cpp b/src/lib/tls/tls_client.cpp index c88b6a7db..e041c1a77 100644 --- a/src/lib/tls/tls_client.cpp +++ b/src/lib/tls/tls_client.cpp @@ -169,6 +169,7 @@ void Client::send_client_hello(Handshake_State& state_base, new Client_Hello(state.handshake_io(), state.hash(), policy(), + callbacks(), rng(), secure_renegotiation_data_for_client_hello(), session_info, @@ -188,6 +189,7 @@ void Client::send_client_hello(Handshake_State& state_base, state.handshake_io(), state.hash(), policy(), + callbacks(), rng(), secure_renegotiation_data_for_client_hello(), client_settings, diff --git a/src/lib/tls/tls_extensions.cpp b/src/lib/tls/tls_extensions.cpp index d521f6bf8..6497c3c11 100644 --- a/src/lib/tls/tls_extensions.cpp +++ b/src/lib/tls/tls_extensions.cpp @@ -124,6 +124,15 @@ std::vector<uint8_t> Extensions::serialize() const return buf; } +bool Extensions::remove_extension(Handshake_Extension_Type typ) + { + auto i = m_extensions.find(typ); + if(i == m_extensions.end()) + return false; + m_extensions.erase(i); + return true; + } + std::set<Handshake_Extension_Type> Extensions::extension_types() const { std::set<Handshake_Extension_Type> offers; diff --git a/src/lib/tls/tls_extensions.h b/src/lib/tls/tls_extensions.h index 221d8b46f..7b7b645bf 100644 --- a/src/lib/tls/tls_extensions.h +++ b/src/lib/tls/tls_extensions.h @@ -466,6 +466,13 @@ class BOTAN_UNSTABLE_API Extensions final void deserialize(TLS_Data_Reader& reader); + /** + * Remvoe an extension from this extensions object, if it exists. + * Returns true if the extension existed (and thus is now removed), + * otherwise false (the extension wasn't set in the first place). + */ + bool remove_extension(Handshake_Extension_Type typ); + Extensions() = default; explicit Extensions(TLS_Data_Reader& reader) { deserialize(reader); } diff --git a/src/lib/tls/tls_messages.h b/src/lib/tls/tls_messages.h index 35ec3c83c..471467fc2 100644 --- a/src/lib/tls/tls_messages.h +++ b/src/lib/tls/tls_messages.h @@ -38,9 +38,10 @@ namespace TLS { class Session; class Handshake_IO; class Handshake_State; +class Callbacks; std::vector<uint8_t> make_hello_random(RandomNumberGenerator& rng, - const Policy& policy); + const Policy& policy); /** * DTLS Hello Verify Request @@ -148,6 +149,7 @@ class BOTAN_UNSTABLE_API Client_Hello final : public Handshake_Message Client_Hello(Handshake_IO& io, Handshake_Hash& hash, const Policy& policy, + Callbacks& cb, RandomNumberGenerator& rng, const std::vector<uint8_t>& reneg_info, const Client_Hello::Settings& client_settings, @@ -156,6 +158,7 @@ class BOTAN_UNSTABLE_API Client_Hello final : public Handshake_Message Client_Hello(Handshake_IO& io, Handshake_Hash& hash, const Policy& policy, + Callbacks& cb, RandomNumberGenerator& rng, const std::vector<uint8_t>& reneg_info, const Session& resumed_session, @@ -286,6 +289,7 @@ class BOTAN_UNSTABLE_API Server_Hello final : public Handshake_Message Server_Hello(Handshake_IO& io, Handshake_Hash& hash, const Policy& policy, + Callbacks& cb, RandomNumberGenerator& rng, const std::vector<uint8_t>& secure_reneg_info, const Client_Hello& client_hello, @@ -295,6 +299,7 @@ class BOTAN_UNSTABLE_API Server_Hello final : public Handshake_Message Server_Hello(Handshake_IO& io, Handshake_Hash& hash, const Policy& policy, + Callbacks& cb, RandomNumberGenerator& rng, const std::vector<uint8_t>& secure_reneg_info, const Client_Hello& client_hello, diff --git a/src/lib/tls/tls_server.cpp b/src/lib/tls/tls_server.cpp index 2d2fb769b..3ed1b120d 100644 --- a/src/lib/tls/tls_server.cpp +++ b/src/lib/tls/tls_server.cpp @@ -703,6 +703,7 @@ void Server::session_resume(Server_Handshake_State& pending_state, pending_state.handshake_io(), pending_state.hash(), policy(), + callbacks(), rng(), secure_renegotiation_data_for_server_hello(), *pending_state.client_hello(), @@ -794,6 +795,7 @@ void Server::session_create(Server_Handshake_State& pending_state, pending_state.handshake_io(), pending_state.hash(), policy(), + callbacks(), rng(), secure_renegotiation_data_for_server_hello(), *pending_state.client_hello(), |