From 50bcbb4d8f09189cc669bb482487858234da7f6e Mon Sep 17 00:00:00 2001 From: lloyd Date: Wed, 25 Jan 2012 12:49:29 +0000 Subject: Move all key exchange mechanism code (eg DH/ECDH/SRP) out of the server handshake flow and into the server and client key exchange message types. It already was hidden from the client handshake code. --- src/tls/c_kex.cpp | 61 +++++++++++++++-------- src/tls/s_kex.cpp | 105 +++++++++++++++++++++++++++------------- src/tls/tls_handshake_state.cpp | 4 +- src/tls/tls_handshake_state.h | 3 +- src/tls/tls_messages.h | 12 +++-- src/tls/tls_server.cpp | 53 ++++++-------------- 6 files changed, 143 insertions(+), 95 deletions(-) diff --git a/src/tls/c_kex.cpp b/src/tls/c_kex.cpp index 901d9e004..ea2e91972 100644 --- a/src/tls/c_kex.cpp +++ b/src/tls/c_kex.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -108,8 +109,7 @@ Client_Key_Exchange::Client_Key_Exchange(Record_Writer& writer, append_tls_length_value(key_material, priv_key.public_value(), 1); } else - throw Internal_Error("Server key exchange type " + state->suite.kex_algo() + - " not known"); + throw Internal_Error("Unknown key exchange type " + state->suite.kex_algo()); } else { @@ -169,18 +169,33 @@ Client_Key_Exchange::Client_Key_Exchange(const MemoryRegion& contents, } /* -* Return the pre_master_secret +* Return the pre_master_secret (server side implementation) */ SecureVector Client_Key_Exchange::pre_master_secret(RandomNumberGenerator& rng, - const Private_Key* priv_key, - Protocol_Version client_version) + const Handshake_State* state) { - if(const RSA_PrivateKey* rsa = dynamic_cast(priv_key)) + const std::string kex_algo = state->suite.kex_algo(); + + if(kex_algo == "") { - PK_Decryptor_EME decryptor(*rsa, "PKCS1v15"); + BOTAN_ASSERT(state->server_certs && !state->server_certs->cert_chain().empty(), + "No server certificate to use for RSA"); + + const Private_Key* private_key = state->server_rsa_kex_key; + + if(!private_key) + throw Internal_Error("Expected RSA kex but no server kex key set"); + + if(!dynamic_cast(private_key)) + throw Internal_Error("Expected RSA key but got " + private_key->algo_name()); + + PK_Decryptor_EME decryptor(*private_key, "PKCS1v15"); - try { + Protocol_Version client_version = state->client_hello->version(); + + try + { pre_master = decryptor.decrypt(key_material); if(pre_master.size() != 48 || @@ -189,7 +204,7 @@ Client_Key_Exchange::pre_master_secret(RandomNumberGenerator& rng, { throw Decoding_Error("Client_Key_Exchange: Secret corrupted"); } - } + } catch(...) { pre_master = rng.random_vec(48); @@ -199,18 +214,26 @@ Client_Key_Exchange::pre_master_secret(RandomNumberGenerator& rng, return pre_master; } - - // DH or ECDH - if(const PK_Key_Agreement_Key* dh = dynamic_cast(priv_key)) + else if(kex_algo == "DH" || kex_algo == "ECDH") { - try { - PK_Key_Agreement ka(*dh, "Raw"); + const Private_Key& private_key = state->server_kex->server_kex_key(); + + const PK_Key_Agreement_Key* ka_key = + dynamic_cast(&private_key); + + if(!ka_key) + throw Internal_Error("Expected key agreement key type but got " + + private_key.algo_name()); + + try + { + PK_Key_Agreement ka(*ka_key, "Raw"); - if(dh->algo_name() == "DH") + if(ka_key->algo_name() == "DH") pre_master = strip_leading_zeros(ka.derive_key(0, key_material).bits_of()); else pre_master = ka.derive_key(0, key_material).bits_of(); - } + } catch(...) { /* @@ -219,13 +242,13 @@ Client_Key_Exchange::pre_master_secret(RandomNumberGenerator& rng, * on, allowing the protocol to fail later in the finished * checks. */ - pre_master = rng.random_vec(dh->public_value().size()); + pre_master = rng.random_vec(ka_key->public_value().size()); } return pre_master; } - - throw Invalid_Argument("Client_Key_Exchange: Unknown key type " + priv_key->algo_name()); + else + throw Internal_Error("Client_Key_Exchange: Unknown kex type " + kex_algo); } } diff --git a/src/tls/s_kex.cpp b/src/tls/s_kex.cpp index a62fa537a..5861d3494 100644 --- a/src/tls/s_kex.cpp +++ b/src/tls/s_kex.cpp @@ -8,10 +8,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include @@ -26,17 +28,39 @@ namespace TLS { */ Server_Key_Exchange::Server_Key_Exchange(Record_Writer& writer, Handshake_State* state, + const Policy& policy, RandomNumberGenerator& rng, - const Private_Key* private_key) + const Private_Key* signing_key) { - if(const DH_PublicKey* dh = dynamic_cast(state->kex_priv)) + const std::string kex_algo = state->suite.kex_algo(); + + if(kex_algo == "DH") { + std::auto_ptr dh(new DH_PrivateKey(rng, policy.dh_group())); + append_tls_length_value(m_params, BigInt::encode(dh->get_domain().get_p()), 2); append_tls_length_value(m_params, BigInt::encode(dh->get_domain().get_g()), 2); append_tls_length_value(m_params, dh->public_value(), 2); + m_kex_key = dh.release(); } - else if(const ECDH_PublicKey* ecdh = dynamic_cast(state->kex_priv)) + else if(kex_algo == "ECDH") { + const std::vector& curves = + state->client_hello->supported_ecc_curves(); + + if(curves.empty()) + throw Internal_Error("Client sent no ECC extension but we negotiated ECDH"); + + const std::string curve_name = policy.choose_curve(curves); + + if(curve_name == "") + throw TLS_Exception(HANDSHAKE_FAILURE, + "Could not agree on an ECC curve with the client"); + + EC_Group ec_group(curve_name); + + std::auto_ptr ecdh(new ECDH_PrivateKey(rng, ec_group)); + const std::string ecdh_domain_oid = ecdh->domain().get_oid(); const std::string domain = OIDS::lookup(OID(ecdh_domain_oid)); @@ -50,40 +74,28 @@ Server_Key_Exchange::Server_Key_Exchange(Record_Writer& writer, m_params.push_back(get_byte(1, named_curve_id)); append_tls_length_value(m_params, ecdh->public_value(), 1); + + m_kex_key = ecdh.release(); } else - throw Decoding_Error("Unsupported server key exchange type " + - state->kex_priv->algo_name()); - - std::pair format = - state->choose_sig_format(private_key, m_hash_algo, m_sig_algo, false); + throw Internal_Error("Server_Key_Exchange: Unknown kex type " + kex_algo); - PK_Signer signer(*private_key, format.first, format.second); + if(state->suite.sig_algo() != "") + { + BOTAN_ASSERT(signing_key, "No signing key set"); - signer.update(state->client_hello->random()); - signer.update(state->server_hello->random()); - signer.update(params()); - m_signature = signer.signature(rng); + std::pair format = + state->choose_sig_format(signing_key, m_hash_algo, m_sig_algo, false); - send(writer, state->hash); - } + PK_Signer signer(*signing_key, format.first, format.second); -/** -* Serialize a Server Key Exchange message -*/ -MemoryVector Server_Key_Exchange::serialize() const - { - MemoryVector buf = params(); - - // This should be an explicit version check - if(m_hash_algo != "" && m_sig_algo != "") - { - buf.push_back(Signature_Algorithms::hash_algo_code(m_hash_algo)); - buf.push_back(Signature_Algorithms::sig_algo_code(m_sig_algo)); + signer.update(state->client_hello->random()); + signer.update(state->server_hello->random()); + signer.update(params()); + m_signature = signer.signature(rng); } - append_tls_length_value(buf, m_signature, 2); - return buf; + send(writer, state->hash); } /** @@ -92,7 +104,8 @@ MemoryVector Server_Key_Exchange::serialize() const Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion& buf, const std::string& kex_algo, const std::string& sig_algo, - Protocol_Version version) + Protocol_Version version) : + m_kex_key(0) { if(buf.size() < 6) throw Decoding_Error("Server_Key_Exchange: Packet corrupted"); @@ -120,7 +133,7 @@ Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion& buf, const byte curve_type = reader.get_byte(); if(curve_type != 3) - throw Decoding_Error("Server sent non-named ECC curve"); + throw Decoding_Error("Server_Key_Exchange: Server sent non-named ECC curve"); const u16bit curve_id = reader.get_u16bit(); @@ -129,7 +142,8 @@ Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion& buf, MemoryVector ecdh_key = reader.get_range(1, 1, 255); if(name == "") - throw Decoding_Error("Server sent unknown named curve " + to_string(curve_id)); + throw Decoding_Error("Server_Key_Exchange: Server sent unknown named curve " + + to_string(curve_id)); m_params.push_back(curve_type); m_params.push_back(get_byte(0, curve_id)); @@ -137,7 +151,8 @@ Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion& buf, append_tls_length_value(m_params, ecdh_key, 1); } else - throw Decoding_Error("Unsupported server key exchange type " + kex_algo); + throw Decoding_Error("Server_Key_Exchange: Unsupported server key exchange type " + + kex_algo); if(sig_algo != "") { @@ -151,6 +166,25 @@ Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion& buf, } } + +/** +* Serialize a Server Key Exchange message +*/ +MemoryVector Server_Key_Exchange::serialize() const + { + MemoryVector buf = params(); + + // This should be an explicit version check + if(m_hash_algo != "" && m_sig_algo != "") + { + buf.push_back(Signature_Algorithms::hash_algo_code(m_hash_algo)); + buf.push_back(Signature_Algorithms::sig_algo_code(m_sig_algo)); + } + + append_tls_length_value(buf, m_signature, 2); + return buf; + } + /** * Verify a Server Key Exchange message */ @@ -171,6 +205,11 @@ bool Server_Key_Exchange::verify(const X509_Certificate& cert, return verifier.check_signature(m_signature); } +const Private_Key& Server_Key_Exchange::server_kex_key() const + { + BOTAN_ASSERT(m_kex_key, "Key is non-NULL"); + return *m_kex_key; + } } } diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp index c98b147d9..b22039f5b 100644 --- a/src/tls/tls_handshake_state.cpp +++ b/src/tls/tls_handshake_state.cpp @@ -91,7 +91,7 @@ Handshake_State::Handshake_State() client_finished = 0; server_finished = 0; - kex_priv = 0; + server_rsa_kex_key = 0; version = Protocol_Version::SSL_V3; @@ -265,7 +265,7 @@ Handshake_State::~Handshake_State() delete client_finished; delete server_finished; - delete kex_priv; + delete server_rsa_kex_key; } } diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h index 7339033c4..93846da52 100644 --- a/src/tls/tls_handshake_state.h +++ b/src/tls/tls_handshake_state.h @@ -78,7 +78,8 @@ class Handshake_State class Finished* client_finished; class Finished* server_finished; - Private_Key* kex_priv; + // Used by the server only, in case of RSA key exchange + Private_Key* server_rsa_kex_key; Ciphersuite suite; Session_Keys keys; diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h index 7eb67f3b6..7d4905a0e 100644 --- a/src/tls/tls_messages.h +++ b/src/tls/tls_messages.h @@ -212,8 +212,7 @@ class Client_Key_Exchange : public Handshake_Message { return pre_master; } SecureVector pre_master_secret(RandomNumberGenerator& rng, - const Private_Key* key, - Protocol_Version version); + const Handshake_State* state); Client_Key_Exchange(Record_Writer& output, Handshake_State* state, @@ -369,18 +368,25 @@ class Server_Key_Exchange : public Handshake_Message bool verify(const X509_Certificate& cert, Handshake_State* state) const; + const Private_Key& server_kex_key() const; + Server_Key_Exchange(Record_Writer& writer, Handshake_State* state, + const Policy& policy, RandomNumberGenerator& rng, - const Private_Key* priv_key); + const Private_Key* signing_key = 0); Server_Key_Exchange(const MemoryRegion& buf, const std::string& kex_alg, const std::string& sig_alg, Protocol_Version version); + + ~Server_Key_Exchange() { delete m_kex_key; } private: MemoryVector serialize() const; + Private_Key* m_kex_key; + MemoryVector m_params; std::string m_sig_algo; // sig algo used to create signature diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp index 1b2e9b91e..1253a7327 100644 --- a/src/tls/tls_server.cpp +++ b/src/tls/tls_server.cpp @@ -10,8 +10,6 @@ #include #include #include -#include -#include #include namespace Botan { @@ -259,8 +257,6 @@ void Server::process_handshake_msg(Handshake_Type type, const std::string sig_algo = state->suite.sig_algo(); const std::string kex_algo = state->suite.kex_algo(); - std::auto_ptr private_key(0); - if(sig_algo != "") { BOTAN_ASSERT(!cert_chains[sig_algo].empty(), @@ -269,43 +265,27 @@ void Server::process_handshake_msg(Handshake_Type type, state->server_certs = new Certificate(writer, state->hash, cert_chains[sig_algo]); - - private_key.reset(creds.private_key_for(state->server_certs->cert_chain()[0], - "tls-server", - m_hostname)); } - if(kex_algo != "") + std::auto_ptr private_key(0); + + if(kex_algo == "" || sig_algo != "") { - if(kex_algo == "DH") - { - state->kex_priv = new DH_PrivateKey(rng, policy.dh_group()); - } - else if(kex_algo == "ECDH") - { - const std::vector& curves = - state->client_hello->supported_ecc_curves(); - - if(curves.empty()) - throw Internal_Error("Client sent no ECC extension but we negotiated ECDH"); - - const std::string curve_name = policy.choose_curve(curves); - - if(curve_name == "") // shouldn't happen - throw Internal_Error("Could not agree on an ECC curve with the client"); - - EC_Group ec_group(curve_name); - state->kex_priv = new ECDH_PrivateKey(rng, ec_group); - } - else - throw Internal_Error("Server: Unknown ciphersuite kex type " + - kex_algo); + private_key.reset( + creds.private_key_for(state->server_certs->cert_chain()[0], + "tls-server", + m_hostname)); + } - state->server_kex = - new Server_Key_Exchange(writer, state, rng, private_key.get()); + if(kex_algo == "") + { + state->server_rsa_kex_key = private_key.release(); } else - state->kex_priv = private_key.release(); + { + state->server_kex = + new Server_Key_Exchange(writer, state, policy, rng, private_key.get()); + } std::vector client_auth_CAs = creds.trusted_certificate_authorities("tls-server", m_hostname); @@ -355,8 +335,7 @@ void Server::process_handshake_msg(Handshake_Type type, state->version); SecureVector pre_master = - state->client_kex->pre_master_secret(rng, state->kex_priv, - state->client_hello->version()); + state->client_kex->pre_master_secret(rng, state); state->keys = Session_Keys(state, pre_master, false); } -- cgit v1.2.3