aboutsummaryrefslogtreecommitdiffstats
path: root/src/tls/s_kex.cpp
diff options
context:
space:
mode:
authorlloyd <[email protected]>2012-01-25 12:49:29 +0000
committerlloyd <[email protected]>2012-01-25 12:49:29 +0000
commit50bcbb4d8f09189cc669bb482487858234da7f6e (patch)
tree0082a6a93e6929f7bdb671d7f46dc8c3918072ce /src/tls/s_kex.cpp
parent47ff984c0ae0f077b029d0921e7ce1b62fc8f72f (diff)
Move all key exchange mechanism code (eg DH/ECDH/SRP) out of the
server handshake flow and into the server and client key exchange message types. It already was hidden from the client handshake code.
Diffstat (limited to 'src/tls/s_kex.cpp')
-rw-r--r--src/tls/s_kex.cpp105
1 files changed, 72 insertions, 33 deletions
diff --git a/src/tls/s_kex.cpp b/src/tls/s_kex.cpp
index a62fa537a..5861d3494 100644
--- a/src/tls/s_kex.cpp
+++ b/src/tls/s_kex.cpp
@@ -8,10 +8,12 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_reader.h>
#include <botan/internal/tls_extensions.h>
+#include <botan/internal/assert.h>
#include <botan/loadstor.h>
#include <botan/pubkey.h>
#include <botan/dh.h>
#include <botan/ecdh.h>
+#include <botan/rsa.h>
#include <botan/oids.h>
#include <memory>
@@ -26,17 +28,39 @@ namespace TLS {
*/
Server_Key_Exchange::Server_Key_Exchange(Record_Writer& writer,
Handshake_State* state,
+ const Policy& policy,
RandomNumberGenerator& rng,
- const Private_Key* private_key)
+ const Private_Key* signing_key)
{
- if(const DH_PublicKey* dh = dynamic_cast<const DH_PublicKey*>(state->kex_priv))
+ const std::string kex_algo = state->suite.kex_algo();
+
+ if(kex_algo == "DH")
{
+ std::auto_ptr<DH_PrivateKey> dh(new DH_PrivateKey(rng, policy.dh_group()));
+
append_tls_length_value(m_params, BigInt::encode(dh->get_domain().get_p()), 2);
append_tls_length_value(m_params, BigInt::encode(dh->get_domain().get_g()), 2);
append_tls_length_value(m_params, dh->public_value(), 2);
+ m_kex_key = dh.release();
}
- else if(const ECDH_PublicKey* ecdh = dynamic_cast<const ECDH_PublicKey*>(state->kex_priv))
+ else if(kex_algo == "ECDH")
{
+ const std::vector<std::string>& curves =
+ state->client_hello->supported_ecc_curves();
+
+ if(curves.empty())
+ throw Internal_Error("Client sent no ECC extension but we negotiated ECDH");
+
+ const std::string curve_name = policy.choose_curve(curves);
+
+ if(curve_name == "")
+ throw TLS_Exception(HANDSHAKE_FAILURE,
+ "Could not agree on an ECC curve with the client");
+
+ EC_Group ec_group(curve_name);
+
+ std::auto_ptr<ECDH_PrivateKey> ecdh(new ECDH_PrivateKey(rng, ec_group));
+
const std::string ecdh_domain_oid = ecdh->domain().get_oid();
const std::string domain = OIDS::lookup(OID(ecdh_domain_oid));
@@ -50,40 +74,28 @@ Server_Key_Exchange::Server_Key_Exchange(Record_Writer& writer,
m_params.push_back(get_byte(1, named_curve_id));
append_tls_length_value(m_params, ecdh->public_value(), 1);
+
+ m_kex_key = ecdh.release();
}
else
- throw Decoding_Error("Unsupported server key exchange type " +
- state->kex_priv->algo_name());
-
- std::pair<std::string, Signature_Format> format =
- state->choose_sig_format(private_key, m_hash_algo, m_sig_algo, false);
+ throw Internal_Error("Server_Key_Exchange: Unknown kex type " + kex_algo);
- PK_Signer signer(*private_key, format.first, format.second);
+ if(state->suite.sig_algo() != "")
+ {
+ BOTAN_ASSERT(signing_key, "No signing key set");
- signer.update(state->client_hello->random());
- signer.update(state->server_hello->random());
- signer.update(params());
- m_signature = signer.signature(rng);
+ std::pair<std::string, Signature_Format> format =
+ state->choose_sig_format(signing_key, m_hash_algo, m_sig_algo, false);
- send(writer, state->hash);
- }
+ PK_Signer signer(*signing_key, format.first, format.second);
-/**
-* Serialize a Server Key Exchange message
-*/
-MemoryVector<byte> Server_Key_Exchange::serialize() const
- {
- MemoryVector<byte> buf = params();
-
- // This should be an explicit version check
- if(m_hash_algo != "" && m_sig_algo != "")
- {
- buf.push_back(Signature_Algorithms::hash_algo_code(m_hash_algo));
- buf.push_back(Signature_Algorithms::sig_algo_code(m_sig_algo));
+ signer.update(state->client_hello->random());
+ signer.update(state->server_hello->random());
+ signer.update(params());
+ m_signature = signer.signature(rng);
}
- append_tls_length_value(buf, m_signature, 2);
- return buf;
+ send(writer, state->hash);
}
/**
@@ -92,7 +104,8 @@ MemoryVector<byte> Server_Key_Exchange::serialize() const
Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion<byte>& buf,
const std::string& kex_algo,
const std::string& sig_algo,
- Protocol_Version version)
+ Protocol_Version version) :
+ m_kex_key(0)
{
if(buf.size() < 6)
throw Decoding_Error("Server_Key_Exchange: Packet corrupted");
@@ -120,7 +133,7 @@ Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion<byte>& buf,
const byte curve_type = reader.get_byte();
if(curve_type != 3)
- throw Decoding_Error("Server sent non-named ECC curve");
+ throw Decoding_Error("Server_Key_Exchange: Server sent non-named ECC curve");
const u16bit curve_id = reader.get_u16bit();
@@ -129,7 +142,8 @@ Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion<byte>& buf,
MemoryVector<byte> ecdh_key = reader.get_range<byte>(1, 1, 255);
if(name == "")
- throw Decoding_Error("Server sent unknown named curve " + to_string(curve_id));
+ throw Decoding_Error("Server_Key_Exchange: Server sent unknown named curve " +
+ to_string(curve_id));
m_params.push_back(curve_type);
m_params.push_back(get_byte(0, curve_id));
@@ -137,7 +151,8 @@ Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion<byte>& buf,
append_tls_length_value(m_params, ecdh_key, 1);
}
else
- throw Decoding_Error("Unsupported server key exchange type " + kex_algo);
+ throw Decoding_Error("Server_Key_Exchange: Unsupported server key exchange type " +
+ kex_algo);
if(sig_algo != "")
{
@@ -151,6 +166,25 @@ Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion<byte>& buf,
}
}
+
+/**
+* Serialize a Server Key Exchange message
+*/
+MemoryVector<byte> Server_Key_Exchange::serialize() const
+ {
+ MemoryVector<byte> buf = params();
+
+ // This should be an explicit version check
+ if(m_hash_algo != "" && m_sig_algo != "")
+ {
+ buf.push_back(Signature_Algorithms::hash_algo_code(m_hash_algo));
+ buf.push_back(Signature_Algorithms::sig_algo_code(m_sig_algo));
+ }
+
+ append_tls_length_value(buf, m_signature, 2);
+ return buf;
+ }
+
/**
* Verify a Server Key Exchange message
*/
@@ -171,6 +205,11 @@ bool Server_Key_Exchange::verify(const X509_Certificate& cert,
return verifier.check_signature(m_signature);
}
+const Private_Key& Server_Key_Exchange::server_kex_key() const
+ {
+ BOTAN_ASSERT(m_kex_key, "Key is non-NULL");
+ return *m_kex_key;
+ }
}
}