aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/tls/c_kex.cpp76
-rw-r--r--src/tls/s_kex.cpp53
-rw-r--r--src/tls/tls_messages.h3
3 files changed, 95 insertions, 37 deletions
diff --git a/src/tls/c_kex.cpp b/src/tls/c_kex.cpp
index 78c60c1cc..b821df1a9 100644
--- a/src/tls/c_kex.cpp
+++ b/src/tls/c_kex.cpp
@@ -7,8 +7,10 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_reader.h>
+#include <botan/internal/tls_extensions.h>
#include <botan/pubkey.h>
#include <botan/dh.h>
+#include <botan/ecdh.h>
#include <botan/rsa.h>
#include <botan/rng.h>
#include <botan/loadstor.h>
@@ -46,8 +48,6 @@ Client_Key_Exchange::Client_Key_Exchange(Record_Writer& writer,
const std::vector<X509_Certificate>& peer_certs,
RandomNumberGenerator& rng)
{
- include_length = true;
-
if(state->server_kex)
{
TLS_Data_Reader reader(state->server_kex->params());
@@ -77,10 +77,40 @@ Client_Key_Exchange::Client_Key_Exchange(Record_Writer& writer,
pre_master = strip_leading_zeros(
ka.derive_key(0, counterparty_key.public_value()).bits_of());
- key_material = priv_key.public_value();
+ append_tls_length_value(key_material, priv_key.public_value(), 2);
+ }
+ else if(state->suite.kex_algo() == "ECDH")
+ {
+ const byte curve_type = reader.get_byte();
+
+ if(curve_type != 3)
+ throw Decoding_Error("Server sent non-named ECC curve");
+
+ const u16bit curve_id = reader.get_u16bit();
+
+ const std::string name = Supported_Elliptic_Curves::curve_id_to_name(curve_id);
+
+ if(name == "")
+ throw Decoding_Error("Server sent unknown named curve " + to_string(curve_id));
+
+ EC_Group group(name);
+
+ MemoryVector<byte> ecdh_key = reader.get_range<byte>(1, 1, 255);
+
+ ECDH_PublicKey counterparty_key(group, OS2ECP(ecdh_key, group.get_curve()));
+
+ ECDH_PrivateKey priv_key(rng, group);
+
+ PK_Key_Agreement ka(priv_key, "Raw");
+
+ pre_master = strip_leading_zeros(
+ ka.derive_key(0, counterparty_key.public_value()).bits_of());
+
+ append_tls_length_value(key_material, priv_key.public_value(), 1);
}
else
- throw Internal_Error("Server key exchange not a known key type");
+ throw Internal_Error("Server key exchange type " + state->suite.kex_algo() +
+ " not known");
}
else
{
@@ -101,10 +131,12 @@ Client_Key_Exchange::Client_Key_Exchange(Record_Writer& writer,
PK_Encryptor_EME encryptor(*rsa_pub, "PKCS1v15");
- key_material = encryptor.encrypt(pre_master, rng);
+ MemoryVector<byte> encrypted_key = encryptor.encrypt(pre_master, rng);
if(state->version == Protocol_Version::SSL_V3)
- include_length = false;
+ key_material = encrypted_key; // no length field
+ else
+ append_tls_length_value(key_material, encrypted_key, 2);
}
else
throw TLS_Exception(HANDSHAKE_FAILURE,
@@ -122,33 +154,19 @@ Client_Key_Exchange::Client_Key_Exchange(const MemoryRegion<byte>& contents,
const Ciphersuite& suite,
Protocol_Version using_version)
{
- include_length = true;
-
- if(using_version == Protocol_Version::SSL_V3 && (suite.kex_algo() == ""))
- include_length = false;
-
- if(include_length)
+ if(suite.kex_algo() == "" && using_version == Protocol_Version::SSL_V3)
+ key_material = contents;
+ else
{
TLS_Data_Reader reader(contents);
- key_material = reader.get_range<byte>(2, 0, 65535);
- }
- else
- key_material = contents;
- }
-/*
-* Serialize a Client Key Exchange message
-*/
-MemoryVector<byte> Client_Key_Exchange::serialize() const
- {
- if(include_length)
- {
- MemoryVector<byte> buf;
- append_tls_length_value(buf, key_material, 2);
- return buf;
+ if(suite.kex_algo() == "" || suite.kex_algo() == "DH")
+ key_material = reader.get_range<byte>(2, 0, 65535);
+ else if(suite.kex_algo() == "ECDH")
+ key_material = reader.get_range<byte>(1, 1, 255);
+ else
+ throw Internal_Error("Unknown client key exch type " + suite.kex_algo());
}
- else
- return key_material;
}
/*
diff --git a/src/tls/s_kex.cpp b/src/tls/s_kex.cpp
index 4425285f0..a62fa537a 100644
--- a/src/tls/s_kex.cpp
+++ b/src/tls/s_kex.cpp
@@ -11,8 +11,12 @@
#include <botan/loadstor.h>
#include <botan/pubkey.h>
#include <botan/dh.h>
+#include <botan/ecdh.h>
+#include <botan/oids.h>
#include <memory>
+#include <stdio.h>
+
namespace Botan {
namespace TLS {
@@ -25,15 +29,31 @@ Server_Key_Exchange::Server_Key_Exchange(Record_Writer& writer,
RandomNumberGenerator& rng,
const Private_Key* private_key)
{
- if(const DH_PublicKey* dh_pub = dynamic_cast<const DH_PublicKey*>(state->kex_priv))
+ if(const DH_PublicKey* dh = dynamic_cast<const DH_PublicKey*>(state->kex_priv))
+ {
+ append_tls_length_value(m_params, BigInt::encode(dh->get_domain().get_p()), 2);
+ append_tls_length_value(m_params, BigInt::encode(dh->get_domain().get_g()), 2);
+ append_tls_length_value(m_params, dh->public_value(), 2);
+ }
+ else if(const ECDH_PublicKey* ecdh = dynamic_cast<const ECDH_PublicKey*>(state->kex_priv))
{
- append_tls_length_value(m_params, BigInt::encode(dh_pub->get_domain().get_p()), 2);
- append_tls_length_value(m_params, BigInt::encode(dh_pub->get_domain().get_g()), 2);
- append_tls_length_value(m_params, dh_pub->public_value(), 2);
+ const std::string ecdh_domain_oid = ecdh->domain().get_oid();
+ const std::string domain = OIDS::lookup(OID(ecdh_domain_oid));
+
+ if(domain == "")
+ throw Internal_Error("Could not find name of ECDH domain " + ecdh_domain_oid);
+
+ const u16bit named_curve_id = Supported_Elliptic_Curves::name_to_curve_id(domain);
+
+ m_params.push_back(3); // named curve
+ m_params.push_back(get_byte(0, named_curve_id));
+ m_params.push_back(get_byte(1, named_curve_id));
+
+ append_tls_length_value(m_params, ecdh->public_value(), 1);
}
else
- throw Invalid_Argument("Unknown key type " + state->kex_priv->algo_name() +
- " for TLS key exchange");
+ throw Decoding_Error("Unsupported server key exchange type " +
+ state->kex_priv->algo_name());
std::pair<std::string, Signature_Format> format =
state->choose_sig_format(private_key, m_hash_algo, m_sig_algo, false);
@@ -95,6 +115,27 @@ Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion<byte>& buf,
append_tls_length_value(m_params, BigInt::encode(v), 2);
}
}
+ else if(kex_algo == "ECDH")
+ {
+ const byte curve_type = reader.get_byte();
+
+ if(curve_type != 3)
+ throw Decoding_Error("Server sent non-named ECC curve");
+
+ const u16bit curve_id = reader.get_u16bit();
+
+ const std::string name = Supported_Elliptic_Curves::curve_id_to_name(curve_id);
+
+ MemoryVector<byte> ecdh_key = reader.get_range<byte>(1, 1, 255);
+
+ if(name == "")
+ throw Decoding_Error("Server sent unknown named curve " + to_string(curve_id));
+
+ m_params.push_back(curve_type);
+ m_params.push_back(get_byte(0, curve_id));
+ m_params.push_back(get_byte(1, curve_id));
+ append_tls_length_value(m_params, ecdh_key, 1);
+ }
else
throw Decoding_Error("Unsupported server key exchange type " + kex_algo);
diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h
index 2b2e02baa..c3dbaaf42 100644
--- a/src/tls/tls_messages.h
+++ b/src/tls/tls_messages.h
@@ -220,10 +220,9 @@ class Client_Key_Exchange : public Handshake_Message
const Ciphersuite& suite,
Protocol_Version using_version);
private:
- MemoryVector<byte> serialize() const;
+ MemoryVector<byte> serialize() const { return key_material; }
SecureVector<byte> key_material, pre_master;
- bool include_length;
};
/**