diff options
-rw-r--r-- | src/lib/tls/msg_client_kex.cpp | 2 | ||||
-rw-r--r-- | src/lib/tls/msg_server_kex.cpp | 4 | ||||
-rw-r--r-- | src/lib/tls/tls_callbacks.cpp | 6 | ||||
-rw-r--r-- | src/lib/tls/tls_callbacks.h | 11 |
4 files changed, 20 insertions, 3 deletions
diff --git a/src/lib/tls/msg_client_kex.cpp b/src/lib/tls/msg_client_kex.cpp index 2d0c2d019..b3dff072e 100644 --- a/src/lib/tls/msg_client_kex.cpp +++ b/src/lib/tls/msg_client_kex.cpp @@ -124,7 +124,7 @@ Client_Key_Exchange::Client_Key_Exchange(Handshake_IO& io, "Server sent ECC curve prohibited by policy"); } - const std::string curve_name = group_param_to_string(curve_id); + const std::string curve_name = state.callbacks().tls_decode_group_param(curve_id); if(curve_name == "") throw Decoding_Error("Server sent unknown named curve " + diff --git a/src/lib/tls/msg_server_kex.cpp b/src/lib/tls/msg_server_kex.cpp index 4cb204c68..1b42cba99 100644 --- a/src/lib/tls/msg_server_kex.cpp +++ b/src/lib/tls/msg_server_kex.cpp @@ -81,7 +81,7 @@ Server_Key_Exchange::Server_Key_Exchange(Handshake_IO& io, BOTAN_ASSERT(group_param_is_dh(shared_group), "DH groups for the DH ciphersuites god"); - const std::string group_name = group_param_to_string(shared_group); + const std::string group_name = state.callbacks().tls_decode_group_param(shared_group); std::unique_ptr<DH_PrivateKey> dh(new DH_PrivateKey(rng, DL_Group(group_name))); append_tls_length_value(m_params, BigInt::encode(dh->get_domain().get_p()), 2); @@ -117,7 +117,7 @@ Server_Key_Exchange::Server_Key_Exchange(Handshake_IO& io, { Group_Params curve = policy.choose_key_exchange_group(ec_groups); - const std::string curve_name = group_param_to_string(curve); + const std::string curve_name = state.callbacks().tls_decode_group_param(curve); EC_Group ec_group(curve_name); std::unique_ptr<ECDH_PrivateKey> ecdh(new ECDH_PrivateKey(rng, ec_group)); diff --git a/src/lib/tls/tls_callbacks.cpp b/src/lib/tls/tls_callbacks.cpp index b3b1b79bb..6919c36ca 100644 --- a/src/lib/tls/tls_callbacks.cpp +++ b/src/lib/tls/tls_callbacks.cpp @@ -8,6 +8,7 @@ #include <botan/tls_callbacks.h> #include <botan/tls_policy.h> +#include <botan/tls_algos.h> #include <botan/x509path.h> #include <botan/ocsp.h> #include <botan/dh.h> @@ -40,6 +41,11 @@ void TLS::Callbacks::tls_examine_extensions(const Extensions&, Connection_Side) { } +std::string TLS::Callbacks::tls_decode_group_param(Group_Params group_param) + { + return group_param_to_string(group_param); + } + void TLS::Callbacks::tls_verify_cert_chain( const std::vector<X509_Certificate>& cert_chain, const std::vector<std::shared_ptr<const OCSP::Response>>& ocsp_responses, diff --git a/src/lib/tls/tls_callbacks.h b/src/lib/tls/tls_callbacks.h index dd6ad2d4b..88e502e89 100644 --- a/src/lib/tls/tls_callbacks.h +++ b/src/lib/tls/tls_callbacks.h @@ -284,6 +284,17 @@ class BOTAN_PUBLIC_API(2,0) Callbacks virtual void tls_examine_extensions(const Extensions& extn, Connection_Side which_side); /** + * Optional callback: decode TLS group ID + * + * TLS uses a 16-bit field to identify ECC and DH groups. This callback + * handles the decoding. You only need to implement this if you are using + * a custom ECC or DH group (this is extremely uncommon). + * + * Default implementation uses the standard (IETF-defined) mappings. + */ + virtual std::string tls_decode_group_param(Group_Params group_param); + + /** * Optional callback: error logging. (not currently called) * @param err An error message related to this connection. */ |