aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/lib/tls/msg_client_hello.cpp7
-rw-r--r--src/lib/tls/msg_server_hello.cpp7
-rw-r--r--src/lib/tls/tls_callbacks.cpp4
-rw-r--r--src/lib/tls/tls_callbacks.h12
-rw-r--r--src/lib/tls/tls_client.cpp2
-rw-r--r--src/lib/tls/tls_extensions.cpp9
-rw-r--r--src/lib/tls/tls_extensions.h7
-rw-r--r--src/lib/tls/tls_messages.h7
-rw-r--r--src/lib/tls/tls_server.cpp2
9 files changed, 56 insertions, 1 deletions
diff --git a/src/lib/tls/msg_client_hello.cpp b/src/lib/tls/msg_client_hello.cpp
index eeeaf8c71..77068a928 100644
--- a/src/lib/tls/msg_client_hello.cpp
+++ b/src/lib/tls/msg_client_hello.cpp
@@ -10,6 +10,7 @@
#include <botan/tls_messages.h>
#include <botan/tls_alert.h>
#include <botan/tls_exceptn.h>
+#include <botan/tls_callbacks.h>
#include <botan/rng.h>
#include <botan/hash.h>
@@ -81,6 +82,7 @@ std::vector<uint8_t> Hello_Request::serialize() const
Client_Hello::Client_Hello(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
+ Callbacks& cb,
RandomNumberGenerator& rng,
const std::vector<uint8_t>& reneg_info,
const Client_Hello::Settings& client_settings,
@@ -140,6 +142,8 @@ Client_Hello::Client_Hello(Handshake_IO& io,
m_extensions.add(new Signature_Algorithms(policy.allowed_signature_hashes(),
policy.allowed_signature_methods()));
+ cb.tls_modify_extensions(m_extensions);
+
if(policy.send_fallback_scsv(client_settings.protocol_version()))
m_suites.push_back(TLS_FALLBACK_SCSV);
@@ -152,6 +156,7 @@ Client_Hello::Client_Hello(Handshake_IO& io,
Client_Hello::Client_Hello(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
+ Callbacks& cb,
RandomNumberGenerator& rng,
const std::vector<uint8_t>& reneg_info,
const Session& session,
@@ -201,6 +206,8 @@ Client_Hello::Client_Hello(Handshake_IO& io,
if(reneg_info.empty() && !next_protocols.empty())
m_extensions.add(new Application_Layer_Protocol_Notification(next_protocols));
+ cb.tls_modify_extensions(m_extensions);
+
hash.update(io.send(*this));
}
diff --git a/src/lib/tls/msg_server_hello.cpp b/src/lib/tls/msg_server_hello.cpp
index 5e290eb68..815fd37f0 100644
--- a/src/lib/tls/msg_server_hello.cpp
+++ b/src/lib/tls/msg_server_hello.cpp
@@ -9,6 +9,7 @@
#include <botan/tls_messages.h>
#include <botan/tls_extensions.h>
+#include <botan/tls_callbacks.h>
#include <botan/internal/tls_reader.h>
#include <botan/internal/tls_session_key.h>
#include <botan/internal/tls_handshake_io.h>
@@ -23,6 +24,7 @@ namespace TLS {
Server_Hello::Server_Hello(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
+ Callbacks& cb,
RandomNumberGenerator& rng,
const std::vector<uint8_t>& reneg_info,
const Client_Hello& client_hello,
@@ -83,6 +85,8 @@ Server_Hello::Server_Hello(Handshake_IO& io,
}
}
+ cb.tls_modify_extensions(m_extensions);
+
hash.update(io.send(*this));
}
@@ -90,6 +94,7 @@ Server_Hello::Server_Hello(Handshake_IO& io,
Server_Hello::Server_Hello(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
+ Callbacks& cb,
RandomNumberGenerator& rng,
const std::vector<uint8_t>& reneg_info,
const Client_Hello& client_hello,
@@ -130,6 +135,8 @@ Server_Hello::Server_Hello(Handshake_IO& io,
if(!next_protocol.empty() && client_hello.supports_alpn())
m_extensions.add(new Application_Layer_Protocol_Notification(next_protocol));
+ cb.tls_modify_extensions(m_extensions);
+
hash.update(io.send(*this));
}
diff --git a/src/lib/tls/tls_callbacks.cpp b/src/lib/tls/tls_callbacks.cpp
index b8f38589e..b13aa2406 100644
--- a/src/lib/tls/tls_callbacks.cpp
+++ b/src/lib/tls/tls_callbacks.cpp
@@ -32,6 +32,10 @@ std::string TLS::Callbacks::tls_server_choose_app_protocol(const std::vector<std
return "";
}
+void TLS::Callbacks::tls_modify_extensions(Extensions&)
+ {
+ }
+
void TLS::Callbacks::tls_verify_cert_chain(
const std::vector<X509_Certificate>& cert_chain,
const std::vector<std::shared_ptr<const OCSP::Response>>& ocsp_responses,
diff --git a/src/lib/tls/tls_callbacks.h b/src/lib/tls/tls_callbacks.h
index 4437a222a..3ac1a9d2f 100644
--- a/src/lib/tls/tls_callbacks.h
+++ b/src/lib/tls/tls_callbacks.h
@@ -30,6 +30,7 @@ namespace TLS {
class Handshake_Message;
class Policy;
+class Extensions;
/**
* Encapsulates the callbacks that a TLS channel will make which are due to
@@ -250,6 +251,17 @@ class BOTAN_PUBLIC_API(2,0) Callbacks
virtual std::string tls_server_choose_app_protocol(const std::vector<std::string>& client_protos);
/**
+ * Optional callback: examine/modify Extensions before sending. Both
+ * client and server will call this callback on the Extensions object
+ * before serializing it in the client/server hellos. This allows a client
+ * to modify which extensions are sent during the handshake. This also
+ * allows creating custom extensions.
+ *
+ * Default implementation does nothing.
+ */
+ virtual void tls_modify_extensions(Extensions& extn);
+
+ /**
* Optional callback: error logging. (not currently called)
* @param err An error message related to this connection.
*/
diff --git a/src/lib/tls/tls_client.cpp b/src/lib/tls/tls_client.cpp
index c88b6a7db..e041c1a77 100644
--- a/src/lib/tls/tls_client.cpp
+++ b/src/lib/tls/tls_client.cpp
@@ -169,6 +169,7 @@ void Client::send_client_hello(Handshake_State& state_base,
new Client_Hello(state.handshake_io(),
state.hash(),
policy(),
+ callbacks(),
rng(),
secure_renegotiation_data_for_client_hello(),
session_info,
@@ -188,6 +189,7 @@ void Client::send_client_hello(Handshake_State& state_base,
state.handshake_io(),
state.hash(),
policy(),
+ callbacks(),
rng(),
secure_renegotiation_data_for_client_hello(),
client_settings,
diff --git a/src/lib/tls/tls_extensions.cpp b/src/lib/tls/tls_extensions.cpp
index d521f6bf8..6497c3c11 100644
--- a/src/lib/tls/tls_extensions.cpp
+++ b/src/lib/tls/tls_extensions.cpp
@@ -124,6 +124,15 @@ std::vector<uint8_t> Extensions::serialize() const
return buf;
}
+bool Extensions::remove_extension(Handshake_Extension_Type typ)
+ {
+ auto i = m_extensions.find(typ);
+ if(i == m_extensions.end())
+ return false;
+ m_extensions.erase(i);
+ return true;
+ }
+
std::set<Handshake_Extension_Type> Extensions::extension_types() const
{
std::set<Handshake_Extension_Type> offers;
diff --git a/src/lib/tls/tls_extensions.h b/src/lib/tls/tls_extensions.h
index 221d8b46f..7b7b645bf 100644
--- a/src/lib/tls/tls_extensions.h
+++ b/src/lib/tls/tls_extensions.h
@@ -466,6 +466,13 @@ class BOTAN_UNSTABLE_API Extensions final
void deserialize(TLS_Data_Reader& reader);
+ /**
+ * Remvoe an extension from this extensions object, if it exists.
+ * Returns true if the extension existed (and thus is now removed),
+ * otherwise false (the extension wasn't set in the first place).
+ */
+ bool remove_extension(Handshake_Extension_Type typ);
+
Extensions() = default;
explicit Extensions(TLS_Data_Reader& reader) { deserialize(reader); }
diff --git a/src/lib/tls/tls_messages.h b/src/lib/tls/tls_messages.h
index 35ec3c83c..471467fc2 100644
--- a/src/lib/tls/tls_messages.h
+++ b/src/lib/tls/tls_messages.h
@@ -38,9 +38,10 @@ namespace TLS {
class Session;
class Handshake_IO;
class Handshake_State;
+class Callbacks;
std::vector<uint8_t> make_hello_random(RandomNumberGenerator& rng,
- const Policy& policy);
+ const Policy& policy);
/**
* DTLS Hello Verify Request
@@ -148,6 +149,7 @@ class BOTAN_UNSTABLE_API Client_Hello final : public Handshake_Message
Client_Hello(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
+ Callbacks& cb,
RandomNumberGenerator& rng,
const std::vector<uint8_t>& reneg_info,
const Client_Hello::Settings& client_settings,
@@ -156,6 +158,7 @@ class BOTAN_UNSTABLE_API Client_Hello final : public Handshake_Message
Client_Hello(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
+ Callbacks& cb,
RandomNumberGenerator& rng,
const std::vector<uint8_t>& reneg_info,
const Session& resumed_session,
@@ -286,6 +289,7 @@ class BOTAN_UNSTABLE_API Server_Hello final : public Handshake_Message
Server_Hello(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
+ Callbacks& cb,
RandomNumberGenerator& rng,
const std::vector<uint8_t>& secure_reneg_info,
const Client_Hello& client_hello,
@@ -295,6 +299,7 @@ class BOTAN_UNSTABLE_API Server_Hello final : public Handshake_Message
Server_Hello(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
+ Callbacks& cb,
RandomNumberGenerator& rng,
const std::vector<uint8_t>& secure_reneg_info,
const Client_Hello& client_hello,
diff --git a/src/lib/tls/tls_server.cpp b/src/lib/tls/tls_server.cpp
index 2d2fb769b..3ed1b120d 100644
--- a/src/lib/tls/tls_server.cpp
+++ b/src/lib/tls/tls_server.cpp
@@ -703,6 +703,7 @@ void Server::session_resume(Server_Handshake_State& pending_state,
pending_state.handshake_io(),
pending_state.hash(),
policy(),
+ callbacks(),
rng(),
secure_renegotiation_data_for_server_hello(),
*pending_state.client_hello(),
@@ -794,6 +795,7 @@ void Server::session_create(Server_Handshake_State& pending_state,
pending_state.handshake_io(),
pending_state.hash(),
policy(),
+ callbacks(),
rng(),
secure_renegotiation_data_for_server_hello(),
*pending_state.client_hello(),