aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorlloyd <[email protected]>2012-01-25 12:49:29 +0000
committerlloyd <[email protected]>2012-01-25 12:49:29 +0000
commit50bcbb4d8f09189cc669bb482487858234da7f6e (patch)
tree0082a6a93e6929f7bdb671d7f46dc8c3918072ce
parent47ff984c0ae0f077b029d0921e7ce1b62fc8f72f (diff)
Move all key exchange mechanism code (eg DH/ECDH/SRP) out of the
server handshake flow and into the server and client key exchange message types. It already was hidden from the client handshake code.
-rw-r--r--src/tls/c_kex.cpp61
-rw-r--r--src/tls/s_kex.cpp105
-rw-r--r--src/tls/tls_handshake_state.cpp4
-rw-r--r--src/tls/tls_handshake_state.h3
-rw-r--r--src/tls/tls_messages.h12
-rw-r--r--src/tls/tls_server.cpp53
6 files changed, 143 insertions, 95 deletions
diff --git a/src/tls/c_kex.cpp b/src/tls/c_kex.cpp
index 901d9e004..ea2e91972 100644
--- a/src/tls/c_kex.cpp
+++ b/src/tls/c_kex.cpp
@@ -8,6 +8,7 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_reader.h>
#include <botan/internal/tls_extensions.h>
+#include <botan/internal/assert.h>
#include <botan/pubkey.h>
#include <botan/dh.h>
#include <botan/ecdh.h>
@@ -108,8 +109,7 @@ Client_Key_Exchange::Client_Key_Exchange(Record_Writer& writer,
append_tls_length_value(key_material, priv_key.public_value(), 1);
}
else
- throw Internal_Error("Server key exchange type " + state->suite.kex_algo() +
- " not known");
+ throw Internal_Error("Unknown key exchange type " + state->suite.kex_algo());
}
else
{
@@ -169,18 +169,33 @@ Client_Key_Exchange::Client_Key_Exchange(const MemoryRegion<byte>& contents,
}
/*
-* Return the pre_master_secret
+* Return the pre_master_secret (server side implementation)
*/
SecureVector<byte>
Client_Key_Exchange::pre_master_secret(RandomNumberGenerator& rng,
- const Private_Key* priv_key,
- Protocol_Version client_version)
+ const Handshake_State* state)
{
- if(const RSA_PrivateKey* rsa = dynamic_cast<const RSA_PrivateKey*>(priv_key))
+ const std::string kex_algo = state->suite.kex_algo();
+
+ if(kex_algo == "")
{
- PK_Decryptor_EME decryptor(*rsa, "PKCS1v15");
+ BOTAN_ASSERT(state->server_certs && !state->server_certs->cert_chain().empty(),
+ "No server certificate to use for RSA");
+
+ const Private_Key* private_key = state->server_rsa_kex_key;
+
+ if(!private_key)
+ throw Internal_Error("Expected RSA kex but no server kex key set");
+
+ if(!dynamic_cast<const RSA_PrivateKey*>(private_key))
+ throw Internal_Error("Expected RSA key but got " + private_key->algo_name());
+
+ PK_Decryptor_EME decryptor(*private_key, "PKCS1v15");
- try {
+ Protocol_Version client_version = state->client_hello->version();
+
+ try
+ {
pre_master = decryptor.decrypt(key_material);
if(pre_master.size() != 48 ||
@@ -189,7 +204,7 @@ Client_Key_Exchange::pre_master_secret(RandomNumberGenerator& rng,
{
throw Decoding_Error("Client_Key_Exchange: Secret corrupted");
}
- }
+ }
catch(...)
{
pre_master = rng.random_vec(48);
@@ -199,18 +214,26 @@ Client_Key_Exchange::pre_master_secret(RandomNumberGenerator& rng,
return pre_master;
}
-
- // DH or ECDH
- if(const PK_Key_Agreement_Key* dh = dynamic_cast<const PK_Key_Agreement_Key*>(priv_key))
+ else if(kex_algo == "DH" || kex_algo == "ECDH")
{
- try {
- PK_Key_Agreement ka(*dh, "Raw");
+ const Private_Key& private_key = state->server_kex->server_kex_key();
+
+ const PK_Key_Agreement_Key* ka_key =
+ dynamic_cast<const PK_Key_Agreement_Key*>(&private_key);
+
+ if(!ka_key)
+ throw Internal_Error("Expected key agreement key type but got " +
+ private_key.algo_name());
+
+ try
+ {
+ PK_Key_Agreement ka(*ka_key, "Raw");
- if(dh->algo_name() == "DH")
+ if(ka_key->algo_name() == "DH")
pre_master = strip_leading_zeros(ka.derive_key(0, key_material).bits_of());
else
pre_master = ka.derive_key(0, key_material).bits_of();
- }
+ }
catch(...)
{
/*
@@ -219,13 +242,13 @@ Client_Key_Exchange::pre_master_secret(RandomNumberGenerator& rng,
* on, allowing the protocol to fail later in the finished
* checks.
*/
- pre_master = rng.random_vec(dh->public_value().size());
+ pre_master = rng.random_vec(ka_key->public_value().size());
}
return pre_master;
}
-
- throw Invalid_Argument("Client_Key_Exchange: Unknown key type " + priv_key->algo_name());
+ else
+ throw Internal_Error("Client_Key_Exchange: Unknown kex type " + kex_algo);
}
}
diff --git a/src/tls/s_kex.cpp b/src/tls/s_kex.cpp
index a62fa537a..5861d3494 100644
--- a/src/tls/s_kex.cpp
+++ b/src/tls/s_kex.cpp
@@ -8,10 +8,12 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_reader.h>
#include <botan/internal/tls_extensions.h>
+#include <botan/internal/assert.h>
#include <botan/loadstor.h>
#include <botan/pubkey.h>
#include <botan/dh.h>
#include <botan/ecdh.h>
+#include <botan/rsa.h>
#include <botan/oids.h>
#include <memory>
@@ -26,17 +28,39 @@ namespace TLS {
*/
Server_Key_Exchange::Server_Key_Exchange(Record_Writer& writer,
Handshake_State* state,
+ const Policy& policy,
RandomNumberGenerator& rng,
- const Private_Key* private_key)
+ const Private_Key* signing_key)
{
- if(const DH_PublicKey* dh = dynamic_cast<const DH_PublicKey*>(state->kex_priv))
+ const std::string kex_algo = state->suite.kex_algo();
+
+ if(kex_algo == "DH")
{
+ std::auto_ptr<DH_PrivateKey> dh(new DH_PrivateKey(rng, policy.dh_group()));
+
append_tls_length_value(m_params, BigInt::encode(dh->get_domain().get_p()), 2);
append_tls_length_value(m_params, BigInt::encode(dh->get_domain().get_g()), 2);
append_tls_length_value(m_params, dh->public_value(), 2);
+ m_kex_key = dh.release();
}
- else if(const ECDH_PublicKey* ecdh = dynamic_cast<const ECDH_PublicKey*>(state->kex_priv))
+ else if(kex_algo == "ECDH")
{
+ const std::vector<std::string>& curves =
+ state->client_hello->supported_ecc_curves();
+
+ if(curves.empty())
+ throw Internal_Error("Client sent no ECC extension but we negotiated ECDH");
+
+ const std::string curve_name = policy.choose_curve(curves);
+
+ if(curve_name == "")
+ throw TLS_Exception(HANDSHAKE_FAILURE,
+ "Could not agree on an ECC curve with the client");
+
+ EC_Group ec_group(curve_name);
+
+ std::auto_ptr<ECDH_PrivateKey> ecdh(new ECDH_PrivateKey(rng, ec_group));
+
const std::string ecdh_domain_oid = ecdh->domain().get_oid();
const std::string domain = OIDS::lookup(OID(ecdh_domain_oid));
@@ -50,40 +74,28 @@ Server_Key_Exchange::Server_Key_Exchange(Record_Writer& writer,
m_params.push_back(get_byte(1, named_curve_id));
append_tls_length_value(m_params, ecdh->public_value(), 1);
+
+ m_kex_key = ecdh.release();
}
else
- throw Decoding_Error("Unsupported server key exchange type " +
- state->kex_priv->algo_name());
-
- std::pair<std::string, Signature_Format> format =
- state->choose_sig_format(private_key, m_hash_algo, m_sig_algo, false);
+ throw Internal_Error("Server_Key_Exchange: Unknown kex type " + kex_algo);
- PK_Signer signer(*private_key, format.first, format.second);
+ if(state->suite.sig_algo() != "")
+ {
+ BOTAN_ASSERT(signing_key, "No signing key set");
- signer.update(state->client_hello->random());
- signer.update(state->server_hello->random());
- signer.update(params());
- m_signature = signer.signature(rng);
+ std::pair<std::string, Signature_Format> format =
+ state->choose_sig_format(signing_key, m_hash_algo, m_sig_algo, false);
- send(writer, state->hash);
- }
+ PK_Signer signer(*signing_key, format.first, format.second);
-/**
-* Serialize a Server Key Exchange message
-*/
-MemoryVector<byte> Server_Key_Exchange::serialize() const
- {
- MemoryVector<byte> buf = params();
-
- // This should be an explicit version check
- if(m_hash_algo != "" && m_sig_algo != "")
- {
- buf.push_back(Signature_Algorithms::hash_algo_code(m_hash_algo));
- buf.push_back(Signature_Algorithms::sig_algo_code(m_sig_algo));
+ signer.update(state->client_hello->random());
+ signer.update(state->server_hello->random());
+ signer.update(params());
+ m_signature = signer.signature(rng);
}
- append_tls_length_value(buf, m_signature, 2);
- return buf;
+ send(writer, state->hash);
}
/**
@@ -92,7 +104,8 @@ MemoryVector<byte> Server_Key_Exchange::serialize() const
Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion<byte>& buf,
const std::string& kex_algo,
const std::string& sig_algo,
- Protocol_Version version)
+ Protocol_Version version) :
+ m_kex_key(0)
{
if(buf.size() < 6)
throw Decoding_Error("Server_Key_Exchange: Packet corrupted");
@@ -120,7 +133,7 @@ Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion<byte>& buf,
const byte curve_type = reader.get_byte();
if(curve_type != 3)
- throw Decoding_Error("Server sent non-named ECC curve");
+ throw Decoding_Error("Server_Key_Exchange: Server sent non-named ECC curve");
const u16bit curve_id = reader.get_u16bit();
@@ -129,7 +142,8 @@ Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion<byte>& buf,
MemoryVector<byte> ecdh_key = reader.get_range<byte>(1, 1, 255);
if(name == "")
- throw Decoding_Error("Server sent unknown named curve " + to_string(curve_id));
+ throw Decoding_Error("Server_Key_Exchange: Server sent unknown named curve " +
+ to_string(curve_id));
m_params.push_back(curve_type);
m_params.push_back(get_byte(0, curve_id));
@@ -137,7 +151,8 @@ Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion<byte>& buf,
append_tls_length_value(m_params, ecdh_key, 1);
}
else
- throw Decoding_Error("Unsupported server key exchange type " + kex_algo);
+ throw Decoding_Error("Server_Key_Exchange: Unsupported server key exchange type " +
+ kex_algo);
if(sig_algo != "")
{
@@ -151,6 +166,25 @@ Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion<byte>& buf,
}
}
+
+/**
+* Serialize a Server Key Exchange message
+*/
+MemoryVector<byte> Server_Key_Exchange::serialize() const
+ {
+ MemoryVector<byte> buf = params();
+
+ // This should be an explicit version check
+ if(m_hash_algo != "" && m_sig_algo != "")
+ {
+ buf.push_back(Signature_Algorithms::hash_algo_code(m_hash_algo));
+ buf.push_back(Signature_Algorithms::sig_algo_code(m_sig_algo));
+ }
+
+ append_tls_length_value(buf, m_signature, 2);
+ return buf;
+ }
+
/**
* Verify a Server Key Exchange message
*/
@@ -171,6 +205,11 @@ bool Server_Key_Exchange::verify(const X509_Certificate& cert,
return verifier.check_signature(m_signature);
}
+const Private_Key& Server_Key_Exchange::server_kex_key() const
+ {
+ BOTAN_ASSERT(m_kex_key, "Key is non-NULL");
+ return *m_kex_key;
+ }
}
}
diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp
index c98b147d9..b22039f5b 100644
--- a/src/tls/tls_handshake_state.cpp
+++ b/src/tls/tls_handshake_state.cpp
@@ -91,7 +91,7 @@ Handshake_State::Handshake_State()
client_finished = 0;
server_finished = 0;
- kex_priv = 0;
+ server_rsa_kex_key = 0;
version = Protocol_Version::SSL_V3;
@@ -265,7 +265,7 @@ Handshake_State::~Handshake_State()
delete client_finished;
delete server_finished;
- delete kex_priv;
+ delete server_rsa_kex_key;
}
}
diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h
index 7339033c4..93846da52 100644
--- a/src/tls/tls_handshake_state.h
+++ b/src/tls/tls_handshake_state.h
@@ -78,7 +78,8 @@ class Handshake_State
class Finished* client_finished;
class Finished* server_finished;
- Private_Key* kex_priv;
+ // Used by the server only, in case of RSA key exchange
+ Private_Key* server_rsa_kex_key;
Ciphersuite suite;
Session_Keys keys;
diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h
index 7eb67f3b6..7d4905a0e 100644
--- a/src/tls/tls_messages.h
+++ b/src/tls/tls_messages.h
@@ -212,8 +212,7 @@ class Client_Key_Exchange : public Handshake_Message
{ return pre_master; }
SecureVector<byte> pre_master_secret(RandomNumberGenerator& rng,
- const Private_Key* key,
- Protocol_Version version);
+ const Handshake_State* state);
Client_Key_Exchange(Record_Writer& output,
Handshake_State* state,
@@ -369,18 +368,25 @@ class Server_Key_Exchange : public Handshake_Message
bool verify(const X509_Certificate& cert,
Handshake_State* state) const;
+ const Private_Key& server_kex_key() const;
+
Server_Key_Exchange(Record_Writer& writer,
Handshake_State* state,
+ const Policy& policy,
RandomNumberGenerator& rng,
- const Private_Key* priv_key);
+ const Private_Key* signing_key = 0);
Server_Key_Exchange(const MemoryRegion<byte>& buf,
const std::string& kex_alg,
const std::string& sig_alg,
Protocol_Version version);
+
+ ~Server_Key_Exchange() { delete m_kex_key; }
private:
MemoryVector<byte> serialize() const;
+ Private_Key* m_kex_key;
+
MemoryVector<byte> m_params;
std::string m_sig_algo; // sig algo used to create signature
diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp
index 1b2e9b91e..1253a7327 100644
--- a/src/tls/tls_server.cpp
+++ b/src/tls/tls_server.cpp
@@ -10,8 +10,6 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/stl_util.h>
#include <botan/internal/assert.h>
-#include <botan/dh.h>
-#include <botan/ecdh.h>
#include <memory>
namespace Botan {
@@ -259,8 +257,6 @@ void Server::process_handshake_msg(Handshake_Type type,
const std::string sig_algo = state->suite.sig_algo();
const std::string kex_algo = state->suite.kex_algo();
- std::auto_ptr<Private_Key> private_key(0);
-
if(sig_algo != "")
{
BOTAN_ASSERT(!cert_chains[sig_algo].empty(),
@@ -269,43 +265,27 @@ void Server::process_handshake_msg(Handshake_Type type,
state->server_certs = new Certificate(writer,
state->hash,
cert_chains[sig_algo]);
-
- private_key.reset(creds.private_key_for(state->server_certs->cert_chain()[0],
- "tls-server",
- m_hostname));
}
- if(kex_algo != "")
+ std::auto_ptr<Private_Key> private_key(0);
+
+ if(kex_algo == "" || sig_algo != "")
{
- if(kex_algo == "DH")
- {
- state->kex_priv = new DH_PrivateKey(rng, policy.dh_group());
- }
- else if(kex_algo == "ECDH")
- {
- const std::vector<std::string>& curves =
- state->client_hello->supported_ecc_curves();
-
- if(curves.empty())
- throw Internal_Error("Client sent no ECC extension but we negotiated ECDH");
-
- const std::string curve_name = policy.choose_curve(curves);
-
- if(curve_name == "") // shouldn't happen
- throw Internal_Error("Could not agree on an ECC curve with the client");
-
- EC_Group ec_group(curve_name);
- state->kex_priv = new ECDH_PrivateKey(rng, ec_group);
- }
- else
- throw Internal_Error("Server: Unknown ciphersuite kex type " +
- kex_algo);
+ private_key.reset(
+ creds.private_key_for(state->server_certs->cert_chain()[0],
+ "tls-server",
+ m_hostname));
+ }
- state->server_kex =
- new Server_Key_Exchange(writer, state, rng, private_key.get());
+ if(kex_algo == "")
+ {
+ state->server_rsa_kex_key = private_key.release();
}
else
- state->kex_priv = private_key.release();
+ {
+ state->server_kex =
+ new Server_Key_Exchange(writer, state, policy, rng, private_key.get());
+ }
std::vector<X509_Certificate> client_auth_CAs =
creds.trusted_certificate_authorities("tls-server", m_hostname);
@@ -355,8 +335,7 @@ void Server::process_handshake_msg(Handshake_Type type,
state->version);
SecureVector<byte> pre_master =
- state->client_kex->pre_master_secret(rng, state->kex_priv,
- state->client_hello->version());
+ state->client_kex->pre_master_secret(rng, state);
state->keys = Session_Keys(state, pre_master, false);
}