diff options
Diffstat (limited to 'src/tls')
-rw-r--r-- | src/tls/c_hello.cpp | 68 | ||||
-rw-r--r-- | src/tls/cert_req.cpp | 54 | ||||
-rw-r--r-- | src/tls/cert_ver.cpp | 47 | ||||
-rw-r--r-- | src/tls/s_hello.cpp | 34 | ||||
-rw-r--r-- | src/tls/s_kex.cpp | 20 | ||||
-rw-r--r-- | src/tls/tls_client.cpp | 5 | ||||
-rw-r--r-- | src/tls/tls_extensions.cpp | 19 | ||||
-rw-r--r-- | src/tls/tls_extensions.h | 52 | ||||
-rw-r--r-- | src/tls/tls_handshake_state.cpp | 69 | ||||
-rw-r--r-- | src/tls/tls_handshake_state.h | 7 | ||||
-rw-r--r-- | src/tls/tls_messages.h | 14 | ||||
-rw-r--r-- | src/tls/tls_policy.h | 2 | ||||
-rw-r--r-- | src/tls/tls_server.cpp | 10 |
13 files changed, 251 insertions, 150 deletions
diff --git a/src/tls/c_hello.cpp b/src/tls/c_hello.cpp index a70713a80..71c0c3de9 100644 --- a/src/tls/c_hello.cpp +++ b/src/tls/c_hello.cpp @@ -147,20 +147,20 @@ MemoryVector<byte> Client_Hello::serialize() const // Initial handshake if(m_renegotiation_info.empty()) { - extensions.push_back(new Renegotation_Extension(m_renegotiation_info)); - extensions.push_back(new Server_Name_Indicator(m_hostname)); - extensions.push_back(new SRP_Identifier(m_srp_identifier)); + extensions.add(new Renegotation_Extension(m_renegotiation_info)); + extensions.add(new Server_Name_Indicator(m_hostname)); + extensions.add(new SRP_Identifier(m_srp_identifier)); if(m_version >= TLS_V12) - extensions.push_back(new Signature_Algorithms()); + extensions.add(new Signature_Algorithms()); if(m_next_protocol) - extensions.push_back(new Next_Protocol_Notification()); + extensions.add(new Next_Protocol_Notification()); } else { // renegotiation - extensions.push_back(new Renegotation_Extension(m_renegotiation_info)); + extensions.add(new Renegotation_Extension(m_renegotiation_info)); } buf += extensions.serialize(); @@ -237,35 +237,39 @@ void Client_Hello::deserialize(const MemoryRegion<byte>& buf) TLS_Extensions extensions(reader); - for(size_t i = 0; i != extensions.count(); ++i) + if(Server_Name_Indicator* sni = extensions.get<Server_Name_Indicator>()) { - TLS_Extension* extn = extensions.at(i); + m_hostname = sni->host_name(); + } - if(Server_Name_Indicator* sni = dynamic_cast<Server_Name_Indicator*>(extn)) - { - m_hostname = sni->host_name(); - } - else if(SRP_Identifier* srp = dynamic_cast<SRP_Identifier*>(extn)) - { - m_srp_identifier = srp->identifier(); - } - else if(Next_Protocol_Notification* npn = dynamic_cast<Next_Protocol_Notification*>(extn)) - { - if(!npn->protocols().empty()) - throw Decoding_Error("Client sent non-empty NPN extension"); + if(SRP_Identifier* srp = extensions.get<SRP_Identifier>()) + { + m_srp_identifier = srp->identifier(); + } - m_next_protocol = true; - } - else if(Maximum_Fragment_Length* frag = dynamic_cast<Maximum_Fragment_Length*>(extn)) - { - m_fragment_size = frag->fragment_size(); - } - else if(Renegotation_Extension* reneg = dynamic_cast<Renegotation_Extension*>(extn)) - { - // checked by TLS_Client / TLS_Server as they know the handshake state - m_secure_renegotiation = true; - m_renegotiation_info = reneg->renegotiation_info(); - } + if(Next_Protocol_Notification* npn = extensions.get<Next_Protocol_Notification>()) + { + if(!npn->protocols().empty()) + throw Decoding_Error("Client sent non-empty NPN extension"); + + m_next_protocol = true; + } + + if(Maximum_Fragment_Length* frag = extensions.get<Maximum_Fragment_Length>()) + { + m_fragment_size = frag->fragment_size(); + } + + if(Renegotation_Extension* reneg = extensions.get<Renegotation_Extension>()) + { + // checked by TLS_Client / TLS_Server as they know the handshake state + m_secure_renegotiation = true; + m_renegotiation_info = reneg->renegotiation_info(); + } + + if(Signature_Algorithms* sigs = extensions.get<Signature_Algorithms>()) + { + // save in handshake state } if(value_exists(m_suites, static_cast<u16bit>(TLS_EMPTY_RENEGOTIATION_INFO_SCSV))) diff --git a/src/tls/cert_req.cpp b/src/tls/cert_req.cpp index bdb25057c..c3e46a5ae 100644 --- a/src/tls/cert_req.cpp +++ b/src/tls/cert_req.cpp @@ -7,11 +7,14 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_reader.h> +#include <botan/internal/tls_extensions.h> #include <botan/der_enc.h> #include <botan/ber_dec.h> #include <botan/loadstor.h> #include <botan/secqueue.h> +#include <stdio.h> + namespace Botan { /** @@ -20,18 +23,16 @@ namespace Botan { Certificate_Req::Certificate_Req(Record_Writer& writer, TLS_Handshake_Hash& hash, const std::vector<X509_Certificate>& ca_certs, - const std::vector<Certificate_Type>& cert_types) + Version_Code version) { for(size_t i = 0; i != ca_certs.size(); ++i) names.push_back(ca_certs[i].subject_dn()); - if(cert_types.empty()) // default is RSA/DSA is OK - { - types.push_back(RSA_CERT); - types.push_back(DSS_CERT); - } - else - types = cert_types; + cert_types.push_back(RSA_CERT); + cert_types.push_back(DSS_CERT); + + if(version >= TLS_V12) + sig_and_hash_algos = Signature_Algorithms().serialize(); send(writer, hash); } @@ -39,39 +40,36 @@ Certificate_Req::Certificate_Req(Record_Writer& writer, /** * Deserialize a Certificate Request message */ -Certificate_Req::Certificate_Req(const MemoryRegion<byte>& buf) +Certificate_Req::Certificate_Req(const MemoryRegion<byte>& buf, + Version_Code version) { if(buf.size() < 4) throw Decoding_Error("Certificate_Req: Bad certificate request"); - const size_t types_size = buf[0]; + TLS_Data_Reader reader(buf); - if(buf.size() < types_size + 3) - throw Decoding_Error("Certificate_Req: Bad certificate request"); + cert_types = reader.get_range_vector<byte>(1, 1, 255); - for(size_t i = 0; i != types_size; ++i) - types.push_back(static_cast<Certificate_Type>(buf[i+1])); + if(version >= TLS_V12) + { + std::vector<u16bit> sig_hash_algs = reader.get_range_vector<u16bit>(2, 2, 65534); - const size_t names_size = make_u16bit(buf[types_size+1], buf[types_size+2]); + // FIXME, do something with this + } - if(buf.size() != names_size + types_size + 3) - throw Decoding_Error("Certificate_Req: Bad certificate request"); + u16bit purported_size = reader.get_u16bit(); - size_t offset = types_size + 3; + if(reader.remaining_bytes() != purported_size) + throw Decoding_Error("Inconsistent length in certificate request"); - while(offset < buf.size()) + while(reader.has_remaining()) { - const size_t name_size = make_u16bit(buf[offset], buf[offset+1]); - - if(offset + 2 + name_size > buf.size()) - throw Decoding_Error("Certificate_Req: Bad certificate request"); + std::vector<byte> name_bits = reader.get_range_vector<byte>(2, 0, 65535); - BER_Decoder decoder(&buf[offset + 2], name_size); + BER_Decoder decoder(&name_bits[0], name_bits.size()); X509_DN name; decoder.decode(name); names.push_back(name); - - offset += (2 + name_size); } } @@ -82,7 +80,9 @@ MemoryVector<byte> Certificate_Req::serialize() const { MemoryVector<byte> buf; - append_tls_length_value(buf, types, 1); + append_tls_length_value(buf, cert_types, 1); + + buf += sig_and_hash_algos; for(size_t i = 0; i != names.size(); ++i) { diff --git a/src/tls/cert_ver.cpp b/src/tls/cert_ver.cpp index f35202734..f7386dd13 100644 --- a/src/tls/cert_ver.cpp +++ b/src/tls/cert_ver.cpp @@ -7,6 +7,7 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_reader.h> +#include <botan/internal/tls_extensions.h> #include <botan/internal/assert.h> #include <botan/tls_exceptn.h> #include <botan/pubkey.h> @@ -27,14 +28,8 @@ Certificate_Verify::Certificate_Verify(Record_Writer& writer, { BOTAN_ASSERT_NONNULL(priv_key); - // FIXME: this should respect server's hash preferences - if(state->version >= TLS_V12) - hash_algo = TLS_ALGO_HASH_SHA256; - else - hash_algo = TLS_ALGO_NONE; - std::pair<std::string, Signature_Format> format = - state->choose_sig_format(priv_key, hash_algo, true); + state->choose_sig_format(priv_key, hash_algo, sig_algo, true); PK_Signer signer(*priv_key, format.first, format.second); @@ -48,13 +43,10 @@ Certificate_Verify::Certificate_Verify(Record_Writer& writer, else signature = signer.sign_message(md5_sha, rng); } - else if(state->version == TLS_V10 || state->version == TLS_V11) + else { signature = signer.sign_message(state->hash.get_contents(), rng); } - else - throw TLS_Exception(PROTOCOL_VERSION, - "Unknown TLS version in certificate verification"); send(writer, state->hash); } @@ -62,9 +54,23 @@ Certificate_Verify::Certificate_Verify(Record_Writer& writer, /* * Deserialize a Certificate Verify message */ -Certificate_Verify::Certificate_Verify(const MemoryRegion<byte>& buf) +Certificate_Verify::Certificate_Verify(const MemoryRegion<byte>& buf, + Version_Code version) { TLS_Data_Reader reader(buf); + + if(version < TLS_V12) + { + // use old defaults + hash_algo = TLS_ALGO_NONE; + sig_algo = TLS_ALGO_NONE; + } + else + { + hash_algo = Signature_Algorithms::hash_algo_code(reader.get_byte()); + sig_algo = Signature_Algorithms::sig_algo_code(reader.get_byte()); + } + signature = reader.get_range<byte>(2, 0, 65535); } @@ -75,6 +81,12 @@ MemoryVector<byte> Certificate_Verify::serialize() const { MemoryVector<byte> buf; + if(hash_algo != TLS_ALGO_NONE) + { + buf.push_back(Signature_Algorithms::hash_algo_code(hash_algo)); + buf.push_back(Signature_Algorithms::sig_algo_code(sig_algo)); + } + const u16bit sig_len = signature.size(); buf.push_back(get_byte(0, sig_len)); buf.push_back(get_byte(1, sig_len)); @@ -92,7 +104,7 @@ bool Certificate_Verify::verify(const X509_Certificate& cert, std::auto_ptr<Public_Key> key(cert.subject_public_key()); std::pair<std::string, Signature_Format> format = - state->choose_sig_format(key.get(), hash_algo, true); + state->choose_sig_format(key.get(), hash_algo, sig_algo, true); PK_Verifier verifier(*key, format.first, format.second); @@ -104,13 +116,8 @@ bool Certificate_Verify::verify(const X509_Certificate& cert, return verifier.verify_message(&md5_sha[16], md5_sha.size()-16, &signature[0], signature.size()); } - else if(state->version == TLS_V10 || state->version == TLS_V11) - { - return verifier.verify_message(state->hash.get_contents(), signature); - } - else - throw TLS_Exception(PROTOCOL_VERSION, - "Unknown TLS version in certificate verification"); + + return verifier.verify_message(state->hash.get_contents(), signature); } } diff --git a/src/tls/s_hello.cpp b/src/tls/s_hello.cpp index 652544806..e6aff94e3 100644 --- a/src/tls/s_hello.cpp +++ b/src/tls/s_hello.cpp @@ -123,25 +123,17 @@ Server_Hello::Server_Hello(const MemoryRegion<byte>& buf) TLS_Extensions extensions(reader); - for(size_t i = 0; i != extensions.count(); ++i) + if(Renegotation_Extension* reneg = extensions.get<Renegotation_Extension>()) { - TLS_Extension* extn = extensions.at(i); - - if(Renegotation_Extension* reneg = dynamic_cast<Renegotation_Extension*>(extn)) - { - // checked by TLS_Client / TLS_Server as they know the handshake state - m_secure_renegotiation = true; - m_renegotiation_info = reneg->renegotiation_info(); - } - else if(Next_Protocol_Notification* npn = dynamic_cast<Next_Protocol_Notification*>(extn)) - { - m_next_protocols = npn->protocols(); - m_next_protocol = true; - } - else if(Signature_Algorithms* sigs = dynamic_cast<Signature_Algorithms*>(extn)) - { - // save in handshake state - } + // checked by TLS_Client / TLS_Server as they know the handshake state + m_secure_renegotiation = true; + m_renegotiation_info = reneg->renegotiation_info(); + } + + if(Next_Protocol_Notification* npn = extensions.get<Next_Protocol_Notification>()) + { + m_next_protocols = npn->protocols(); + m_next_protocol = true; } } @@ -166,13 +158,13 @@ MemoryVector<byte> Server_Hello::serialize() const TLS_Extensions extensions; if(m_secure_renegotiation) - extensions.push_back(new Renegotation_Extension(m_renegotiation_info)); + extensions.add(new Renegotation_Extension(m_renegotiation_info)); if(m_fragment_size != 0) - extensions.push_back(new Maximum_Fragment_Length(m_fragment_size)); + extensions.add(new Maximum_Fragment_Length(m_fragment_size)); if(m_next_protocol) - extensions.push_back(new Next_Protocol_Notification(m_next_protocols)); + extensions.add(new Next_Protocol_Notification(m_next_protocols)); buf += extensions.serialize(); diff --git a/src/tls/s_kex.cpp b/src/tls/s_kex.cpp index ac6ee15ee..6b87e6ac6 100644 --- a/src/tls/s_kex.cpp +++ b/src/tls/s_kex.cpp @@ -1,6 +1,6 @@ /* * Server Key Exchange Message -* (C) 2004-2010 Jack Lloyd +* (C) 2004-2010,2012 Jack Lloyd * * Released under the terms of the Botan license */ @@ -35,20 +35,8 @@ Server_Key_Exchange::Server_Key_Exchange(Record_Writer& writer, throw Invalid_Argument("Unknown key type " + state->kex_priv->algo_name() + " for TLS key exchange"); - // FIXME: this should respect client's hash preferences - if(state->version >= TLS_V12) - { - hash_algo = TLS_ALGO_HASH_SHA256; - sig_algo = TLS_ALGO_SIGNER_RSA; - } - else - { - hash_algo = TLS_ALGO_NONE; - sig_algo = TLS_ALGO_NONE; - } - std::pair<std::string, Signature_Format> format = - state->choose_sig_format(private_key, hash_algo, false); + state->choose_sig_format(private_key, hash_algo, sig_algo, false); PK_Signer signer(*private_key, format.first, format.second); @@ -153,10 +141,8 @@ bool Server_Key_Exchange::verify(const X509_Certificate& cert, { std::auto_ptr<Public_Key> key(cert.subject_public_key()); - printf("Checking %s vs code %d\n", key->algo_name().c_str(), sig_algo); - std::pair<std::string, Signature_Format> format = - state->choose_sig_format(key.get(), hash_algo, false); + state->choose_sig_format(key.get(), hash_algo, sig_algo, false); PK_Verifier verifier(*key, format.first, format.second); diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index c8fcd8144..ed7de501f 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -306,7 +306,7 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, else if(type == CERTIFICATE_REQUEST) { state->set_expected_next(SERVER_HELLO_DONE); - state->cert_req = new Certificate_Req(contents); + state->cert_req = new Certificate_Req(contents, state->version); } else if(type == SERVER_HELLO_DONE) { @@ -316,8 +316,7 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, if(state->received_handshake_msg(CERTIFICATE_REQUEST)) { - std::vector<Certificate_Type> types = - state->cert_req->acceptable_types(); + std::vector<byte> types = state->cert_req->acceptable_types(); std::vector<X509_Certificate> client_certs = creds.cert_chain("", // use types here diff --git a/src/tls/tls_extensions.cpp b/src/tls/tls_extensions.cpp index 9f80744f9..21c3b67fc 100644 --- a/src/tls/tls_extensions.cpp +++ b/src/tls/tls_extensions.cpp @@ -54,7 +54,7 @@ TLS_Extensions::TLS_Extensions(TLS_Data_Reader& reader) extension_size); if(extn) - extensions.push_back(extn); + this->add(extn); else // unknown/unhandled extension reader.discard_next(extension_size); } @@ -65,14 +65,15 @@ MemoryVector<byte> TLS_Extensions::serialize() const { MemoryVector<byte> buf(2); // 2 bytes for length field - for(size_t i = 0; i != extensions.size(); ++i) + for(std::map<TLS_Handshake_Extension_Type, TLS_Extension*>::const_iterator i = extensions.begin(); + i != extensions.end(); ++i) { - if(extensions[i]->empty()) + if(i->second->empty()) continue; - const u16bit extn_code = extensions[i]->type(); + const u16bit extn_code = i->second->type(); - MemoryVector<byte> extn_val = extensions[i]->serialize(); + MemoryVector<byte> extn_val = i->second->serialize(); buf.push_back(get_byte(0, extn_code)); buf.push_back(get_byte(1, extn_code)); @@ -97,8 +98,12 @@ MemoryVector<byte> TLS_Extensions::serialize() const TLS_Extensions::~TLS_Extensions() { - for(size_t i = 0; i != extensions.size(); ++i) - delete extensions[i]; + for(std::map<TLS_Handshake_Extension_Type, TLS_Extension*>::const_iterator i = extensions.begin(); + i != extensions.end(); ++i) + { + delete i->second; + } + extensions.clear(); } diff --git a/src/tls/tls_extensions.h b/src/tls/tls_extensions.h index 2f4f711c2..a90cb4f2b 100644 --- a/src/tls/tls_extensions.h +++ b/src/tls/tls_extensions.h @@ -12,6 +12,7 @@ #include <botan/tls_magic.h> #include <vector> #include <string> +#include <map> namespace Botan { @@ -24,6 +25,7 @@ class TLS_Extension { public: virtual TLS_Handshake_Extension_Type type() const = 0; + virtual MemoryVector<byte> serialize() const = 0; virtual bool empty() const = 0; @@ -37,9 +39,11 @@ class TLS_Extension class Server_Name_Indicator : public TLS_Extension { public: - TLS_Handshake_Extension_Type type() const + static TLS_Handshake_Extension_Type static_type() { return TLSEXT_SERVER_NAME_INDICATION; } + TLS_Handshake_Extension_Type type() const { return static_type(); } + Server_Name_Indicator(const std::string& host_name) : sni_host_name(host_name) {} @@ -61,9 +65,11 @@ class Server_Name_Indicator : public TLS_Extension class SRP_Identifier : public TLS_Extension { public: - TLS_Handshake_Extension_Type type() const + static TLS_Handshake_Extension_Type static_type() { return TLSEXT_SRP_IDENTIFIER; } + TLS_Handshake_Extension_Type type() const { return static_type(); } + SRP_Identifier(const std::string& identifier) : srp_identifier(identifier) {} @@ -85,9 +91,11 @@ class SRP_Identifier : public TLS_Extension class Renegotation_Extension : public TLS_Extension { public: - TLS_Handshake_Extension_Type type() const + static TLS_Handshake_Extension_Type static_type() { return TLSEXT_SAFE_RENEGOTIATION; } + TLS_Handshake_Extension_Type type() const { return static_type(); } + Renegotation_Extension() {} Renegotation_Extension(const MemoryRegion<byte>& bits) : @@ -112,9 +120,11 @@ class Renegotation_Extension : public TLS_Extension class Maximum_Fragment_Length : public TLS_Extension { public: - TLS_Handshake_Extension_Type type() const + static TLS_Handshake_Extension_Type static_type() { return TLSEXT_MAX_FRAGMENT_LENGTH; } + TLS_Handshake_Extension_Type type() const { return static_type(); } + bool empty() const { return val != 0; } size_t fragment_size() const; @@ -149,9 +159,11 @@ class Maximum_Fragment_Length : public TLS_Extension class Next_Protocol_Notification : public TLS_Extension { public: - TLS_Handshake_Extension_Type type() const + static TLS_Handshake_Extension_Type static_type() { return TLSEXT_NEXT_PROTOCOL; } + TLS_Handshake_Extension_Type type() const { return static_type(); } + const std::vector<std::string>& protocols() const { return m_protocols; } @@ -182,15 +194,17 @@ class Next_Protocol_Notification : public TLS_Extension class Signature_Algorithms : public TLS_Extension { public: + static TLS_Handshake_Extension_Type static_type() + { return TLSEXT_SIGNATURE_ALGORITHMS; } + + TLS_Handshake_Extension_Type type() const { return static_type(); } + static TLS_Ciphersuite_Algos hash_algo_code(byte code); static byte hash_algo_code(TLS_Ciphersuite_Algos code); static TLS_Ciphersuite_Algos sig_algo_code(byte code); static byte sig_algo_code(TLS_Ciphersuite_Algos code); - TLS_Handshake_Extension_Type type() const - { return TLSEXT_SIGNATURE_ALGORITHMS; } - std::vector<std::pair<TLS_Ciphersuite_Algos, TLS_Ciphersuite_Algos> > supported_signature_algorthms() const { @@ -215,12 +229,24 @@ class Signature_Algorithms : public TLS_Extension class TLS_Extensions { public: - size_t count() const { return extensions.size(); } + template<typename T> + T* get() const + { + TLS_Handshake_Extension_Type type = T::static_type(); - TLS_Extension* at(size_t idx) { return extensions.at(idx); } + std::map<TLS_Handshake_Extension_Type, TLS_Extension*>::const_iterator i = + extensions.find(type); - void push_back(TLS_Extension* extn) - { extensions.push_back(extn); } + if(i != extensions.end()) + return dynamic_cast<T*>(i->second); + return 0; + } + + void add(TLS_Extension* extn) + { + delete extensions[extn->type()]; // or hard error if already exists? + extensions[extn->type()] = extn; + } MemoryVector<byte> serialize() const; @@ -233,7 +259,7 @@ class TLS_Extensions TLS_Extensions(const TLS_Extensions&) {} TLS_Extensions& operator=(const TLS_Extensions&) { return (*this); } - std::vector<TLS_Extension*> extensions; + std::map<TLS_Handshake_Extension_Type, TLS_Extension*> extensions; }; } diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp index a816e9f6a..48fb70ae1 100644 --- a/src/tls/tls_handshake_state.cpp +++ b/src/tls/tls_handshake_state.cpp @@ -130,14 +130,79 @@ bool TLS_Handshake_State::received_handshake_msg(Handshake_Type handshake_msg) c } std::pair<std::string, Signature_Format> +TLS_Handshake_State::choose_sig_format(const Private_Key* key, + TLS_Ciphersuite_Algos& hash_algo, + TLS_Ciphersuite_Algos& sig_algo, + bool for_client_auth) + { + const std::string algo_name = key->algo_name(); + + hash_algo = TLS_ALGO_NONE; + sig_algo = TLS_ALGO_NONE; + + /* + FIXME: This should respect the algo preferences in the client hello. + Either we are the client, and shouldn't confuse the server by claiming + one thing and doing another, or we're the server and the client might + be unhappy if we send it something it doesn't understand. + */ + + if(algo_name == "RSA") + { + std::string padding = ""; + + if(for_client_auth && this->version == SSL_V3) + padding = "EMSA3(Raw)"; + else if(this->version == TLS_V10 || this->version == TLS_V11) + padding = "EMSA3(TLS.Digest.0)"; + else + { + hash_algo = TLS_ALGO_HASH_SHA256; // should be policy + sig_algo = TLS_ALGO_SIGNER_RSA; + + std::string hash = TLS_Cipher_Suite::hash_code_to_name(hash_algo); + padding = "EMSA3(" + hash + ")"; + } + + return std::make_pair(padding, IEEE_1363); + } + else if(algo_name == "DSA") + { + std::string padding = ""; + + if(for_client_auth && this->version == SSL_V3) + padding = "Raw"; + else if(this->version == TLS_V10 || this->version == TLS_V11) + padding = "EMSA1(SHA-1)"; + else + { + hash_algo = TLS_ALGO_HASH_SHA1; // should be policy + sig_algo = TLS_ALGO_SIGNER_DSA; + + std::string hash = TLS_Cipher_Suite::hash_code_to_name(hash_algo); + padding = "EMSA1(" + hash + ")"; + } + + return std::make_pair(padding, DER_SEQUENCE); + } + + throw Invalid_Argument(algo_name + " is invalid/unknown for TLS signatures"); + } + +std::pair<std::string, Signature_Format> TLS_Handshake_State::choose_sig_format(const Public_Key* key, TLS_Ciphersuite_Algos hash_algo, + TLS_Ciphersuite_Algos sig_algo, bool for_client_auth) { const std::string algo_name = key->algo_name(); if(algo_name == "RSA") { + if(sig_algo != TLS_ALGO_NONE && sig_algo != TLS_ALGO_SIGNER_RSA) + throw TLS_Exception(DECODE_ERROR, + "Counterparty sent RSA key and non-RSA signature"); + std::string padding = ""; if(for_client_auth && this->version == SSL_V3) @@ -154,6 +219,10 @@ TLS_Handshake_State::choose_sig_format(const Public_Key* key, } else if(algo_name == "DSA") { + if(sig_algo != TLS_ALGO_NONE && sig_algo != TLS_ALGO_SIGNER_DSA) + throw TLS_Exception(DECODE_ERROR, + "Counterparty sent RSA key and non-RSA signature"); + std::string padding = ""; if(for_client_auth && this->version == SSL_V3) diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h index 1beaf74b3..3480ee85f 100644 --- a/src/tls/tls_handshake_state.h +++ b/src/tls/tls_handshake_state.h @@ -49,6 +49,13 @@ class TLS_Handshake_State std::pair<std::string, Signature_Format> choose_sig_format(const Public_Key* key, TLS_Ciphersuite_Algos hash_algo, + TLS_Ciphersuite_Algos sig_algo, + bool for_client_auth); + + std::pair<std::string, Signature_Format> + choose_sig_format(const Private_Key* key, + TLS_Ciphersuite_Algos& hash_algo, + TLS_Ciphersuite_Algos& sig_algo, bool for_client_auth); Version_Code version; diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h index 9ea0b1a2d..95c1ba0a0 100644 --- a/src/tls/tls_messages.h +++ b/src/tls/tls_messages.h @@ -258,21 +258,22 @@ class Certificate_Req : public Handshake_Message public: Handshake_Type type() const { return CERTIFICATE_REQUEST; } - std::vector<Certificate_Type> acceptable_types() const { return types; } + std::vector<byte> acceptable_types() const { return cert_types; } std::vector<X509_DN> acceptable_CAs() const { return names; } Certificate_Req(Record_Writer& writer, TLS_Handshake_Hash& hash, const std::vector<X509_Certificate>& allowed_cas, - const std::vector<Certificate_Type>& types = - std::vector<Certificate_Type>()); + Version_Code version); - Certificate_Req(const MemoryRegion<byte>& buf); + Certificate_Req(const MemoryRegion<byte>& buf, + Version_Code version); private: MemoryVector<byte> serialize() const; std::vector<X509_DN> names; - std::vector<Certificate_Type> types; + std::vector<byte> cert_types; + MemoryVector<byte> sig_and_hash_algos; // for TLS 1.2 }; /** @@ -296,7 +297,8 @@ class Certificate_Verify : public Handshake_Message RandomNumberGenerator& rng, const Private_Key* key); - Certificate_Verify(const MemoryRegion<byte>& buf); + Certificate_Verify(const MemoryRegion<byte>& buf, + Version_Code version); private: MemoryVector<byte> serialize() const; diff --git a/src/tls/tls_policy.h b/src/tls/tls_policy.h index a0bca4e7f..48ff9185e 100644 --- a/src/tls/tls_policy.h +++ b/src/tls/tls_policy.h @@ -52,7 +52,7 @@ class BOTAN_DLL TLS_Policy /* * @return the version we would prefer to negotiate */ - virtual Version_Code pref_version() const { return TLS_V12; } + virtual Version_Code pref_version() const { return TLS_V11; } virtual bool check_cert(const std::vector<X509_Certificate>& cert_chain) const = 0; diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp index 503d55610..44f8ec2b4 100644 --- a/src/tls/tls_server.cpp +++ b/src/tls/tls_server.cpp @@ -278,8 +278,12 @@ void TLS_Server::process_handshake_msg(Handshake_Type type, { // FIXME: figure out the allowed CAs/cert types - state->cert_req = new Certificate_Req(writer, state->hash, - std::vector<X509_Certificate>()); + std::vector<X509_Certificate> allowed_cas; + + state->cert_req = new Certificate_Req(writer, + state->hash, + allowed_cas, + state->version); state->set_expected_next(CERTIFICATE); } @@ -325,7 +329,7 @@ void TLS_Server::process_handshake_msg(Handshake_Type type, } else if(type == CERTIFICATE_VERIFY) { - state->client_verify = new Certificate_Verify(contents); + state->client_verify = new Certificate_Verify(contents, state->version); const std::vector<X509_Certificate>& client_certs = state->client_certs->cert_chain(); |