aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/tls/c_hello.cpp68
-rw-r--r--src/tls/cert_req.cpp54
-rw-r--r--src/tls/cert_ver.cpp47
-rw-r--r--src/tls/s_hello.cpp34
-rw-r--r--src/tls/s_kex.cpp20
-rw-r--r--src/tls/tls_client.cpp5
-rw-r--r--src/tls/tls_extensions.cpp19
-rw-r--r--src/tls/tls_extensions.h52
-rw-r--r--src/tls/tls_handshake_state.cpp69
-rw-r--r--src/tls/tls_handshake_state.h7
-rw-r--r--src/tls/tls_messages.h14
-rw-r--r--src/tls/tls_policy.h2
-rw-r--r--src/tls/tls_server.cpp10
13 files changed, 251 insertions, 150 deletions
diff --git a/src/tls/c_hello.cpp b/src/tls/c_hello.cpp
index a70713a80..71c0c3de9 100644
--- a/src/tls/c_hello.cpp
+++ b/src/tls/c_hello.cpp
@@ -147,20 +147,20 @@ MemoryVector<byte> Client_Hello::serialize() const
// Initial handshake
if(m_renegotiation_info.empty())
{
- extensions.push_back(new Renegotation_Extension(m_renegotiation_info));
- extensions.push_back(new Server_Name_Indicator(m_hostname));
- extensions.push_back(new SRP_Identifier(m_srp_identifier));
+ extensions.add(new Renegotation_Extension(m_renegotiation_info));
+ extensions.add(new Server_Name_Indicator(m_hostname));
+ extensions.add(new SRP_Identifier(m_srp_identifier));
if(m_version >= TLS_V12)
- extensions.push_back(new Signature_Algorithms());
+ extensions.add(new Signature_Algorithms());
if(m_next_protocol)
- extensions.push_back(new Next_Protocol_Notification());
+ extensions.add(new Next_Protocol_Notification());
}
else
{
// renegotiation
- extensions.push_back(new Renegotation_Extension(m_renegotiation_info));
+ extensions.add(new Renegotation_Extension(m_renegotiation_info));
}
buf += extensions.serialize();
@@ -237,35 +237,39 @@ void Client_Hello::deserialize(const MemoryRegion<byte>& buf)
TLS_Extensions extensions(reader);
- for(size_t i = 0; i != extensions.count(); ++i)
+ if(Server_Name_Indicator* sni = extensions.get<Server_Name_Indicator>())
{
- TLS_Extension* extn = extensions.at(i);
+ m_hostname = sni->host_name();
+ }
- if(Server_Name_Indicator* sni = dynamic_cast<Server_Name_Indicator*>(extn))
- {
- m_hostname = sni->host_name();
- }
- else if(SRP_Identifier* srp = dynamic_cast<SRP_Identifier*>(extn))
- {
- m_srp_identifier = srp->identifier();
- }
- else if(Next_Protocol_Notification* npn = dynamic_cast<Next_Protocol_Notification*>(extn))
- {
- if(!npn->protocols().empty())
- throw Decoding_Error("Client sent non-empty NPN extension");
+ if(SRP_Identifier* srp = extensions.get<SRP_Identifier>())
+ {
+ m_srp_identifier = srp->identifier();
+ }
- m_next_protocol = true;
- }
- else if(Maximum_Fragment_Length* frag = dynamic_cast<Maximum_Fragment_Length*>(extn))
- {
- m_fragment_size = frag->fragment_size();
- }
- else if(Renegotation_Extension* reneg = dynamic_cast<Renegotation_Extension*>(extn))
- {
- // checked by TLS_Client / TLS_Server as they know the handshake state
- m_secure_renegotiation = true;
- m_renegotiation_info = reneg->renegotiation_info();
- }
+ if(Next_Protocol_Notification* npn = extensions.get<Next_Protocol_Notification>())
+ {
+ if(!npn->protocols().empty())
+ throw Decoding_Error("Client sent non-empty NPN extension");
+
+ m_next_protocol = true;
+ }
+
+ if(Maximum_Fragment_Length* frag = extensions.get<Maximum_Fragment_Length>())
+ {
+ m_fragment_size = frag->fragment_size();
+ }
+
+ if(Renegotation_Extension* reneg = extensions.get<Renegotation_Extension>())
+ {
+ // checked by TLS_Client / TLS_Server as they know the handshake state
+ m_secure_renegotiation = true;
+ m_renegotiation_info = reneg->renegotiation_info();
+ }
+
+ if(Signature_Algorithms* sigs = extensions.get<Signature_Algorithms>())
+ {
+ // save in handshake state
}
if(value_exists(m_suites, static_cast<u16bit>(TLS_EMPTY_RENEGOTIATION_INFO_SCSV)))
diff --git a/src/tls/cert_req.cpp b/src/tls/cert_req.cpp
index bdb25057c..c3e46a5ae 100644
--- a/src/tls/cert_req.cpp
+++ b/src/tls/cert_req.cpp
@@ -7,11 +7,14 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_reader.h>
+#include <botan/internal/tls_extensions.h>
#include <botan/der_enc.h>
#include <botan/ber_dec.h>
#include <botan/loadstor.h>
#include <botan/secqueue.h>
+#include <stdio.h>
+
namespace Botan {
/**
@@ -20,18 +23,16 @@ namespace Botan {
Certificate_Req::Certificate_Req(Record_Writer& writer,
TLS_Handshake_Hash& hash,
const std::vector<X509_Certificate>& ca_certs,
- const std::vector<Certificate_Type>& cert_types)
+ Version_Code version)
{
for(size_t i = 0; i != ca_certs.size(); ++i)
names.push_back(ca_certs[i].subject_dn());
- if(cert_types.empty()) // default is RSA/DSA is OK
- {
- types.push_back(RSA_CERT);
- types.push_back(DSS_CERT);
- }
- else
- types = cert_types;
+ cert_types.push_back(RSA_CERT);
+ cert_types.push_back(DSS_CERT);
+
+ if(version >= TLS_V12)
+ sig_and_hash_algos = Signature_Algorithms().serialize();
send(writer, hash);
}
@@ -39,39 +40,36 @@ Certificate_Req::Certificate_Req(Record_Writer& writer,
/**
* Deserialize a Certificate Request message
*/
-Certificate_Req::Certificate_Req(const MemoryRegion<byte>& buf)
+Certificate_Req::Certificate_Req(const MemoryRegion<byte>& buf,
+ Version_Code version)
{
if(buf.size() < 4)
throw Decoding_Error("Certificate_Req: Bad certificate request");
- const size_t types_size = buf[0];
+ TLS_Data_Reader reader(buf);
- if(buf.size() < types_size + 3)
- throw Decoding_Error("Certificate_Req: Bad certificate request");
+ cert_types = reader.get_range_vector<byte>(1, 1, 255);
- for(size_t i = 0; i != types_size; ++i)
- types.push_back(static_cast<Certificate_Type>(buf[i+1]));
+ if(version >= TLS_V12)
+ {
+ std::vector<u16bit> sig_hash_algs = reader.get_range_vector<u16bit>(2, 2, 65534);
- const size_t names_size = make_u16bit(buf[types_size+1], buf[types_size+2]);
+ // FIXME, do something with this
+ }
- if(buf.size() != names_size + types_size + 3)
- throw Decoding_Error("Certificate_Req: Bad certificate request");
+ u16bit purported_size = reader.get_u16bit();
- size_t offset = types_size + 3;
+ if(reader.remaining_bytes() != purported_size)
+ throw Decoding_Error("Inconsistent length in certificate request");
- while(offset < buf.size())
+ while(reader.has_remaining())
{
- const size_t name_size = make_u16bit(buf[offset], buf[offset+1]);
-
- if(offset + 2 + name_size > buf.size())
- throw Decoding_Error("Certificate_Req: Bad certificate request");
+ std::vector<byte> name_bits = reader.get_range_vector<byte>(2, 0, 65535);
- BER_Decoder decoder(&buf[offset + 2], name_size);
+ BER_Decoder decoder(&name_bits[0], name_bits.size());
X509_DN name;
decoder.decode(name);
names.push_back(name);
-
- offset += (2 + name_size);
}
}
@@ -82,7 +80,9 @@ MemoryVector<byte> Certificate_Req::serialize() const
{
MemoryVector<byte> buf;
- append_tls_length_value(buf, types, 1);
+ append_tls_length_value(buf, cert_types, 1);
+
+ buf += sig_and_hash_algos;
for(size_t i = 0; i != names.size(); ++i)
{
diff --git a/src/tls/cert_ver.cpp b/src/tls/cert_ver.cpp
index f35202734..f7386dd13 100644
--- a/src/tls/cert_ver.cpp
+++ b/src/tls/cert_ver.cpp
@@ -7,6 +7,7 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_reader.h>
+#include <botan/internal/tls_extensions.h>
#include <botan/internal/assert.h>
#include <botan/tls_exceptn.h>
#include <botan/pubkey.h>
@@ -27,14 +28,8 @@ Certificate_Verify::Certificate_Verify(Record_Writer& writer,
{
BOTAN_ASSERT_NONNULL(priv_key);
- // FIXME: this should respect server's hash preferences
- if(state->version >= TLS_V12)
- hash_algo = TLS_ALGO_HASH_SHA256;
- else
- hash_algo = TLS_ALGO_NONE;
-
std::pair<std::string, Signature_Format> format =
- state->choose_sig_format(priv_key, hash_algo, true);
+ state->choose_sig_format(priv_key, hash_algo, sig_algo, true);
PK_Signer signer(*priv_key, format.first, format.second);
@@ -48,13 +43,10 @@ Certificate_Verify::Certificate_Verify(Record_Writer& writer,
else
signature = signer.sign_message(md5_sha, rng);
}
- else if(state->version == TLS_V10 || state->version == TLS_V11)
+ else
{
signature = signer.sign_message(state->hash.get_contents(), rng);
}
- else
- throw TLS_Exception(PROTOCOL_VERSION,
- "Unknown TLS version in certificate verification");
send(writer, state->hash);
}
@@ -62,9 +54,23 @@ Certificate_Verify::Certificate_Verify(Record_Writer& writer,
/*
* Deserialize a Certificate Verify message
*/
-Certificate_Verify::Certificate_Verify(const MemoryRegion<byte>& buf)
+Certificate_Verify::Certificate_Verify(const MemoryRegion<byte>& buf,
+ Version_Code version)
{
TLS_Data_Reader reader(buf);
+
+ if(version < TLS_V12)
+ {
+ // use old defaults
+ hash_algo = TLS_ALGO_NONE;
+ sig_algo = TLS_ALGO_NONE;
+ }
+ else
+ {
+ hash_algo = Signature_Algorithms::hash_algo_code(reader.get_byte());
+ sig_algo = Signature_Algorithms::sig_algo_code(reader.get_byte());
+ }
+
signature = reader.get_range<byte>(2, 0, 65535);
}
@@ -75,6 +81,12 @@ MemoryVector<byte> Certificate_Verify::serialize() const
{
MemoryVector<byte> buf;
+ if(hash_algo != TLS_ALGO_NONE)
+ {
+ buf.push_back(Signature_Algorithms::hash_algo_code(hash_algo));
+ buf.push_back(Signature_Algorithms::sig_algo_code(sig_algo));
+ }
+
const u16bit sig_len = signature.size();
buf.push_back(get_byte(0, sig_len));
buf.push_back(get_byte(1, sig_len));
@@ -92,7 +104,7 @@ bool Certificate_Verify::verify(const X509_Certificate& cert,
std::auto_ptr<Public_Key> key(cert.subject_public_key());
std::pair<std::string, Signature_Format> format =
- state->choose_sig_format(key.get(), hash_algo, true);
+ state->choose_sig_format(key.get(), hash_algo, sig_algo, true);
PK_Verifier verifier(*key, format.first, format.second);
@@ -104,13 +116,8 @@ bool Certificate_Verify::verify(const X509_Certificate& cert,
return verifier.verify_message(&md5_sha[16], md5_sha.size()-16,
&signature[0], signature.size());
}
- else if(state->version == TLS_V10 || state->version == TLS_V11)
- {
- return verifier.verify_message(state->hash.get_contents(), signature);
- }
- else
- throw TLS_Exception(PROTOCOL_VERSION,
- "Unknown TLS version in certificate verification");
+
+ return verifier.verify_message(state->hash.get_contents(), signature);
}
}
diff --git a/src/tls/s_hello.cpp b/src/tls/s_hello.cpp
index 652544806..e6aff94e3 100644
--- a/src/tls/s_hello.cpp
+++ b/src/tls/s_hello.cpp
@@ -123,25 +123,17 @@ Server_Hello::Server_Hello(const MemoryRegion<byte>& buf)
TLS_Extensions extensions(reader);
- for(size_t i = 0; i != extensions.count(); ++i)
+ if(Renegotation_Extension* reneg = extensions.get<Renegotation_Extension>())
{
- TLS_Extension* extn = extensions.at(i);
-
- if(Renegotation_Extension* reneg = dynamic_cast<Renegotation_Extension*>(extn))
- {
- // checked by TLS_Client / TLS_Server as they know the handshake state
- m_secure_renegotiation = true;
- m_renegotiation_info = reneg->renegotiation_info();
- }
- else if(Next_Protocol_Notification* npn = dynamic_cast<Next_Protocol_Notification*>(extn))
- {
- m_next_protocols = npn->protocols();
- m_next_protocol = true;
- }
- else if(Signature_Algorithms* sigs = dynamic_cast<Signature_Algorithms*>(extn))
- {
- // save in handshake state
- }
+ // checked by TLS_Client / TLS_Server as they know the handshake state
+ m_secure_renegotiation = true;
+ m_renegotiation_info = reneg->renegotiation_info();
+ }
+
+ if(Next_Protocol_Notification* npn = extensions.get<Next_Protocol_Notification>())
+ {
+ m_next_protocols = npn->protocols();
+ m_next_protocol = true;
}
}
@@ -166,13 +158,13 @@ MemoryVector<byte> Server_Hello::serialize() const
TLS_Extensions extensions;
if(m_secure_renegotiation)
- extensions.push_back(new Renegotation_Extension(m_renegotiation_info));
+ extensions.add(new Renegotation_Extension(m_renegotiation_info));
if(m_fragment_size != 0)
- extensions.push_back(new Maximum_Fragment_Length(m_fragment_size));
+ extensions.add(new Maximum_Fragment_Length(m_fragment_size));
if(m_next_protocol)
- extensions.push_back(new Next_Protocol_Notification(m_next_protocols));
+ extensions.add(new Next_Protocol_Notification(m_next_protocols));
buf += extensions.serialize();
diff --git a/src/tls/s_kex.cpp b/src/tls/s_kex.cpp
index ac6ee15ee..6b87e6ac6 100644
--- a/src/tls/s_kex.cpp
+++ b/src/tls/s_kex.cpp
@@ -1,6 +1,6 @@
/*
* Server Key Exchange Message
-* (C) 2004-2010 Jack Lloyd
+* (C) 2004-2010,2012 Jack Lloyd
*
* Released under the terms of the Botan license
*/
@@ -35,20 +35,8 @@ Server_Key_Exchange::Server_Key_Exchange(Record_Writer& writer,
throw Invalid_Argument("Unknown key type " + state->kex_priv->algo_name() +
" for TLS key exchange");
- // FIXME: this should respect client's hash preferences
- if(state->version >= TLS_V12)
- {
- hash_algo = TLS_ALGO_HASH_SHA256;
- sig_algo = TLS_ALGO_SIGNER_RSA;
- }
- else
- {
- hash_algo = TLS_ALGO_NONE;
- sig_algo = TLS_ALGO_NONE;
- }
-
std::pair<std::string, Signature_Format> format =
- state->choose_sig_format(private_key, hash_algo, false);
+ state->choose_sig_format(private_key, hash_algo, sig_algo, false);
PK_Signer signer(*private_key, format.first, format.second);
@@ -153,10 +141,8 @@ bool Server_Key_Exchange::verify(const X509_Certificate& cert,
{
std::auto_ptr<Public_Key> key(cert.subject_public_key());
- printf("Checking %s vs code %d\n", key->algo_name().c_str(), sig_algo);
-
std::pair<std::string, Signature_Format> format =
- state->choose_sig_format(key.get(), hash_algo, false);
+ state->choose_sig_format(key.get(), hash_algo, sig_algo, false);
PK_Verifier verifier(*key, format.first, format.second);
diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp
index c8fcd8144..ed7de501f 100644
--- a/src/tls/tls_client.cpp
+++ b/src/tls/tls_client.cpp
@@ -306,7 +306,7 @@ void TLS_Client::process_handshake_msg(Handshake_Type type,
else if(type == CERTIFICATE_REQUEST)
{
state->set_expected_next(SERVER_HELLO_DONE);
- state->cert_req = new Certificate_Req(contents);
+ state->cert_req = new Certificate_Req(contents, state->version);
}
else if(type == SERVER_HELLO_DONE)
{
@@ -316,8 +316,7 @@ void TLS_Client::process_handshake_msg(Handshake_Type type,
if(state->received_handshake_msg(CERTIFICATE_REQUEST))
{
- std::vector<Certificate_Type> types =
- state->cert_req->acceptable_types();
+ std::vector<byte> types = state->cert_req->acceptable_types();
std::vector<X509_Certificate> client_certs =
creds.cert_chain("", // use types here
diff --git a/src/tls/tls_extensions.cpp b/src/tls/tls_extensions.cpp
index 9f80744f9..21c3b67fc 100644
--- a/src/tls/tls_extensions.cpp
+++ b/src/tls/tls_extensions.cpp
@@ -54,7 +54,7 @@ TLS_Extensions::TLS_Extensions(TLS_Data_Reader& reader)
extension_size);
if(extn)
- extensions.push_back(extn);
+ this->add(extn);
else // unknown/unhandled extension
reader.discard_next(extension_size);
}
@@ -65,14 +65,15 @@ MemoryVector<byte> TLS_Extensions::serialize() const
{
MemoryVector<byte> buf(2); // 2 bytes for length field
- for(size_t i = 0; i != extensions.size(); ++i)
+ for(std::map<TLS_Handshake_Extension_Type, TLS_Extension*>::const_iterator i = extensions.begin();
+ i != extensions.end(); ++i)
{
- if(extensions[i]->empty())
+ if(i->second->empty())
continue;
- const u16bit extn_code = extensions[i]->type();
+ const u16bit extn_code = i->second->type();
- MemoryVector<byte> extn_val = extensions[i]->serialize();
+ MemoryVector<byte> extn_val = i->second->serialize();
buf.push_back(get_byte(0, extn_code));
buf.push_back(get_byte(1, extn_code));
@@ -97,8 +98,12 @@ MemoryVector<byte> TLS_Extensions::serialize() const
TLS_Extensions::~TLS_Extensions()
{
- for(size_t i = 0; i != extensions.size(); ++i)
- delete extensions[i];
+ for(std::map<TLS_Handshake_Extension_Type, TLS_Extension*>::const_iterator i = extensions.begin();
+ i != extensions.end(); ++i)
+ {
+ delete i->second;
+ }
+
extensions.clear();
}
diff --git a/src/tls/tls_extensions.h b/src/tls/tls_extensions.h
index 2f4f711c2..a90cb4f2b 100644
--- a/src/tls/tls_extensions.h
+++ b/src/tls/tls_extensions.h
@@ -12,6 +12,7 @@
#include <botan/tls_magic.h>
#include <vector>
#include <string>
+#include <map>
namespace Botan {
@@ -24,6 +25,7 @@ class TLS_Extension
{
public:
virtual TLS_Handshake_Extension_Type type() const = 0;
+
virtual MemoryVector<byte> serialize() const = 0;
virtual bool empty() const = 0;
@@ -37,9 +39,11 @@ class TLS_Extension
class Server_Name_Indicator : public TLS_Extension
{
public:
- TLS_Handshake_Extension_Type type() const
+ static TLS_Handshake_Extension_Type static_type()
{ return TLSEXT_SERVER_NAME_INDICATION; }
+ TLS_Handshake_Extension_Type type() const { return static_type(); }
+
Server_Name_Indicator(const std::string& host_name) :
sni_host_name(host_name) {}
@@ -61,9 +65,11 @@ class Server_Name_Indicator : public TLS_Extension
class SRP_Identifier : public TLS_Extension
{
public:
- TLS_Handshake_Extension_Type type() const
+ static TLS_Handshake_Extension_Type static_type()
{ return TLSEXT_SRP_IDENTIFIER; }
+ TLS_Handshake_Extension_Type type() const { return static_type(); }
+
SRP_Identifier(const std::string& identifier) :
srp_identifier(identifier) {}
@@ -85,9 +91,11 @@ class SRP_Identifier : public TLS_Extension
class Renegotation_Extension : public TLS_Extension
{
public:
- TLS_Handshake_Extension_Type type() const
+ static TLS_Handshake_Extension_Type static_type()
{ return TLSEXT_SAFE_RENEGOTIATION; }
+ TLS_Handshake_Extension_Type type() const { return static_type(); }
+
Renegotation_Extension() {}
Renegotation_Extension(const MemoryRegion<byte>& bits) :
@@ -112,9 +120,11 @@ class Renegotation_Extension : public TLS_Extension
class Maximum_Fragment_Length : public TLS_Extension
{
public:
- TLS_Handshake_Extension_Type type() const
+ static TLS_Handshake_Extension_Type static_type()
{ return TLSEXT_MAX_FRAGMENT_LENGTH; }
+ TLS_Handshake_Extension_Type type() const { return static_type(); }
+
bool empty() const { return val != 0; }
size_t fragment_size() const;
@@ -149,9 +159,11 @@ class Maximum_Fragment_Length : public TLS_Extension
class Next_Protocol_Notification : public TLS_Extension
{
public:
- TLS_Handshake_Extension_Type type() const
+ static TLS_Handshake_Extension_Type static_type()
{ return TLSEXT_NEXT_PROTOCOL; }
+ TLS_Handshake_Extension_Type type() const { return static_type(); }
+
const std::vector<std::string>& protocols() const
{ return m_protocols; }
@@ -182,15 +194,17 @@ class Next_Protocol_Notification : public TLS_Extension
class Signature_Algorithms : public TLS_Extension
{
public:
+ static TLS_Handshake_Extension_Type static_type()
+ { return TLSEXT_SIGNATURE_ALGORITHMS; }
+
+ TLS_Handshake_Extension_Type type() const { return static_type(); }
+
static TLS_Ciphersuite_Algos hash_algo_code(byte code);
static byte hash_algo_code(TLS_Ciphersuite_Algos code);
static TLS_Ciphersuite_Algos sig_algo_code(byte code);
static byte sig_algo_code(TLS_Ciphersuite_Algos code);
- TLS_Handshake_Extension_Type type() const
- { return TLSEXT_SIGNATURE_ALGORITHMS; }
-
std::vector<std::pair<TLS_Ciphersuite_Algos, TLS_Ciphersuite_Algos> >
supported_signature_algorthms() const
{
@@ -215,12 +229,24 @@ class Signature_Algorithms : public TLS_Extension
class TLS_Extensions
{
public:
- size_t count() const { return extensions.size(); }
+ template<typename T>
+ T* get() const
+ {
+ TLS_Handshake_Extension_Type type = T::static_type();
- TLS_Extension* at(size_t idx) { return extensions.at(idx); }
+ std::map<TLS_Handshake_Extension_Type, TLS_Extension*>::const_iterator i =
+ extensions.find(type);
- void push_back(TLS_Extension* extn)
- { extensions.push_back(extn); }
+ if(i != extensions.end())
+ return dynamic_cast<T*>(i->second);
+ return 0;
+ }
+
+ void add(TLS_Extension* extn)
+ {
+ delete extensions[extn->type()]; // or hard error if already exists?
+ extensions[extn->type()] = extn;
+ }
MemoryVector<byte> serialize() const;
@@ -233,7 +259,7 @@ class TLS_Extensions
TLS_Extensions(const TLS_Extensions&) {}
TLS_Extensions& operator=(const TLS_Extensions&) { return (*this); }
- std::vector<TLS_Extension*> extensions;
+ std::map<TLS_Handshake_Extension_Type, TLS_Extension*> extensions;
};
}
diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp
index a816e9f6a..48fb70ae1 100644
--- a/src/tls/tls_handshake_state.cpp
+++ b/src/tls/tls_handshake_state.cpp
@@ -130,14 +130,79 @@ bool TLS_Handshake_State::received_handshake_msg(Handshake_Type handshake_msg) c
}
std::pair<std::string, Signature_Format>
+TLS_Handshake_State::choose_sig_format(const Private_Key* key,
+ TLS_Ciphersuite_Algos& hash_algo,
+ TLS_Ciphersuite_Algos& sig_algo,
+ bool for_client_auth)
+ {
+ const std::string algo_name = key->algo_name();
+
+ hash_algo = TLS_ALGO_NONE;
+ sig_algo = TLS_ALGO_NONE;
+
+ /*
+ FIXME: This should respect the algo preferences in the client hello.
+ Either we are the client, and shouldn't confuse the server by claiming
+ one thing and doing another, or we're the server and the client might
+ be unhappy if we send it something it doesn't understand.
+ */
+
+ if(algo_name == "RSA")
+ {
+ std::string padding = "";
+
+ if(for_client_auth && this->version == SSL_V3)
+ padding = "EMSA3(Raw)";
+ else if(this->version == TLS_V10 || this->version == TLS_V11)
+ padding = "EMSA3(TLS.Digest.0)";
+ else
+ {
+ hash_algo = TLS_ALGO_HASH_SHA256; // should be policy
+ sig_algo = TLS_ALGO_SIGNER_RSA;
+
+ std::string hash = TLS_Cipher_Suite::hash_code_to_name(hash_algo);
+ padding = "EMSA3(" + hash + ")";
+ }
+
+ return std::make_pair(padding, IEEE_1363);
+ }
+ else if(algo_name == "DSA")
+ {
+ std::string padding = "";
+
+ if(for_client_auth && this->version == SSL_V3)
+ padding = "Raw";
+ else if(this->version == TLS_V10 || this->version == TLS_V11)
+ padding = "EMSA1(SHA-1)";
+ else
+ {
+ hash_algo = TLS_ALGO_HASH_SHA1; // should be policy
+ sig_algo = TLS_ALGO_SIGNER_DSA;
+
+ std::string hash = TLS_Cipher_Suite::hash_code_to_name(hash_algo);
+ padding = "EMSA1(" + hash + ")";
+ }
+
+ return std::make_pair(padding, DER_SEQUENCE);
+ }
+
+ throw Invalid_Argument(algo_name + " is invalid/unknown for TLS signatures");
+ }
+
+std::pair<std::string, Signature_Format>
TLS_Handshake_State::choose_sig_format(const Public_Key* key,
TLS_Ciphersuite_Algos hash_algo,
+ TLS_Ciphersuite_Algos sig_algo,
bool for_client_auth)
{
const std::string algo_name = key->algo_name();
if(algo_name == "RSA")
{
+ if(sig_algo != TLS_ALGO_NONE && sig_algo != TLS_ALGO_SIGNER_RSA)
+ throw TLS_Exception(DECODE_ERROR,
+ "Counterparty sent RSA key and non-RSA signature");
+
std::string padding = "";
if(for_client_auth && this->version == SSL_V3)
@@ -154,6 +219,10 @@ TLS_Handshake_State::choose_sig_format(const Public_Key* key,
}
else if(algo_name == "DSA")
{
+ if(sig_algo != TLS_ALGO_NONE && sig_algo != TLS_ALGO_SIGNER_DSA)
+ throw TLS_Exception(DECODE_ERROR,
+ "Counterparty sent RSA key and non-RSA signature");
+
std::string padding = "";
if(for_client_auth && this->version == SSL_V3)
diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h
index 1beaf74b3..3480ee85f 100644
--- a/src/tls/tls_handshake_state.h
+++ b/src/tls/tls_handshake_state.h
@@ -49,6 +49,13 @@ class TLS_Handshake_State
std::pair<std::string, Signature_Format>
choose_sig_format(const Public_Key* key,
TLS_Ciphersuite_Algos hash_algo,
+ TLS_Ciphersuite_Algos sig_algo,
+ bool for_client_auth);
+
+ std::pair<std::string, Signature_Format>
+ choose_sig_format(const Private_Key* key,
+ TLS_Ciphersuite_Algos& hash_algo,
+ TLS_Ciphersuite_Algos& sig_algo,
bool for_client_auth);
Version_Code version;
diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h
index 9ea0b1a2d..95c1ba0a0 100644
--- a/src/tls/tls_messages.h
+++ b/src/tls/tls_messages.h
@@ -258,21 +258,22 @@ class Certificate_Req : public Handshake_Message
public:
Handshake_Type type() const { return CERTIFICATE_REQUEST; }
- std::vector<Certificate_Type> acceptable_types() const { return types; }
+ std::vector<byte> acceptable_types() const { return cert_types; }
std::vector<X509_DN> acceptable_CAs() const { return names; }
Certificate_Req(Record_Writer& writer,
TLS_Handshake_Hash& hash,
const std::vector<X509_Certificate>& allowed_cas,
- const std::vector<Certificate_Type>& types =
- std::vector<Certificate_Type>());
+ Version_Code version);
- Certificate_Req(const MemoryRegion<byte>& buf);
+ Certificate_Req(const MemoryRegion<byte>& buf,
+ Version_Code version);
private:
MemoryVector<byte> serialize() const;
std::vector<X509_DN> names;
- std::vector<Certificate_Type> types;
+ std::vector<byte> cert_types;
+ MemoryVector<byte> sig_and_hash_algos; // for TLS 1.2
};
/**
@@ -296,7 +297,8 @@ class Certificate_Verify : public Handshake_Message
RandomNumberGenerator& rng,
const Private_Key* key);
- Certificate_Verify(const MemoryRegion<byte>& buf);
+ Certificate_Verify(const MemoryRegion<byte>& buf,
+ Version_Code version);
private:
MemoryVector<byte> serialize() const;
diff --git a/src/tls/tls_policy.h b/src/tls/tls_policy.h
index a0bca4e7f..48ff9185e 100644
--- a/src/tls/tls_policy.h
+++ b/src/tls/tls_policy.h
@@ -52,7 +52,7 @@ class BOTAN_DLL TLS_Policy
/*
* @return the version we would prefer to negotiate
*/
- virtual Version_Code pref_version() const { return TLS_V12; }
+ virtual Version_Code pref_version() const { return TLS_V11; }
virtual bool check_cert(const std::vector<X509_Certificate>& cert_chain) const = 0;
diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp
index 503d55610..44f8ec2b4 100644
--- a/src/tls/tls_server.cpp
+++ b/src/tls/tls_server.cpp
@@ -278,8 +278,12 @@ void TLS_Server::process_handshake_msg(Handshake_Type type,
{
// FIXME: figure out the allowed CAs/cert types
- state->cert_req = new Certificate_Req(writer, state->hash,
- std::vector<X509_Certificate>());
+ std::vector<X509_Certificate> allowed_cas;
+
+ state->cert_req = new Certificate_Req(writer,
+ state->hash,
+ allowed_cas,
+ state->version);
state->set_expected_next(CERTIFICATE);
}
@@ -325,7 +329,7 @@ void TLS_Server::process_handshake_msg(Handshake_Type type,
}
else if(type == CERTIFICATE_VERIFY)
{
- state->client_verify = new Certificate_Verify(contents);
+ state->client_verify = new Certificate_Verify(contents, state->version);
const std::vector<X509_Certificate>& client_certs =
state->client_certs->cert_chain();