aboutsummaryrefslogtreecommitdiffstats
path: root/src/lib
diff options
context:
space:
mode:
authorJack Lloyd <[email protected]>2018-01-27 13:38:04 -0500
committerJack Lloyd <[email protected]>2018-01-27 13:38:04 -0500
commite5cf7992ff53c3fbe4beb106d3fd80b8845957b7 (patch)
treea1732c98087a76ccc2bfc40f4b0ce846902632a5 /src/lib
parentcfe57137e5957b84b6b749db8d9f02c3ee1f8c1e (diff)
parent7f7feb41880d87ea170633b47f5dede30ea528de (diff)
Merge GH #1394 Add ability to use custom extensions, control which extensions are used
Diffstat (limited to 'src/lib')
-rw-r--r--src/lib/tls/msg_client_hello.cpp7
-rw-r--r--src/lib/tls/msg_server_hello.cpp7
-rw-r--r--src/lib/tls/tls_callbacks.cpp8
-rw-r--r--src/lib/tls/tls_callbacks.h34
-rw-r--r--src/lib/tls/tls_client.cpp4
-rw-r--r--src/lib/tls/tls_extensions.cpp30
-rw-r--r--src/lib/tls/tls_extensions.h48
-rw-r--r--src/lib/tls/tls_messages.h11
-rw-r--r--src/lib/tls/tls_server.cpp4
-rw-r--r--src/lib/tls/tls_server.h10
10 files changed, 149 insertions, 14 deletions
diff --git a/src/lib/tls/msg_client_hello.cpp b/src/lib/tls/msg_client_hello.cpp
index eeeaf8c71..68753fa26 100644
--- a/src/lib/tls/msg_client_hello.cpp
+++ b/src/lib/tls/msg_client_hello.cpp
@@ -10,6 +10,7 @@
#include <botan/tls_messages.h>
#include <botan/tls_alert.h>
#include <botan/tls_exceptn.h>
+#include <botan/tls_callbacks.h>
#include <botan/rng.h>
#include <botan/hash.h>
@@ -81,6 +82,7 @@ std::vector<uint8_t> Hello_Request::serialize() const
Client_Hello::Client_Hello(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
+ Callbacks& cb,
RandomNumberGenerator& rng,
const std::vector<uint8_t>& reneg_info,
const Client_Hello::Settings& client_settings,
@@ -140,6 +142,8 @@ Client_Hello::Client_Hello(Handshake_IO& io,
m_extensions.add(new Signature_Algorithms(policy.allowed_signature_hashes(),
policy.allowed_signature_methods()));
+ cb.tls_modify_extensions(m_extensions, CLIENT);
+
if(policy.send_fallback_scsv(client_settings.protocol_version()))
m_suites.push_back(TLS_FALLBACK_SCSV);
@@ -152,6 +156,7 @@ Client_Hello::Client_Hello(Handshake_IO& io,
Client_Hello::Client_Hello(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
+ Callbacks& cb,
RandomNumberGenerator& rng,
const std::vector<uint8_t>& reneg_info,
const Session& session,
@@ -201,6 +206,8 @@ Client_Hello::Client_Hello(Handshake_IO& io,
if(reneg_info.empty() && !next_protocols.empty())
m_extensions.add(new Application_Layer_Protocol_Notification(next_protocols));
+ cb.tls_modify_extensions(m_extensions, CLIENT);
+
hash.update(io.send(*this));
}
diff --git a/src/lib/tls/msg_server_hello.cpp b/src/lib/tls/msg_server_hello.cpp
index 5e290eb68..2d5a185f0 100644
--- a/src/lib/tls/msg_server_hello.cpp
+++ b/src/lib/tls/msg_server_hello.cpp
@@ -9,6 +9,7 @@
#include <botan/tls_messages.h>
#include <botan/tls_extensions.h>
+#include <botan/tls_callbacks.h>
#include <botan/internal/tls_reader.h>
#include <botan/internal/tls_session_key.h>
#include <botan/internal/tls_handshake_io.h>
@@ -23,6 +24,7 @@ namespace TLS {
Server_Hello::Server_Hello(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
+ Callbacks& cb,
RandomNumberGenerator& rng,
const std::vector<uint8_t>& reneg_info,
const Client_Hello& client_hello,
@@ -83,6 +85,8 @@ Server_Hello::Server_Hello(Handshake_IO& io,
}
}
+ cb.tls_modify_extensions(m_extensions, SERVER);
+
hash.update(io.send(*this));
}
@@ -90,6 +94,7 @@ Server_Hello::Server_Hello(Handshake_IO& io,
Server_Hello::Server_Hello(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
+ Callbacks& cb,
RandomNumberGenerator& rng,
const std::vector<uint8_t>& reneg_info,
const Client_Hello& client_hello,
@@ -130,6 +135,8 @@ Server_Hello::Server_Hello(Handshake_IO& io,
if(!next_protocol.empty() && client_hello.supports_alpn())
m_extensions.add(new Application_Layer_Protocol_Notification(next_protocol));
+ cb.tls_modify_extensions(m_extensions, SERVER);
+
hash.update(io.send(*this));
}
diff --git a/src/lib/tls/tls_callbacks.cpp b/src/lib/tls/tls_callbacks.cpp
index b8f38589e..7a64291c8 100644
--- a/src/lib/tls/tls_callbacks.cpp
+++ b/src/lib/tls/tls_callbacks.cpp
@@ -32,6 +32,14 @@ std::string TLS::Callbacks::tls_server_choose_app_protocol(const std::vector<std
return "";
}
+void TLS::Callbacks::tls_modify_extensions(Extensions&, Connection_Side)
+ {
+ }
+
+void TLS::Callbacks::tls_examine_extensions(const Extensions&, Connection_Side)
+ {
+ }
+
void TLS::Callbacks::tls_verify_cert_chain(
const std::vector<X509_Certificate>& cert_chain,
const std::vector<std::shared_ptr<const OCSP::Response>>& ocsp_responses,
diff --git a/src/lib/tls/tls_callbacks.h b/src/lib/tls/tls_callbacks.h
index 4437a222a..dd6ad2d4b 100644
--- a/src/lib/tls/tls_callbacks.h
+++ b/src/lib/tls/tls_callbacks.h
@@ -30,6 +30,7 @@ namespace TLS {
class Handshake_Message;
class Policy;
+class Extensions;
/**
* Encapsulates the callbacks that a TLS channel will make which are due to
@@ -250,6 +251,39 @@ class BOTAN_PUBLIC_API(2,0) Callbacks
virtual std::string tls_server_choose_app_protocol(const std::vector<std::string>& client_protos);
/**
+ * Optional callback: examine/modify Extensions before sending.
+ *
+ * Both client and server will call this callback on the Extensions object
+ * before serializing it in the client/server hellos. This allows an
+ * application to modify which extensions are sent during the
+ * handshake.
+ *
+ * Default implementation does nothing.
+ *
+ * @param extn the extensions
+ * @param which_side will be CLIENT or SERVER which is the current
+ * applications role in the exchange.
+ */
+ virtual void tls_modify_extensions(Extensions& extn, Connection_Side which_side);
+
+ /**
+ * Optional callback: examine peer extensions.
+ *
+ * Both client and server will call this callback with the Extensions
+ * object after receiving it from the peer. This allows examining the
+ * Extensions, for example to implement a custom extension. It also allows
+ * an application to require that a particular extension be implemented;
+ * throw an exception from this function to abort the handshake.
+ *
+ * Default implementation does nothing.
+ *
+ * @param extn the extensions
+ * @param which_side will be CLIENT if these are are the clients extensions (ie we are
+ * the server) or SERVER if these are the server extensions (we are the client).
+ */
+ virtual void tls_examine_extensions(const Extensions& extn, Connection_Side which_side);
+
+ /**
* Optional callback: error logging. (not currently called)
* @param err An error message related to this connection.
*/
diff --git a/src/lib/tls/tls_client.cpp b/src/lib/tls/tls_client.cpp
index c88b6a7db..5f84481ac 100644
--- a/src/lib/tls/tls_client.cpp
+++ b/src/lib/tls/tls_client.cpp
@@ -169,6 +169,7 @@ void Client::send_client_hello(Handshake_State& state_base,
new Client_Hello(state.handshake_io(),
state.hash(),
policy(),
+ callbacks(),
rng(),
secure_renegotiation_data_for_client_hello(),
session_info,
@@ -188,6 +189,7 @@ void Client::send_client_hello(Handshake_State& state_base,
state.handshake_io(),
state.hash(),
policy(),
+ callbacks(),
rng(),
secure_renegotiation_data_for_client_hello(),
client_settings,
@@ -294,6 +296,8 @@ void Client::process_handshake_msg(const Handshake_State* active_state,
"Server replied with DTLS-SRTP alg we did not send");
}
+ callbacks().tls_examine_extensions(state.server_hello()->extensions(), SERVER);
+
state.set_version(state.server_hello()->version());
m_application_protocol = state.server_hello()->next_protocol();
diff --git a/src/lib/tls/tls_extensions.cpp b/src/lib/tls/tls_extensions.cpp
index d521f6bf8..522cf4a4f 100644
--- a/src/lib/tls/tls_extensions.cpp
+++ b/src/lib/tls/tls_extensions.cpp
@@ -59,7 +59,8 @@ Extension* make_extension(TLS_Data_Reader& reader, uint16_t code, uint16_t size)
return new Session_Ticket(reader, size);
}
- return nullptr; // not known
+ return new Unknown_Extension(static_cast<Handshake_Extension_Type>(code),
+ reader, size);
}
}
@@ -82,10 +83,7 @@ void Extensions::deserialize(TLS_Data_Reader& reader)
extension_code,
extension_size);
- if(extn)
- this->add(extn);
- else // unknown/unhandled extension
- reader.discard_next(extension_size);
+ this->add(extn);
}
}
}
@@ -124,6 +122,15 @@ std::vector<uint8_t> Extensions::serialize() const
return buf;
}
+bool Extensions::remove_extension(Handshake_Extension_Type typ)
+ {
+ auto i = m_extensions.find(typ);
+ if(i == m_extensions.end())
+ return false;
+ m_extensions.erase(i);
+ return true;
+ }
+
std::set<Handshake_Extension_Type> Extensions::extension_types() const
{
std::set<Handshake_Extension_Type> offers;
@@ -132,6 +139,19 @@ std::set<Handshake_Extension_Type> Extensions::extension_types() const
return offers;
}
+Unknown_Extension::Unknown_Extension(Handshake_Extension_Type type,
+ TLS_Data_Reader& reader,
+ uint16_t extension_size) :
+ m_type(type),
+ m_value(reader.get_fixed<uint8_t>(extension_size))
+ {
+ }
+
+std::vector<uint8_t> Unknown_Extension::serialize() const
+ {
+ throw Invalid_State("Cannot encode an unknown TLS extension");
+ }
+
Server_Name_Indicator::Server_Name_Indicator(TLS_Data_Reader& reader,
uint16_t extension_size)
{
diff --git a/src/lib/tls/tls_extensions.h b/src/lib/tls/tls_extensions.h
index 221d8b46f..5ba3c0b8e 100644
--- a/src/lib/tls/tls_extensions.h
+++ b/src/lib/tls/tls_extensions.h
@@ -432,6 +432,30 @@ class Certificate_Status_Request final : public Extension
};
/**
+* Unknown extensions are deserialized as this type
+*/
+class BOTAN_UNSTABLE_API Unknown_Extension final : public Extension
+ {
+ public:
+ Unknown_Extension(Handshake_Extension_Type type,
+ TLS_Data_Reader& reader,
+ uint16_t extension_size);
+
+ std::vector<uint8_t> serialize() const override; // always fails
+
+ const std::vector<uint8_t>& value() { return m_value; }
+
+ bool empty() const override { return false; }
+
+ Handshake_Extension_Type type() const override { return m_type; }
+
+ private:
+ Handshake_Extension_Type m_type;
+ std::vector<uint8_t> m_value;
+
+ };
+
+/**
* Represents a block of extensions in a hello message
*/
class BOTAN_UNSTABLE_API Extensions final
@@ -442,13 +466,7 @@ class BOTAN_UNSTABLE_API Extensions final
template<typename T>
T* get() const
{
- Handshake_Extension_Type type = T::static_type();
-
- auto i = m_extensions.find(type);
-
- if(i != m_extensions.end())
- return dynamic_cast<T*>(i->second.get());
- return nullptr;
+ return dynamic_cast<T*>(get(T::static_type()));
}
template<typename T>
@@ -462,10 +480,26 @@ class BOTAN_UNSTABLE_API Extensions final
m_extensions[extn->type()].reset(extn);
}
+ Extension* get(Handshake_Extension_Type type) const
+ {
+ auto i = m_extensions.find(type);
+
+ if(i != m_extensions.end())
+ return i->second.get();
+ return nullptr;
+ }
+
std::vector<uint8_t> serialize() const;
void deserialize(TLS_Data_Reader& reader);
+ /**
+ * Remvoe an extension from this extensions object, if it exists.
+ * Returns true if the extension existed (and thus is now removed),
+ * otherwise false (the extension wasn't set in the first place).
+ */
+ bool remove_extension(Handshake_Extension_Type typ);
+
Extensions() = default;
explicit Extensions(TLS_Data_Reader& reader) { deserialize(reader); }
diff --git a/src/lib/tls/tls_messages.h b/src/lib/tls/tls_messages.h
index 35ec3c83c..75e65fa7f 100644
--- a/src/lib/tls/tls_messages.h
+++ b/src/lib/tls/tls_messages.h
@@ -38,9 +38,10 @@ namespace TLS {
class Session;
class Handshake_IO;
class Handshake_State;
+class Callbacks;
std::vector<uint8_t> make_hello_random(RandomNumberGenerator& rng,
- const Policy& policy);
+ const Policy& policy);
/**
* DTLS Hello Verify Request
@@ -145,9 +146,12 @@ class BOTAN_UNSTABLE_API Client_Hello final : public Handshake_Message
std::set<Handshake_Extension_Type> extension_types() const
{ return m_extensions.extension_types(); }
+ const Extensions& extensions() const { return m_extensions; }
+
Client_Hello(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
+ Callbacks& cb,
RandomNumberGenerator& rng,
const std::vector<uint8_t>& reneg_info,
const Client_Hello::Settings& client_settings,
@@ -156,6 +160,7 @@ class BOTAN_UNSTABLE_API Client_Hello final : public Handshake_Message
Client_Hello(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
+ Callbacks& cb,
RandomNumberGenerator& rng,
const std::vector<uint8_t>& reneg_info,
const Session& resumed_session,
@@ -274,6 +279,8 @@ class BOTAN_UNSTABLE_API Server_Hello final : public Handshake_Message
std::set<Handshake_Extension_Type> extension_types() const
{ return m_extensions.extension_types(); }
+ const Extensions& extensions() const { return m_extensions; }
+
bool prefers_compressed_ec_points() const
{
if(auto ecc_formats = m_extensions.get<Supported_Point_Formats>())
@@ -286,6 +293,7 @@ class BOTAN_UNSTABLE_API Server_Hello final : public Handshake_Message
Server_Hello(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
+ Callbacks& cb,
RandomNumberGenerator& rng,
const std::vector<uint8_t>& secure_reneg_info,
const Client_Hello& client_hello,
@@ -295,6 +303,7 @@ class BOTAN_UNSTABLE_API Server_Hello final : public Handshake_Message
Server_Hello(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
+ Callbacks& cb,
RandomNumberGenerator& rng,
const std::vector<uint8_t>& secure_reneg_info,
const Client_Hello& client_hello,
diff --git a/src/lib/tls/tls_server.cpp b/src/lib/tls/tls_server.cpp
index 2d2fb769b..38c5cf2ca 100644
--- a/src/lib/tls/tls_server.cpp
+++ b/src/lib/tls/tls_server.cpp
@@ -460,6 +460,8 @@ void Server::process_client_hello_msg(const Handshake_State* active_state,
pending_state.set_version(negotiated_version);
+ callbacks().tls_examine_extensions(pending_state.client_hello()->extensions(), CLIENT);
+
Session session_info;
const bool resuming =
pending_state.allow_session_resumption() &&
@@ -703,6 +705,7 @@ void Server::session_resume(Server_Handshake_State& pending_state,
pending_state.handshake_io(),
pending_state.hash(),
policy(),
+ callbacks(),
rng(),
secure_renegotiation_data_for_server_hello(),
*pending_state.client_hello(),
@@ -794,6 +797,7 @@ void Server::session_create(Server_Handshake_State& pending_state,
pending_state.handshake_io(),
pending_state.hash(),
policy(),
+ callbacks(),
rng(),
secure_renegotiation_data_for_server_hello(),
*pending_state.client_hello(),
diff --git a/src/lib/tls/tls_server.h b/src/lib/tls/tls_server.h
index eb6e710e1..7c5d9668f 100644
--- a/src/lib/tls/tls_server.h
+++ b/src/lib/tls/tls_server.h
@@ -96,12 +96,20 @@ class BOTAN_PUBLIC_API(2,0) Server final : public Channel
/**
* Return the protocol notification set by the client (using the
- * NPN extension) for this connection, if any. This value is not
+ * ALPN extension) for this connection, if any. This value is not
* tied to the session and a later renegotiation of the same
* session can choose a new protocol.
*/
std::string next_protocol() const { return m_next_protocol; }
+ /**
+ * Return the protocol notification set by the client (using the
+ * ALPN extension) for this connection, if any. This value is not
+ * tied to the session and a later renegotiation of the same
+ * session can choose a new protocol.
+ */
+ std::string application_protocol() const { return m_next_protocol; }
+
private:
std::vector<X509_Certificate>
get_peer_cert_chain(const Handshake_State& state) const override;