diff options
author | lloyd <[email protected]> | 2012-03-16 17:32:40 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2012-03-16 17:32:40 +0000 |
commit | 1fce3fe2274cf3368e1f29827ed0f41cebba3726 (patch) | |
tree | 7a83f1d6d8fcf08dc8120aafba718dd11c1e52b7 /src/tls | |
parent | 4c6327c95bd01de54487b3159b77a5152ed39564 (diff) | |
parent | 7371f7c59ae722769fbc0dc810583a0cd0e38877 (diff) |
propagate from branch 'net.randombit.botan.tls-state-machine' (head c24b5d6b012131b177d38bddb8b06d73f81f70c4)
to branch 'net.randombit.botan.tls-session-ticket' (head 9977d4c118e1ac26425cef676ebf26cd5b2a470e)
Diffstat (limited to 'src/tls')
47 files changed, 3747 insertions, 2018 deletions
diff --git a/src/tls/c_hello.cpp b/src/tls/c_hello.cpp index 2455eae3b..e0fce03b5 100644 --- a/src/tls/c_hello.cpp +++ b/src/tls/c_hello.cpp @@ -15,6 +15,8 @@ namespace Botan { +namespace TLS { + MemoryVector<byte> make_hello_random(RandomNumberGenerator& rng) { MemoryVector<byte> buf(32); @@ -25,34 +27,20 @@ MemoryVector<byte> make_hello_random(RandomNumberGenerator& rng) } /* -* Encode and send a Handshake message +* Create a new Hello Request message */ -void Handshake_Message::send(Record_Writer& writer, TLS_Handshake_Hash& hash) const +Hello_Request::Hello_Request(Record_Writer& writer) { - MemoryVector<byte> buf = serialize(); - MemoryVector<byte> send_buf(4); - - const size_t buf_size = buf.size(); - - send_buf[0] = type(); - - for(size_t i = 1; i != 4; ++i) - send_buf[i] = get_byte<u32bit>(i, buf_size); - - send_buf += buf; - - hash.update(send_buf); - - writer.send(HANDSHAKE, &send_buf[0], send_buf.size()); + writer.send(*this); } /* -* Create a new Hello Request message +* Deserialize a Hello Request message */ -Hello_Request::Hello_Request(Record_Writer& writer) +Hello_Request::Hello_Request(const MemoryRegion<byte>& buf) { - TLS_Handshake_Hash dummy; // FIXME: *UGLY* - send(writer, dummy); + if(buf.size()) + throw Decoding_Error("Bad Hello_Request, has non-zero size"); } /* @@ -64,20 +52,11 @@ MemoryVector<byte> Hello_Request::serialize() const } /* -* Deserialize a Hello Request message -*/ -void Hello_Request::deserialize(const MemoryRegion<byte>& buf) - { - if(buf.size()) - throw Decoding_Error("Hello_Request: Must be empty, and is not"); - } - -/* * Create a new Client Hello message */ Client_Hello::Client_Hello(Record_Writer& writer, - TLS_Handshake_Hash& hash, - const TLS_Policy& policy, + Handshake_Hash& hash, + const Policy& policy, RandomNumberGenerator& rng, const MemoryRegion<byte>& reneg_info, bool next_protocol, @@ -85,7 +64,7 @@ Client_Hello::Client_Hello(Record_Writer& writer, const std::string& srp_identifier) : m_version(policy.pref_version()), m_random(make_hello_random(rng)), - m_suites(policy.ciphersuites(srp_identifier != "")), + m_suites(policy.ciphersuite_list((srp_identifier != ""))), m_comp_methods(policy.compression()), m_hostname(hostname), m_srp_identifier(srp_identifier), @@ -94,16 +73,25 @@ Client_Hello::Client_Hello(Record_Writer& writer, m_secure_renegotiation(true), m_renegotiation_info(reneg_info) { - send(writer, hash); + std::vector<std::string> hashes = policy.allowed_hashes(); + std::vector<std::string> sigs = policy.allowed_signature_methods(); + + m_supported_curves = policy.allowed_ecc_curves(); + + for(size_t i = 0; i != hashes.size(); ++i) + for(size_t j = 0; j != sigs.size(); ++j) + m_supported_algos.push_back(std::make_pair(hashes[i], sigs[j])); + + hash.update(writer.send(*this)); } /* * Create a new Client Hello message */ Client_Hello::Client_Hello(Record_Writer& writer, - TLS_Handshake_Hash& hash, + Handshake_Hash& hash, RandomNumberGenerator& rng, - const TLS_Session& session, + const Session& session, bool next_protocol) : m_version(session.version()), m_session_id(session.session_id()), @@ -114,10 +102,24 @@ Client_Hello::Client_Hello(Record_Writer& writer, m_fragment_size(session.fragment_size()), m_secure_renegotiation(session.secure_renegotiation()) { - m_suites.push_back(session.ciphersuite()); + m_suites.push_back(session.ciphersuite_code()); m_comp_methods.push_back(session.compression_method()); - send(writer, hash); + // set m_supported_algos + m_supported_curves here? + + hash.update(writer.send(*this)); + } + +Client_Hello::Client_Hello(const MemoryRegion<byte>& buf, Handshake_Type type) + { + m_next_protocol = false; + m_secure_renegotiation = false; + m_fragment_size = 0; + + if(type == CLIENT_HELLO) + deserialize(buf); + else + deserialize_sslv2(buf); } /* @@ -127,8 +129,8 @@ MemoryVector<byte> Client_Hello::serialize() const { MemoryVector<byte> buf; - buf.push_back(static_cast<byte>(m_version >> 8)); - buf.push_back(static_cast<byte>(m_version )); + buf.push_back(m_version.major_version()); + buf.push_back(m_version.minor_version()); buf += m_random; append_tls_length_value(buf, m_session_id, 1); @@ -142,22 +144,26 @@ MemoryVector<byte> Client_Hello::serialize() const * send that extension. */ - TLS_Extensions extensions; + Extensions extensions; // Initial handshake if(m_renegotiation_info.empty()) { - extensions.push_back(new Renegotation_Extension(m_renegotiation_info)); - extensions.push_back(new Server_Name_Indicator(m_hostname)); - extensions.push_back(new SRP_Identifier(m_srp_identifier)); + extensions.add(new Renegotation_Extension(m_renegotiation_info)); + extensions.add(new Server_Name_Indicator(m_hostname)); + extensions.add(new SRP_Identifier(m_srp_identifier)); + extensions.add(new Supported_Elliptic_Curves(m_supported_curves)); + + if(m_version >= Protocol_Version::TLS_V12) + extensions.add(new Signature_Algorithms(m_supported_algos)); if(m_next_protocol) - extensions.push_back(new Next_Protocol_Notification()); + extensions.add(new Next_Protocol_Notification()); } else { // renegotiation - extensions.push_back(new Renegotation_Extension(m_renegotiation_info)); + extensions.add(new Renegotation_Extension(m_renegotiation_info)); } buf += extensions.serialize(); @@ -194,7 +200,7 @@ void Client_Hello::deserialize_sslv2(const MemoryRegion<byte>& buf) m_suites.push_back(make_u16bit(buf[i+1], buf[i+2])); } - m_version = static_cast<Version_Code>(make_u16bit(buf[1], buf[2])); + m_version = Protocol_Version(buf[1], buf[2]); m_random.resize(challenge_len); copy_mem(&m_random[0], &buf[9+cipher_spec_len+m_session_id_len], challenge_len); @@ -220,7 +226,11 @@ void Client_Hello::deserialize(const MemoryRegion<byte>& buf) TLS_Data_Reader reader(buf); - m_version = static_cast<Version_Code>(reader.get_u16bit()); + const byte major_version = reader.get_byte(); + const byte minor_version = reader.get_byte(); + + m_version = Protocol_Version(major_version, minor_version); + m_random = reader.get_fixed<byte>(32); m_session_id = reader.get_range<byte>(1, 0, 32); @@ -229,30 +239,70 @@ void Client_Hello::deserialize(const MemoryRegion<byte>& buf) m_comp_methods = reader.get_range_vector<byte>(1, 1, 255); - m_next_protocol = false; - m_secure_renegotiation = false; - m_fragment_size = 0; + Extensions extensions(reader); - TLS_Extensions extensions(reader); + if(Server_Name_Indicator* sni = extensions.get<Server_Name_Indicator>()) + { + m_hostname = sni->host_name(); + } - for(size_t i = 0; i != extensions.count(); ++i) + if(SRP_Identifier* srp = extensions.get<SRP_Identifier>()) { - TLS_Extension* extn = extensions.at(i); + m_srp_identifier = srp->identifier(); + } - if(Server_Name_Indicator* sni = dynamic_cast<Server_Name_Indicator*>(extn)) - { - m_hostname = sni->host_name(); - } - else if(SRP_Identifier* srp = dynamic_cast<SRP_Identifier*>(extn)) + if(Next_Protocol_Notification* npn = extensions.get<Next_Protocol_Notification>()) + { + if(!npn->protocols().empty()) + throw Decoding_Error("Client sent non-empty NPN extension"); + + m_next_protocol = true; + } + + if(Maximum_Fragment_Length* frag = extensions.get<Maximum_Fragment_Length>()) + { + m_fragment_size = frag->fragment_size(); + } + + if(Renegotation_Extension* reneg = extensions.get<Renegotation_Extension>()) + { + // checked by Client / Server as they know the handshake state + m_secure_renegotiation = true; + m_renegotiation_info = reneg->renegotiation_info(); + } + + if(Supported_Elliptic_Curves* ecc = extensions.get<Supported_Elliptic_Curves>()) + m_supported_curves = ecc->curves(); + + if(Signature_Algorithms* sigs = extensions.get<Signature_Algorithms>()) + { + m_supported_algos = sigs->supported_signature_algorthms(); + } + else + { + if(m_version >= Protocol_Version::TLS_V12) { - m_srp_identifier = srp->identifier(); + /* + The rule for when a TLS 1.2 client not sending the extension + is strange; in theory, the server is supposed to act as if + the client had sent only SHA-1 using whatever signature + algorithm we end up negotiating. Right here, we don't know + what we'll end up negotiating (depends on policy), but we do + know that we'll only negotiate something the client sent, so + we can safely say it supports everything here and know that + we'll filter it out later. + */ + m_supported_algos.push_back(std::make_pair("SHA-1", "RSA")); + m_supported_algos.push_back(std::make_pair("SHA-1", "DSA")); + m_supported_algos.push_back(std::make_pair("SHA-1", "ECDSA")); } - else if(Next_Protocol_Notification* npn = dynamic_cast<Next_Protocol_Notification*>(extn)) + else { - if(!npn->protocols().empty()) - throw Decoding_Error("Client sent non-empty NPN extension"); + // For versions before TLS 1.2, insert fake values for the old defaults - m_next_protocol = true; + m_supported_algos.push_back(std::make_pair("TLS.Digest.0", "RSA")); + m_supported_algos.push_back(std::make_pair("SHA-1", "DSA")); + m_supported_algos.push_back(std::make_pair("SHA-1", "ECDSA")); } else if(Maximum_Fragment_Length* frag = dynamic_cast<Maximum_Fragment_Length*>(extn)) { @@ -282,7 +332,7 @@ void Client_Hello::deserialize(const MemoryRegion<byte>& buf) { if(!m_renegotiation_info.empty()) { - throw TLS_Exception(HANDSHAKE_FAILURE, + throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Client send SCSV and non-empty extension"); } } @@ -304,3 +354,5 @@ bool Client_Hello::offered_suite(u16bit ciphersuite) const } } + +} diff --git a/src/tls/c_kex.cpp b/src/tls/c_kex.cpp index 22c0253c1..ed571852c 100644 --- a/src/tls/c_kex.cpp +++ b/src/tls/c_kex.cpp @@ -7,8 +7,13 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_reader.h> +#include <botan/internal/tls_extensions.h> +#include <botan/tls_record.h> +#include <botan/internal/assert.h> +#include <botan/credentials_manager.h> #include <botan/pubkey.h> #include <botan/dh.h> +#include <botan/ecdh.h> #include <botan/rsa.h> #include <botan/rng.h> #include <botan/loadstor.h> @@ -16,142 +21,338 @@ namespace Botan { +namespace TLS { + +namespace { + +SecureVector<byte> strip_leading_zeros(const MemoryRegion<byte>& input) + { + size_t leading_zeros = 0; + + for(size_t i = 0; i != input.size(); ++i) + { + if(input[i] != 0) + break; + ++leading_zeros; + } + + SecureVector<byte> output(&input[leading_zeros], + input.size() - leading_zeros); + return output; + } + +} + /* * Create a new Client Key Exchange message */ Client_Key_Exchange::Client_Key_Exchange(Record_Writer& writer, - TLS_Handshake_Hash& hash, - RandomNumberGenerator& rng, - const Public_Key* pub_key, - Version_Code using_version, - Version_Code pref_version) + Handshake_State* state, + Credentials_Manager& creds, + const std::vector<X509_Certificate>& peer_certs, + RandomNumberGenerator& rng) { - include_length = true; + const std::string kex_algo = state->suite.kex_algo(); - if(const DH_PublicKey* dh_pub = dynamic_cast<const DH_PublicKey*>(pub_key)) + if(kex_algo == "PSK") { - DH_PrivateKey priv_key(rng, dh_pub->get_domain()); + std::string identity_hint = ""; + + if(state->server_kex) + { + TLS_Data_Reader reader(state->server_kex->params()); + identity_hint = reader.get_string(2, 0, 65535); + } + + const std::string hostname = state->client_hello->sni_hostname(); - PK_Key_Agreement ka(priv_key, "Raw"); + const std::string psk_identity = creds.psk_identity("tls-client", + hostname, + identity_hint); - pre_master = ka.derive_key(0, dh_pub->public_value()).bits_of(); + append_tls_length_value(key_material, psk_identity, 2); - key_material = priv_key.public_value(); + SymmetricKey psk = creds.psk("tls-client", hostname, psk_identity); + + MemoryVector<byte> zeros(psk.length()); + + append_tls_length_value(pre_master, zeros, 2); + append_tls_length_value(pre_master, psk.bits_of(), 2); } - else if(const RSA_PublicKey* rsa_pub = dynamic_cast<const RSA_PublicKey*>(pub_key)) + else if(state->server_kex) { - pre_master = rng.random_vec(48); - pre_master[0] = (pref_version >> 8) & 0xFF; - pre_master[1] = (pref_version ) & 0xFF; + TLS_Data_Reader reader(state->server_kex->params()); - PK_Encryptor_EME encryptor(*rsa_pub, "PKCS1v15"); + SymmetricKey psk; - key_material = encryptor.encrypt(pre_master, rng); + if(kex_algo == "DHE_PSK" || kex_algo == "ECDHE_PSK") + { + std::string identity_hint = reader.get_string(2, 0, 65535); - if(using_version == SSL_V3) - include_length = false; - } - else - throw Invalid_Argument("Client_Key_Exchange: Key not RSA or DH"); + const std::string hostname = state->client_hello->sni_hostname(); - send(writer, hash); - } + const std::string psk_identity = creds.psk_identity("tls-client", + hostname, + identity_hint); -/* -* Read a Client Key Exchange message -*/ -Client_Key_Exchange::Client_Key_Exchange(const MemoryRegion<byte>& contents, - const TLS_Cipher_Suite& suite, - Version_Code using_version) - { - include_length = true; + append_tls_length_value(key_material, psk_identity, 2); - if(using_version == SSL_V3 && (suite.kex_type() == TLS_ALGO_KEYEXCH_NOKEX)) - include_length = false; + psk = creds.psk("tls-client", hostname, psk_identity); + } - deserialize(contents); - } + if(kex_algo == "DH" || kex_algo == "DHE_PSK") + { + BigInt p = BigInt::decode(reader.get_range<byte>(2, 1, 65535)); + BigInt g = BigInt::decode(reader.get_range<byte>(2, 1, 65535)); + BigInt Y = BigInt::decode(reader.get_range<byte>(2, 1, 65535)); -/* -* Serialize a Client Key Exchange message -*/ -MemoryVector<byte> Client_Key_Exchange::serialize() const - { - if(include_length) - { - MemoryVector<byte> buf; - append_tls_length_value(buf, key_material, 2); - return buf; + if(reader.remaining_bytes()) + throw Decoding_Error("Bad params size for DH key exchange"); + + DL_Group group(p, g); + + if(!group.verify_group(rng, true)) + throw Internal_Error("DH group failed validation, possible attack"); + + DH_PublicKey counterparty_key(group, Y); + + // FIXME Check that public key is residue? + + DH_PrivateKey priv_key(rng, group); + + PK_Key_Agreement ka(priv_key, "Raw"); + + SecureVector<byte> dh_secret = strip_leading_zeros( + ka.derive_key(0, counterparty_key.public_value()).bits_of()); + + if(kex_algo == "DH") + pre_master = dh_secret; + else + { + append_tls_length_value(pre_master, dh_secret, 2); + append_tls_length_value(pre_master, psk.bits_of(), 2); + } + + append_tls_length_value(key_material, priv_key.public_value(), 2); + } + else if(kex_algo == "ECDH" || kex_algo == "ECDHE_PSK") + { + const byte curve_type = reader.get_byte(); + + if(curve_type != 3) + throw Decoding_Error("Server sent non-named ECC curve"); + + const u16bit curve_id = reader.get_u16bit(); + + const std::string name = Supported_Elliptic_Curves::curve_id_to_name(curve_id); + + if(name == "") + throw Decoding_Error("Server sent unknown named curve " + to_string(curve_id)); + + EC_Group group(name); + + MemoryVector<byte> ecdh_key = reader.get_range<byte>(1, 1, 255); + + ECDH_PublicKey counterparty_key(group, OS2ECP(ecdh_key, group.get_curve())); + + ECDH_PrivateKey priv_key(rng, group); + + PK_Key_Agreement ka(priv_key, "Raw"); + + SecureVector<byte> ecdh_secret = ka.derive_key(0, counterparty_key.public_value()).bits_of(); + + if(kex_algo == "ECDH") + pre_master = ecdh_secret; + else + { + append_tls_length_value(pre_master, ecdh_secret, 2); + append_tls_length_value(pre_master, psk.bits_of(), 2); + } + + append_tls_length_value(key_material, priv_key.public_value(), 1); + } + else + { + throw Internal_Error("Client_Key_Exchange: Unknown kex " + + kex_algo); + } } else - return key_material; - } - -/* -* Deserialize a Client Key Exchange message -*/ -void Client_Key_Exchange::deserialize(const MemoryRegion<byte>& buf) - { - if(include_length) { - TLS_Data_Reader reader(buf); - key_material = reader.get_range<byte>(2, 0, 65535); + // No server key exchange msg better mean RSA kex + RSA key in cert + + if(kex_algo != "RSA") + throw Unexpected_Message("No server kex but negotiated kex " + kex_algo); + + if(peer_certs.empty()) + throw Internal_Error("No certificate and no server key exchange"); + + std::auto_ptr<Public_Key> pub_key(peer_certs[0].subject_public_key()); + + if(const RSA_PublicKey* rsa_pub = dynamic_cast<const RSA_PublicKey*>(pub_key.get())) + { + const Protocol_Version pref_version = state->client_hello->version(); + + pre_master = rng.random_vec(48); + pre_master[0] = pref_version.major_version(); + pre_master[1] = pref_version.minor_version(); + + PK_Encryptor_EME encryptor(*rsa_pub, "PKCS1v15"); + + MemoryVector<byte> encrypted_key = encryptor.encrypt(pre_master, rng); + + if(state->version() == Protocol_Version::SSL_V3) + key_material = encrypted_key; // no length field + else + append_tls_length_value(key_material, encrypted_key, 2); + } + else + throw TLS_Exception(Alert::HANDSHAKE_FAILURE, + "Expected a RSA key in server cert but got " + + pub_key->algo_name()); } - else - key_material = buf; + + state->hash.update(writer.send(*this)); } /* -* Return the pre_master_secret +* Read a Client Key Exchange message */ -SecureVector<byte> -Client_Key_Exchange::pre_master_secret(RandomNumberGenerator& rng, - const Private_Key* priv_key, - Version_Code version) +Client_Key_Exchange::Client_Key_Exchange(const MemoryRegion<byte>& contents, + const Handshake_State* state, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng) { + const std::string kex_algo = state->suite.kex_algo(); - if(const DH_PrivateKey* dh_priv = dynamic_cast<const DH_PrivateKey*>(priv_key)) + if(kex_algo == "RSA") { - try { - PK_Key_Agreement ka(*dh_priv, "Raw"); + BOTAN_ASSERT(state->server_certs && !state->server_certs->cert_chain().empty(), + "No server certificate to use for RSA"); - pre_master = ka.derive_key(0, key_material).bits_of(); - } - catch(...) - { - /* - * Something failed in the DH computation. To avoid possible - * timing attacks, randomize the pre-master output and carry - * on, allowing the protocol to fail later in the finished - * checks. - */ - pre_master = rng.random_vec(dh_priv->public_value().size()); - } + const Private_Key* private_key = state->server_rsa_kex_key; - return pre_master; - } - else if(const RSA_PrivateKey* rsa_priv = dynamic_cast<const RSA_PrivateKey*>(priv_key)) - { - PK_Decryptor_EME decryptor(*rsa_priv, "PKCS1v15"); + if(!private_key) + throw Internal_Error("Expected RSA kex but no server kex key set"); + + if(!dynamic_cast<const RSA_PrivateKey*>(private_key)) + throw Internal_Error("Expected RSA key but got " + private_key->algo_name()); + + PK_Decryptor_EME decryptor(*private_key, "PKCS1v15"); - try { - pre_master = decryptor.decrypt(key_material); + Protocol_Version client_version = state->client_hello->version(); + + try + { + if(state->version() == Protocol_Version::SSL_V3) + { + pre_master = decryptor.decrypt(contents); + } + else + { + TLS_Data_Reader reader(contents); + pre_master = decryptor.decrypt(reader.get_range<byte>(2, 0, 65535)); + } if(pre_master.size() != 48 || - make_u16bit(pre_master[0], pre_master[1]) != version) + client_version.major_version() != pre_master[0] || + client_version.minor_version() != pre_master[1]) + { throw Decoding_Error("Client_Key_Exchange: Secret corrupted"); - } + } + } catch(...) { + // Randomize the hide timing channel pre_master = rng.random_vec(48); - pre_master[0] = (version >> 8) & 0xFF; - pre_master[1] = (version ) & 0xFF; + pre_master[0] = client_version.major_version(); + pre_master[1] = client_version.minor_version(); } - - return pre_master; } else - throw Invalid_Argument("Client_Key_Exchange: Bad key for decrypt"); + { + TLS_Data_Reader reader(contents); + + SymmetricKey psk; + + if(kex_algo == "PSK" || kex_algo == "DHE_PSK" || kex_algo == "ECDHE_PSK") + { + const std::string psk_identity = reader.get_string(2, 0, 65535); + + psk = creds.psk("tls-server", + state->client_hello->sni_hostname(), + psk_identity); + + if(psk.length() == 0) + { + if(policy.hide_unknown_users()) + psk = SymmetricKey(rng, 16); + else + throw TLS_Exception(Alert::UNKNOWN_PSK_IDENTITY, + "No PSK for identifier " + psk_identity); + } + + } + + if(kex_algo == "PSK") + { + MemoryVector<byte> zeros(psk.length()); + append_tls_length_value(pre_master, zeros, 2); + append_tls_length_value(pre_master, psk.bits_of(), 2); + } + else if(kex_algo == "DH" || kex_algo == "DHE_PSK" || + kex_algo == "ECDH" || kex_algo == "ECDHE_PSK") + { + const Private_Key& private_key = state->server_kex->server_kex_key(); + + const PK_Key_Agreement_Key* ka_key = + dynamic_cast<const PK_Key_Agreement_Key*>(&private_key); + + if(!ka_key) + throw Internal_Error("Expected key agreement key type but got " + + private_key.algo_name()); + + try + { + PK_Key_Agreement ka(*ka_key, "Raw"); + + MemoryVector<byte> client_pubkey; + + if(ka_key->algo_name() == "DH") + client_pubkey = reader.get_range<byte>(2, 0, 65535); + else + client_pubkey = reader.get_range<byte>(1, 0, 255); + + SecureVector<byte> shared_secret = ka.derive_key(0, client_pubkey).bits_of(); + + if(ka_key->algo_name() == "DH") + shared_secret = strip_leading_zeros(shared_secret); + + if(kex_algo == "DHE_PSK" || kex_algo == "ECDHE_PSK") + { + append_tls_length_value(pre_master, shared_secret, 2); + append_tls_length_value(pre_master, psk.bits_of(), 2); + } + else + pre_master = shared_secret; + } + catch(std::exception &e) + { + /* + * Something failed in the DH computation. To avoid possible + * timing attacks, randomize the pre-master output and carry + * on, allowing the protocol to fail later in the finished + * checks. + */ + pre_master = rng.random_vec(ka_key->public_value().size()); + } + } + else + throw Internal_Error("Client_Key_Exchange: Unknown kex type " + kex_algo); + } } } + +} diff --git a/src/tls/cert_req.cpp b/src/tls/cert_req.cpp index 78c786262..df70cc43d 100644 --- a/src/tls/cert_req.cpp +++ b/src/tls/cert_req.cpp @@ -1,84 +1,137 @@ /* * Certificate Request Message -* (C) 2004-2006 Jack Lloyd +* (C) 2004-2006,2012 Jack Lloyd * * Released under the terms of the Botan license */ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_reader.h> +#include <botan/internal/tls_extensions.h> +#include <botan/tls_record.h> #include <botan/der_enc.h> #include <botan/ber_dec.h> #include <botan/loadstor.h> -#include <botan/secqueue.h> namespace Botan { -/** -* Create a new Certificate Request message -*/ -Certificate_Req::Certificate_Req(Record_Writer& writer, - TLS_Handshake_Hash& hash, - const std::vector<X509_Certificate>& ca_certs, - const std::vector<Certificate_Type>& cert_types) - { - for(size_t i = 0; i != ca_certs.size(); ++i) - names.push_back(ca_certs[i].subject_dn()); +namespace TLS { + +namespace { - if(cert_types.empty()) // default is RSA/DSA is OK +std::string cert_type_code_to_name(byte code) + { + switch(code) { - types.push_back(RSA_CERT); - types.push_back(DSS_CERT); + case 1: + return "RSA"; + case 2: + return "DSA"; + case 64: + return "ECDSA"; + default: + return ""; // DH or something else } - else - types = cert_types; + } - send(writer, hash); +byte cert_type_name_to_code(const std::string& name) + { + if(name == "RSA") + return 1; + if(name == "DSA") + return 2; + if(name == "ECDSA") + return 64; + + throw Invalid_Argument("Unknown cert type " + name); } +} + /** -* Serialize a Certificate Request message +* Create a new Certificate Request message */ -MemoryVector<byte> Certificate_Req::serialize() const +Certificate_Req::Certificate_Req(Record_Writer& writer, + Handshake_Hash& hash, + const Policy& policy, + const std::vector<X509_Certificate>& ca_certs, + Protocol_Version version) { - MemoryVector<byte> buf; + for(size_t i = 0; i != ca_certs.size(); ++i) + names.push_back(ca_certs[i].subject_dn()); - append_tls_length_value(buf, types, 1); + cert_key_types.push_back("RSA"); + cert_key_types.push_back("DSA"); + cert_key_types.push_back("ECDSA"); - DER_Encoder encoder; - for(size_t i = 0; i != names.size(); ++i) - encoder.encode(names[i]); + if(version >= Protocol_Version::TLS_V12) + { + std::vector<std::string> hashes = policy.allowed_hashes(); + std::vector<std::string> sigs = policy.allowed_signature_methods(); - append_tls_length_value(buf, encoder.get_contents(), 2); + for(size_t i = 0; i != hashes.size(); ++i) + for(size_t j = 0; j != sigs.size(); ++j) + m_supported_algos.push_back(std::make_pair(hashes[i], sigs[j])); + } - return buf; + hash.update(writer.send(*this)); } /** * Deserialize a Certificate Request message */ -void Certificate_Req::deserialize(const MemoryRegion<byte>& buf) +Certificate_Req::Certificate_Req(const MemoryRegion<byte>& buf, + Protocol_Version version) { if(buf.size() < 4) throw Decoding_Error("Certificate_Req: Bad certificate request"); - size_t types_size = buf[0]; + TLS_Data_Reader reader(buf); - if(buf.size() < types_size + 3) - throw Decoding_Error("Certificate_Req: Bad certificate request"); + std::vector<byte> cert_type_codes = reader.get_range_vector<byte>(1, 1, 255); - for(size_t i = 0; i != types_size; ++i) - types.push_back(static_cast<Certificate_Type>(buf[i+1])); + for(size_t i = 0; i != cert_type_codes.size(); ++i) + { + const std::string cert_type_name = cert_type_code_to_name(cert_type_codes[i]); - size_t names_size = make_u16bit(buf[types_size+2], buf[types_size+3]); + if(cert_type_name == "") // something we don't know + continue; - if(buf.size() != names_size + types_size + 3) - throw Decoding_Error("Certificate_Req: Bad certificate request"); + cert_key_types.push_back(cert_type_name); + } + + if(version >= Protocol_Version::TLS_V12) + { + std::vector<byte> sig_hash_algs = reader.get_range_vector<byte>(2, 2, 65534); + + if(sig_hash_algs.size() % 2 != 0) + throw Decoding_Error("Bad length for signature IDs in certificate request"); + + for(size_t i = 0; i != sig_hash_algs.size(); i += 2) + { + std::string hash = Signature_Algorithms::hash_algo_name(sig_hash_algs[i]); + std::string sig = Signature_Algorithms::sig_algo_name(sig_hash_algs[i+1]); + m_supported_algos.push_back(std::make_pair(hash, sig)); + } + } + else + { + // The hardcoded settings from previous protocol versions + m_supported_algos.push_back(std::make_pair("TLS.Digest.0", "RSA")); + m_supported_algos.push_back(std::make_pair("SHA-1", "DSA")); + m_supported_algos.push_back(std::make_pair("SHA-1", "ECDSA")); + } + + u16bit purported_size = reader.get_u16bit(); - BER_Decoder decoder(&buf[types_size + 3], names_size); + if(reader.remaining_bytes() != purported_size) + throw Decoding_Error("Inconsistent length in certificate request"); - while(decoder.more_items()) + while(reader.has_remaining()) { + std::vector<byte> name_bits = reader.get_range_vector<byte>(2, 0, 65535); + + BER_Decoder decoder(&name_bits[0], name_bits.size()); X509_DN name; decoder.decode(name); names.push_back(name); @@ -86,71 +139,101 @@ void Certificate_Req::deserialize(const MemoryRegion<byte>& buf) } /** -* Create a new Certificate message +* Serialize a Certificate Request message */ -Certificate::Certificate(Record_Writer& writer, - TLS_Handshake_Hash& hash, - const std::vector<X509_Certificate>& cert_list) +MemoryVector<byte> Certificate_Req::serialize() const { - certs = cert_list; - send(writer, hash); - } + MemoryVector<byte> buf; -/** -* Serialize a Certificate message -*/ -MemoryVector<byte> Certificate::serialize() const - { - MemoryVector<byte> buf(3); + std::vector<byte> cert_types; + + for(size_t i = 0; i != cert_key_types.size(); ++i) + cert_types.push_back(cert_type_name_to_code(cert_key_types[i])); + + append_tls_length_value(buf, cert_types, 1); - for(size_t i = 0; i != certs.size(); ++i) + if(!m_supported_algos.empty()) { - MemoryVector<byte> raw_cert = certs[i].BER_encode(); - const size_t cert_size = raw_cert.size(); - for(size_t i = 0; i != 3; ++i) - buf.push_back(get_byte<u32bit>(i+1, cert_size)); - buf += raw_cert; + buf += Signature_Algorithms(m_supported_algos).serialize(); } - const size_t buf_size = buf.size() - 3; - for(size_t i = 0; i != 3; ++i) - buf[i] = get_byte<u32bit>(i+1, buf_size); + for(size_t i = 0; i != names.size(); ++i) + { + DER_Encoder encoder; + encoder.encode(names[i]); + + append_tls_length_value(buf, encoder.get_contents(), 2); + } return buf; } /** +* Create a new Certificate message +*/ +Certificate::Certificate(Record_Writer& writer, + Handshake_Hash& hash, + const std::vector<X509_Certificate>& cert_list) : + m_certs(cert_list) + { + hash.update(writer.send(*this)); + } + +/** * Deserialize a Certificate message */ -void Certificate::deserialize(const MemoryRegion<byte>& buf) +Certificate::Certificate(const MemoryRegion<byte>& buf) { if(buf.size() < 3) throw Decoding_Error("Certificate: Message malformed"); const size_t total_size = make_u32bit(0, buf[0], buf[1], buf[2]); - SecureQueue queue; - queue.write(&buf[3], buf.size() - 3); - - if(queue.size() != total_size) + if(total_size != buf.size() - 3) throw Decoding_Error("Certificate: Message malformed"); - while(queue.size()) + const byte* certs = &buf[3]; + + while(certs != buf.end()) { - if(queue.size() < 3) + if(buf.end() - certs < 3) throw Decoding_Error("Certificate: Message malformed"); - byte len[3]; - queue.read(len, 3); - - const size_t cert_size = make_u32bit(0, len[0], len[1], len[2]); - const size_t original_size = queue.size(); + const size_t cert_size = make_u32bit(0, certs[0], certs[1], certs[2]); - X509_Certificate cert(queue); - if(queue.size() + cert_size != original_size) + if(buf.end() - certs < (3 + cert_size)) throw Decoding_Error("Certificate: Message malformed"); - certs.push_back(cert); + + DataSource_Memory cert_buf(&certs[3], cert_size); + m_certs.push_back(X509_Certificate(cert_buf)); + + certs += cert_size + 3; } } +/** +* Serialize a Certificate message +*/ +MemoryVector<byte> Certificate::serialize() const + { + MemoryVector<byte> buf(3); + + for(size_t i = 0; i != m_certs.size(); ++i) + { + MemoryVector<byte> raw_cert = m_certs[i].BER_encode(); + const size_t cert_size = raw_cert.size(); + for(size_t i = 0; i != 3; ++i) + buf.push_back(get_byte<u32bit>(i+1, cert_size)); + buf += raw_cert; + } + + const size_t buf_size = buf.size() - 3; + for(size_t i = 0; i != 3; ++i) + buf[i] = get_byte<u32bit>(i+1, buf_size); + + return buf; + } + +} + } diff --git a/src/tls/cert_ver.cpp b/src/tls/cert_ver.cpp index 5a20e3029..73acf3de1 100644 --- a/src/tls/cert_ver.cpp +++ b/src/tls/cert_ver.cpp @@ -1,50 +1,69 @@ /* * Certificate Verify Message -* (C) 2004-2011 Jack Lloyd +* (C) 2004,2006,2011,2012 Jack Lloyd * * Released under the terms of the Botan license */ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_reader.h> +#include <botan/internal/tls_extensions.h> +#include <botan/tls_record.h> #include <botan/internal/assert.h> -#include <botan/tls_exceptn.h> -#include <botan/pubkey.h> -#include <botan/rsa.h> -#include <botan/dsa.h> -#include <botan/loadstor.h> #include <memory> namespace Botan { +namespace TLS { + /* * Create a new Certificate Verify message */ Certificate_Verify::Certificate_Verify(Record_Writer& writer, - TLS_Handshake_Hash& hash, + Handshake_State* state, RandomNumberGenerator& rng, const Private_Key* priv_key) { BOTAN_ASSERT_NONNULL(priv_key); - std::string padding = ""; - Signature_Format format = IEEE_1363; + std::pair<std::string, Signature_Format> format = + state->choose_sig_format(priv_key, hash_algo, sig_algo, true); + + PK_Signer signer(*priv_key, format.first, format.second); - if(priv_key->algo_name() == "RSA") - padding = "EMSA3(TLS.Digest.0)"; - else if(priv_key->algo_name() == "DSA") + if(state->version() == Protocol_Version::SSL_V3) { - padding = "EMSA1(SHA-1)"; - format = DER_SEQUENCE; + SecureVector<byte> md5_sha = state->hash.final_ssl3( + state->keys.master_secret()); + + if(priv_key->algo_name() == "DSA") + signature = signer.sign_message(&md5_sha[16], md5_sha.size()-16, rng); + else + signature = signer.sign_message(md5_sha, rng); } else - throw Invalid_Argument(priv_key->algo_name() + - " is invalid/unknown for TLS signatures"); + { + signature = signer.sign_message(state->hash.get_contents(), rng); + } + + state->hash.update(writer.send(*this)); + } + +/* +* Deserialize a Certificate Verify message +*/ +Certificate_Verify::Certificate_Verify(const MemoryRegion<byte>& buf, + Protocol_Version version) + { + TLS_Data_Reader reader(buf); - PK_Signer signer(*priv_key, padding, format); + if(version >= Protocol_Version::TLS_V12) + { + hash_algo = Signature_Algorithms::hash_algo_name(reader.get_byte()); + sig_algo = Signature_Algorithms::sig_algo_name(reader.get_byte()); + } - signature = signer.sign_message(hash.final(), rng); - send(writer, hash); + signature = reader.get_range<byte>(2, 0, 65535); } /* @@ -54,6 +73,12 @@ MemoryVector<byte> Certificate_Verify::serialize() const { MemoryVector<byte> buf; + if(hash_algo != "" && sig_algo != "") + { + buf.push_back(Signature_Algorithms::hash_algo_code(hash_algo)); + buf.push_back(Signature_Algorithms::sig_algo_code(sig_algo)); + } + const u16bit sig_len = signature.size(); buf.push_back(get_byte(0, sig_len)); buf.push_back(get_byte(1, sig_len)); @@ -63,57 +88,30 @@ MemoryVector<byte> Certificate_Verify::serialize() const } /* -* Deserialize a Certificate Verify message -*/ -void Certificate_Verify::deserialize(const MemoryRegion<byte>& buf) - { - TLS_Data_Reader reader(buf); - signature = reader.get_range<byte>(2, 0, 65535); - } - -/* * Verify a Certificate Verify message */ bool Certificate_Verify::verify(const X509_Certificate& cert, - TLS_Handshake_Hash& hash, - Version_Code version, - const SecureVector<byte>& master_secret) + Handshake_State* state) { std::auto_ptr<Public_Key> key(cert.subject_public_key()); - std::string padding = ""; - Signature_Format format = IEEE_1363; + std::pair<std::string, Signature_Format> format = + state->understand_sig_format(key.get(), hash_algo, sig_algo, true); - if(key->algo_name() == "RSA") - { - padding = "EMSA3(TLS.Digest.0)"; - } - else if(key->algo_name() == "DSA") - { - if(version == SSL_V3) - padding = "Raw"; - else - padding = "EMSA1(SHA-1)"; - format = DER_SEQUENCE; - } - else - throw Invalid_Argument(key->algo_name() + - " is invalid/unknown for TLS signatures"); - - PK_Verifier verifier(*key, padding, format); + PK_Verifier verifier(*key, format.first, format.second); - if(version == SSL_V3) + if(state->version() == Protocol_Version::SSL_V3) { - SecureVector<byte> md5_sha = hash.final_ssl3(master_secret); + SecureVector<byte> md5_sha = state->hash.final_ssl3( + state->keys.master_secret()); return verifier.verify_message(&md5_sha[16], md5_sha.size()-16, &signature[0], signature.size()); } - else if(version == TLS_V10 || version == TLS_V11) - return verifier.verify_message(hash.get_contents(), signature); - else - throw TLS_Exception(PROTOCOL_VERSION, - "Unknown TLS version in certificate verification"); + + return verifier.verify_message(state->hash.get_contents(), signature); } } + +} diff --git a/src/tls/finished.cpp b/src/tls/finished.cpp index 70b714bfd..a494bf932 100644 --- a/src/tls/finished.cpp +++ b/src/tls/finished.cpp @@ -1,71 +1,33 @@ /* * Finished Message -* (C) 2004-2006 Jack Lloyd +* (C) 2004-2006,2012 Jack Lloyd * * Released under the terms of the Botan license */ #include <botan/internal/tls_messages.h> -#include <botan/prf_tls.h> +#include <botan/tls_record.h> +#include <memory> namespace Botan { -/* -* Create a new Finished message -*/ -Finished::Finished(Record_Writer& writer, - TLS_Handshake_Hash& hash, - Version_Code version, - Connection_Side side, - const MemoryRegion<byte>& master_secret) - { - verification_data = compute_verify(master_secret, hash, side, version); - send(writer, hash); - } - -/* -* Serialize a Finished message -*/ -MemoryVector<byte> Finished::serialize() const - { - return verification_data; - } - -/* -* Deserialize a Finished message -*/ -void Finished::deserialize(const MemoryRegion<byte>& buf) - { - verification_data = buf; - } +namespace TLS { -/* -* Verify a Finished message -*/ -bool Finished::verify(const MemoryRegion<byte>& secret, - Version_Code version, - const TLS_Handshake_Hash& hash, - Connection_Side side) - { - MemoryVector<byte> computed = compute_verify(secret, hash, side, version); - if(computed == verification_data) - return true; - return false; - } +namespace { /* * Compute the verify_data */ -MemoryVector<byte> Finished::compute_verify(const MemoryRegion<byte>& secret, - TLS_Handshake_Hash hash, - Connection_Side side, - Version_Code version) +MemoryVector<byte> finished_compute_verify(Handshake_State* state, + Connection_Side side) { - if(version == SSL_V3) + if(state->version() == Protocol_Version::SSL_V3) { const byte SSL_CLIENT_LABEL[] = { 0x43, 0x4C, 0x4E, 0x54 }; const byte SSL_SERVER_LABEL[] = { 0x53, 0x52, 0x56, 0x52 }; + Handshake_Hash hash = state->hash; // don't modify state + MemoryVector<byte> ssl3_finished; if(side == CLIENT) @@ -73,9 +35,9 @@ MemoryVector<byte> Finished::compute_verify(const MemoryRegion<byte>& secret, else hash.update(SSL_SERVER_LABEL, sizeof(SSL_SERVER_LABEL)); - return hash.final_ssl3(secret); + return hash.final_ssl3(state->keys.master_secret()); } - else if(version == TLS_V10 || version == TLS_V11) + else { const byte TLS_CLIENT_LABEL[] = { 0x63, 0x6C, 0x69, 0x65, 0x6E, 0x74, 0x20, 0x66, 0x69, 0x6E, 0x69, @@ -85,19 +47,58 @@ MemoryVector<byte> Finished::compute_verify(const MemoryRegion<byte>& secret, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x20, 0x66, 0x69, 0x6E, 0x69, 0x73, 0x68, 0x65, 0x64 }; - TLS_PRF prf; + std::auto_ptr<KDF> prf(state->protocol_specific_prf()); MemoryVector<byte> input; if(side == CLIENT) input += std::make_pair(TLS_CLIENT_LABEL, sizeof(TLS_CLIENT_LABEL)); else input += std::make_pair(TLS_SERVER_LABEL, sizeof(TLS_SERVER_LABEL)); - input += hash.final(); - return prf.derive_key(12, secret, input); + input += state->hash.final(state->version(), state->suite.mac_algo()); + + return prf->derive_key(12, state->keys.master_secret(), input); } - else - throw Invalid_Argument("Finished message: Unknown protocol version"); } } + +/* +* Create a new Finished message +*/ +Finished::Finished(Record_Writer& writer, + Handshake_State* state, + Connection_Side side) + { + verification_data = finished_compute_verify(state, side); + state->hash.update(writer.send(*this)); + } + +/* +* Serialize a Finished message +*/ +MemoryVector<byte> Finished::serialize() const + { + return verification_data; + } + +/* +* Deserialize a Finished message +*/ +Finished::Finished(const MemoryRegion<byte>& buf) + { + verification_data = buf; + } + +/* +* Verify a Finished message +*/ +bool Finished::verify(Handshake_State* state, + Connection_Side side) + { + return (verification_data == finished_compute_verify(state, side)); + } + +} + +} diff --git a/src/tls/info.txt b/src/tls/info.txt index 9473b0ae2..822914a3d 100644 --- a/src/tls/info.txt +++ b/src/tls/info.txt @@ -1,13 +1,14 @@ -define SSL_TLS +define TLS <comment> -The SSL/TLS code is complex, new, and not yet reviewed, there may be +The TLS code is complex, new, and not yet reviewed, there may be serious bugs or security issues. </comment> uses_tr1 yes <header:public> +tls_alert.h tls_channel.h tls_client.h tls_exceptn.h @@ -17,13 +18,14 @@ tls_record.h tls_server.h tls_session.h tls_session_manager.h -tls_suites.h +tls_ciphersuite.h +tls_version.h </header:public> <header:internal> -tls_alerts.h tls_extensions.h tls_handshake_hash.h +tls_handshake_reader.h tls_handshake_state.h tls_messages.h tls_reader.h @@ -31,11 +33,13 @@ tls_session_key.h </header:internal> <source> +tls_alert.cpp c_hello.cpp c_kex.cpp cert_req.cpp cert_ver.cpp finished.cpp +hello_verify.cpp next_protocol.cpp rec_read.cpp rec_wri.cpp @@ -45,13 +49,15 @@ tls_channel.cpp tls_client.cpp tls_extensions.cpp tls_handshake_hash.cpp +tls_handshake_reader.cpp tls_handshake_state.cpp tls_policy.cpp tls_server.cpp tls_session.cpp tls_session_key.cpp tls_session_manager.cpp -tls_suites.cpp +tls_ciphersuite.cpp +tls_version.cpp </source> <requires> @@ -61,6 +67,8 @@ asn1 des dh dsa +ecdh +ecdsa eme_pkcs emsa3 filters @@ -70,7 +78,9 @@ prf_ssl3 prf_tls rng rsa +seed sha1 +sha2_32 ssl3mac x509cert </requires> diff --git a/src/tls/next_protocol.cpp b/src/tls/next_protocol.cpp index 2d2e2e599..17b77fb6e 100644 --- a/src/tls/next_protocol.cpp +++ b/src/tls/next_protocol.cpp @@ -8,15 +8,27 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_extensions.h> #include <botan/internal/tls_reader.h> +#include <botan/tls_record.h> namespace Botan { +namespace TLS { + Next_Protocol::Next_Protocol(Record_Writer& writer, - TLS_Handshake_Hash& hash, + Handshake_Hash& hash, const std::string& protocol) : m_protocol(protocol) { - send(writer, hash); + hash.update(writer.send(*this)); + } + +Next_Protocol::Next_Protocol(const MemoryRegion<byte>& buf) + { + TLS_Data_Reader reader(buf); + + m_protocol = reader.get_string(1, 0, 255); + + reader.get_range_vector<byte>(1, 0, 255); // padding, ignored } MemoryVector<byte> Next_Protocol::serialize() const @@ -38,13 +50,6 @@ MemoryVector<byte> Next_Protocol::serialize() const return buf; } -void Next_Protocol::deserialize(const MemoryRegion<byte>& buf) - { - TLS_Data_Reader reader(buf); - - m_protocol = reader.get_string(1, 0, 255); - - reader.get_range_vector<byte>(1, 0, 255); // padding, ignored - } +} } diff --git a/src/tls/rec_read.cpp b/src/tls/rec_read.cpp index f8e6bab26..d1fab4692 100644 --- a/src/tls/rec_read.cpp +++ b/src/tls/rec_read.cpp @@ -14,6 +14,8 @@ namespace Botan { +namespace TLS { + Record_Reader::Record_Reader() : m_readbuf(TLS_HEADER_SIZE + MAX_CIPHERTEXT_SIZE), m_mac(0) @@ -39,7 +41,7 @@ void Record_Reader::reset() m_block_size = 0; m_iv_size = 0; - m_major = m_minor = 0; + m_version = Protocol_Version(); m_seq_no = 0; set_maximum_fragment_size(0); } @@ -55,27 +57,27 @@ void Record_Reader::set_maximum_fragment_size(size_t max_fragment) /* * Set the version to use */ -void Record_Reader::set_version(Version_Code version) +void Record_Reader::set_version(Protocol_Version version) { - if(version != SSL_V3 && version != TLS_V10 && version != TLS_V11) - throw Invalid_Argument("Record_Reader: Invalid protocol version"); - - m_major = (version >> 8) & 0xFF; - m_minor = (version & 0xFF); + m_version = version; } /* * Set the keys for reading */ -void Record_Reader::activate(const TLS_Cipher_Suite& suite, - const SessionKeys& keys, - Connection_Side side) +void Record_Reader::activate(Connection_Side side, + const Ciphersuite& suite, + const Session_Keys& keys, + byte compression_method) { m_cipher.reset(); delete m_mac; m_mac = 0; m_seq_no = 0; + if(compression_method != NO_COMPRESSION) + throw Internal_Error("Negotiated unknown compression algorithm"); + SymmetricKey mac_key, cipher_key; InitializationVector iv; @@ -103,7 +105,7 @@ void Record_Reader::activate(const TLS_Cipher_Suite& suite, ); m_block_size = block_size_of(cipher_algo); - if(m_major > 3 || (m_major == 3 && m_minor >= 2)) + if(m_version >= Protocol_Version::TLS_V11) m_iv_size = m_block_size; else m_iv_size = 0; @@ -121,7 +123,7 @@ void Record_Reader::activate(const TLS_Cipher_Suite& suite, { Algorithm_Factory& af = global_state().algorithm_factory(); - if(m_major == 3 && m_minor == 0) + if(m_version == Protocol_Version::SSL_V3) m_mac = af.make_mac("SSL3-MAC(" + mac_algo + ")"); else m_mac = af.make_mac("HMAC(" + mac_algo + ")"); @@ -145,7 +147,7 @@ size_t Record_Reader::fill_buffer_to(const byte*& input, const size_t taken = std::min(input_size, desired - m_readbuf_pos); if(taken > space_available) - throw TLS_Exception(RECORD_OVERFLOW, + throw TLS_Exception(Alert::RECORD_OVERFLOW, "Record is larger than allowed maximum size"); copy_mem(&m_readbuf[m_readbuf_pos], input, taken); @@ -182,7 +184,7 @@ size_t Record_Reader::add_input(const byte input_array[], size_t input_sz, if((!m_mac) && (m_readbuf[0] & 0x80) && (m_readbuf[2] == 1)) { if(m_readbuf[3] == 0 && m_readbuf[4] == 2) - throw TLS_Exception(PROTOCOL_VERSION, + throw TLS_Exception(Alert::PROTOCOL_VERSION, "Client claims to only support SSLv2, rejecting"); if(m_readbuf[3] >= 3) // SSLv2 mapped TLS hello, then? @@ -216,20 +218,24 @@ size_t Record_Reader::add_input(const byte input_array[], size_t input_sz, m_readbuf[0] != HANDSHAKE && m_readbuf[0] != APPLICATION_DATA) { - throw TLS_Exception(UNEXPECTED_MESSAGE, - "Unknown record type " + to_string(m_readbuf[0]) + - " from counterparty"); + throw Unexpected_Message( + "Unknown record type " + to_string(m_readbuf[0]) + " from counterparty"); } - const u16bit version = make_u16bit(m_readbuf[1], m_readbuf[2]); const size_t record_len = make_u16bit(m_readbuf[3], m_readbuf[4]); - if(m_major && (m_readbuf[1] != m_major || m_readbuf[2] != m_minor)) - throw TLS_Exception(PROTOCOL_VERSION, - "Got unexpected version from counterparty"); + if(m_version.major_version()) + { + if(m_readbuf[1] != m_version.major_version() || + m_readbuf[2] != m_version.minor_version()) + { + throw TLS_Exception(Alert::PROTOCOL_VERSION, + "Got unexpected version from counterparty"); + } + } if(record_len > MAX_CIPHERTEXT_SIZE) - throw TLS_Exception(RECORD_OVERFLOW, + throw TLS_Exception(Alert::RECORD_OVERFLOW, "Got message that exceeds maximum size"); if(size_t needed = fill_buffer_to(input, input_sz, consumed, @@ -247,7 +253,7 @@ size_t Record_Reader::add_input(const byte input_array[], size_t input_sz, m_readbuf[0] != ALERT && m_readbuf[0] != HANDSHAKE) { - throw TLS_Exception(DECODE_ERROR, "Invalid msg type received during handshake"); + throw Decoding_Error("Invalid msg type received during handshake"); } msg_type = m_readbuf[0]; @@ -283,7 +289,7 @@ size_t Record_Reader::add_input(const byte input_array[], size_t input_sz, * This particular countermeasure is recommended in the TLS 1.2 * spec (RFC 5246) in section 6.2.3.2 */ - if(version == SSL_V3) + if(m_version == Protocol_Version::SSL_V3) { if(pad_value > m_block_size) pad_size = 0; @@ -309,14 +315,16 @@ size_t Record_Reader::add_input(const byte input_array[], size_t input_sz, const u16bit plain_length = record_len - mac_pad_iv_size; if(plain_length > m_max_fragment) - throw TLS_Exception(RECORD_OVERFLOW, "Plaintext record is too large"); + throw TLS_Exception(Alert::RECORD_OVERFLOW, "Plaintext record is too large"); m_mac->update_be(m_seq_no); m_mac->update(m_readbuf[0]); // msg_type - if(version != SSL_V3) - for(size_t i = 0; i != 2; ++i) - m_mac->update(get_byte(i, version)); + if(m_version != Protocol_Version::SSL_V3) + { + m_mac->update(m_version.major_version()); + m_mac->update(m_version.minor_version()); + } m_mac->update_be(plain_length); m_mac->update(&m_readbuf[TLS_HEADER_SIZE + m_iv_size], plain_length); @@ -328,7 +336,7 @@ size_t Record_Reader::add_input(const byte input_array[], size_t input_sz, const size_t mac_offset = record_len - (m_macbuf.size() + pad_size); if(!same_mem(&m_readbuf[TLS_HEADER_SIZE + mac_offset], &m_macbuf[0], m_macbuf.size())) - throw TLS_Exception(BAD_RECORD_MAC, "Message authentication failure"); + throw TLS_Exception(Alert::BAD_RECORD_MAC, "Message authentication failure"); msg_type = m_readbuf[0]; @@ -339,3 +347,5 @@ size_t Record_Reader::add_input(const byte input_array[], size_t input_sz, } } + +} diff --git a/src/tls/rec_wri.cpp b/src/tls/rec_wri.cpp index fdcb98cc9..cc7c6f79a 100644 --- a/src/tls/rec_wri.cpp +++ b/src/tls/rec_wri.cpp @@ -6,6 +6,7 @@ */ #include <botan/tls_record.h> +#include <botan/internal/tls_messages.h> #include <botan/internal/tls_session_key.h> #include <botan/internal/tls_handshake_hash.h> #include <botan/lookup.h> @@ -16,6 +17,8 @@ namespace Botan { +namespace TLS { + /* * Record_Writer Constructor */ @@ -46,8 +49,7 @@ void Record_Writer::reset() delete m_mac; m_mac = 0; - m_major = 0; - m_minor = 0; + m_version = Protocol_Version(); m_block_size = 0; m_mac_size = 0; m_iv_size = 0; @@ -58,26 +60,26 @@ void Record_Writer::reset() /* * Set the version to use */ -void Record_Writer::set_version(Version_Code version) +void Record_Writer::set_version(Protocol_Version version) { - if(version != SSL_V3 && version != TLS_V10 && version != TLS_V11) - throw Invalid_Argument("Record_Writer: Invalid protocol version"); - - m_major = (version >> 8) & 0xFF; - m_minor = (version & 0xFF); + m_version = version; } /* * Set the keys for writing */ -void Record_Writer::activate(const TLS_Cipher_Suite& suite, - const SessionKeys& keys, - Connection_Side side) +void Record_Writer::activate(Connection_Side side, + const Ciphersuite& suite, + const Session_Keys& keys, + byte compression_method) { m_cipher.reset(); delete m_mac; m_mac = 0; + if(compression_method != NO_COMPRESSION) + throw Internal_Error("Negotiated unknown compression algorithm"); + /* RFC 4346: A sequence number is incremented after each record: specifically, @@ -113,7 +115,7 @@ void Record_Writer::activate(const TLS_Cipher_Suite& suite, ); m_block_size = block_size_of(cipher_algo); - if(m_major > 3 || (m_major == 3 && m_minor >= 2)) + if(m_version >= Protocol_Version::TLS_V11) m_iv_size = m_block_size; else m_iv_size = 0; @@ -131,7 +133,7 @@ void Record_Writer::activate(const TLS_Cipher_Suite& suite, { Algorithm_Factory& af = global_state().algorithm_factory(); - if(m_major == 3 && m_minor == 0) + if(m_version == Protocol_Version::SSL_V3) m_mac = af.make_mac("SSL3-MAC(" + mac_algo + ")"); else m_mac = af.make_mac("HMAC(" + mac_algo + ")"); @@ -143,6 +145,25 @@ void Record_Writer::activate(const TLS_Cipher_Suite& suite, throw Invalid_Argument("Record_Writer: Unknown hash " + mac_algo); } +MemoryVector<byte> Record_Writer::send(Handshake_Message& msg) + { + const MemoryVector<byte> buf = msg.serialize(); + MemoryVector<byte> send_buf(4); + + const size_t buf_size = buf.size(); + + send_buf[0] = msg.type(); + + for(size_t i = 1; i != 4; ++i) + send_buf[i] = get_byte<u32bit>(i, buf_size); + + send_buf += buf; + + send(HANDSHAKE, &send_buf[0], send_buf.size()); + + return send_buf; + } + /* * Send one or more records to the other side */ @@ -185,105 +206,112 @@ void Record_Writer::send(byte type, const byte input[], size_t length) void Record_Writer::send_record(byte type, const byte input[], size_t length) { if(length >= MAX_PLAINTEXT_SIZE) - throw TLS_Exception(INTERNAL_ERROR, - "Record_Writer: Compressed packet is too big"); + throw Internal_Error("Record_Writer: Compressed packet is too big"); - if(m_mac_size == 0) + if(m_mac_size == 0) // initial unencrypted handshake records { const byte header[TLS_HEADER_SIZE] = { type, - m_major, - m_minor, + m_version.major_version(), + m_version.minor_version(), get_byte<u16bit>(0, length), get_byte<u16bit>(1, length) }; m_output_fn(header, TLS_HEADER_SIZE); m_output_fn(input, length); + return; } - else + + m_mac->update_be(m_seq_no); + m_mac->update(type); + + if(m_version != Protocol_Version::SSL_V3) { - m_mac->update_be(m_seq_no); - m_mac->update(type); + m_mac->update(m_version.major_version()); + m_mac->update(m_version.minor_version()); + } - if(m_major > 3 || (m_major == 3 && m_minor != 0)) - { - m_mac->update(m_major); - m_mac->update(m_minor); - } + m_mac->update(get_byte<u16bit>(0, length)); + m_mac->update(get_byte<u16bit>(1, length)); + m_mac->update(input, length); - m_mac->update(get_byte<u16bit>(0, length)); - m_mac->update(get_byte<u16bit>(1, length)); - m_mac->update(input, length); + const size_t buf_size = round_up(m_iv_size + length + + m_mac->output_length() + + (m_block_size ? 1 : 0), + m_block_size); - const size_t buf_size = round_up(m_iv_size + length + - m_mac->output_length() + - (m_block_size ? 1 : 0), - m_block_size); + if(buf_size >= MAX_CIPHERTEXT_SIZE) + throw Internal_Error("Record_Writer: Record is too big"); - if(buf_size >= MAX_CIPHERTEXT_SIZE) - throw TLS_Exception(INTERNAL_ERROR, - "Record_Writer: Record is too big"); + BOTAN_ASSERT(m_writebuf.size() >= TLS_HEADER_SIZE + MAX_CIPHERTEXT_SIZE, + "Write buffer is big enough"); - BOTAN_ASSERT(m_writebuf.size() >= TLS_HEADER_SIZE + MAX_CIPHERTEXT_SIZE, - "Write buffer is big enough"); + // TLS record header + m_writebuf[0] = type; + m_writebuf[1] = m_version.major_version(); + m_writebuf[2] = m_version.minor_version(); + m_writebuf[3] = get_byte<u16bit>(0, buf_size); + m_writebuf[4] = get_byte<u16bit>(1, buf_size); - // TLS record header - m_writebuf[0] = type; - m_writebuf[1] = m_major; - m_writebuf[2] = m_minor; - m_writebuf[3] = get_byte<u16bit>(0, buf_size); - m_writebuf[4] = get_byte<u16bit>(1, buf_size); + byte* buf_write_ptr = &m_writebuf[TLS_HEADER_SIZE]; - byte* buf_write_ptr = &m_writebuf[TLS_HEADER_SIZE]; + if(m_iv_size) + { + RandomNumberGenerator& rng = global_state().global_rng(); + rng.randomize(buf_write_ptr, m_iv_size); + buf_write_ptr += m_iv_size; + } - if(m_iv_size) - { - RandomNumberGenerator& rng = global_state().global_rng(); - rng.randomize(buf_write_ptr, m_iv_size); - buf_write_ptr += m_iv_size; - } + copy_mem(buf_write_ptr, input, length); + buf_write_ptr += length; - copy_mem(buf_write_ptr, input, length); - buf_write_ptr += length; + m_mac->final(buf_write_ptr); + buf_write_ptr += m_mac->output_length(); - m_mac->final(buf_write_ptr); - buf_write_ptr += m_mac->output_length(); + if(m_block_size) + { + const size_t pad_val = + buf_size - (m_iv_size + length + m_mac->output_length() + 1); - if(m_block_size) + for(size_t i = 0; i != pad_val + 1; ++i) { - const size_t pad_val = - buf_size - (m_iv_size + length + m_mac->output_length() + 1); - - for(size_t i = 0; i != pad_val + 1; ++i) - { - *buf_write_ptr = pad_val; - buf_write_ptr += 1; - } + *buf_write_ptr = pad_val; + buf_write_ptr += 1; } + } - // FIXME: this could be done in-place without copying - m_cipher.process_msg(&m_writebuf[TLS_HEADER_SIZE], buf_size); - const size_t got_back = m_cipher.read(&m_writebuf[TLS_HEADER_SIZE], buf_size, Pipe::LAST_MESSAGE); + // FIXME: this could be done in-place without copying + m_cipher.process_msg(&m_writebuf[TLS_HEADER_SIZE], buf_size); - BOTAN_ASSERT_EQUAL(got_back, buf_size, "Cipher encrypted full amount"); + const size_t ctext_size = m_cipher.remaining(Pipe::LAST_MESSAGE); - BOTAN_ASSERT_EQUAL(m_cipher.remaining(Pipe::LAST_MESSAGE), 0, - "No data remains in pipe"); + BOTAN_ASSERT_EQUAL(ctext_size, buf_size, "Cipher encrypted full amount"); - m_output_fn(&m_writebuf[0], TLS_HEADER_SIZE + buf_size); + if(ctext_size > MAX_CIPHERTEXT_SIZE) + throw Internal_Error("Produced ciphertext larger than protocol allows"); - m_seq_no++; - } + m_cipher.read(&m_writebuf[TLS_HEADER_SIZE], ctext_size, Pipe::LAST_MESSAGE); + + BOTAN_ASSERT_EQUAL(m_cipher.remaining(Pipe::LAST_MESSAGE), 0, + "No data remains in pipe"); + + m_output_fn(&m_writebuf[0], TLS_HEADER_SIZE + buf_size); + + m_seq_no++; } /* * Send an alert */ -void Record_Writer::alert(Alert_Level level, Alert_Type type) +void Record_Writer::send_alert(const Alert& alert) { - byte alert[2] = { level, type }; - send(ALERT, alert, sizeof(alert)); + const byte alert_bits[2] = { alert.is_fatal() ? 2 : 1, + alert.type() }; + + send(ALERT, alert_bits, sizeof(alert_bits)); } } + +} diff --git a/src/tls/s_hello.cpp b/src/tls/s_hello.cpp index 90e18ae90..9bcbdb5e9 100644 --- a/src/tls/s_hello.cpp +++ b/src/tls/s_hello.cpp @@ -14,15 +14,17 @@ namespace Botan { +namespace TLS { + /* * Create a new Server Hello message */ Server_Hello::Server_Hello(Record_Writer& writer, - TLS_Handshake_Hash& hash, - Version_Code version, + Handshake_Hash& hash, + Protocol_Version version, const Client_Hello& c_hello, - const std::vector<X509_Certificate>& certs, - const TLS_Policy& policy, + const std::vector<std::string>& available_cert_types, + const Policy& policy, bool client_has_secure_renegotiation, const MemoryRegion<byte>& reneg_info, bool client_has_npn, @@ -37,36 +39,28 @@ Server_Hello::Server_Hello(Record_Writer& writer, m_next_protocol(client_has_npn), m_next_protocols(next_protocols) { - bool have_rsa = false, have_dsa = false; - - for(size_t i = 0; i != certs.size(); ++i) - { - Public_Key* key = certs[i].subject_public_key(); - if(key->algo_name() == "RSA") - have_rsa = true; - - if(key->algo_name() == "DSA") - have_dsa = true; - } - - suite = policy.choose_suite(c_hello.ciphersuites(), have_rsa, have_dsa, false); + suite = policy.choose_suite( + c_hello.ciphersuites(), + available_cert_types, + policy.choose_curve(c_hello.supported_ecc_curves()) != "", + false); if(suite == 0) - throw TLS_Exception(HANDSHAKE_FAILURE, + throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Can't agree on a ciphersuite with client"); comp_method = policy.choose_compression(c_hello.compression_methods()); - send(writer, hash); + hash.update(writer.send(*this)); } /* * Create a new Server Hello message */ Server_Hello::Server_Hello(Record_Writer& writer, - TLS_Handshake_Hash& hash, + Handshake_Hash& hash, const MemoryRegion<byte>& session_id, - Version_Code ver, + Protocol_Version ver, u16bit ciphersuite, byte compression, size_t max_fragment_size, @@ -86,47 +80,13 @@ Server_Hello::Server_Hello(Record_Writer& writer, m_next_protocol(client_has_npn), m_next_protocols(next_protocols) { - send(writer, hash); - } - -/* -* Serialize a Server Hello message -*/ -MemoryVector<byte> Server_Hello::serialize() const - { - MemoryVector<byte> buf; - - buf.push_back(static_cast<byte>(s_version >> 8)); - buf.push_back(static_cast<byte>(s_version )); - buf += s_random; - - append_tls_length_value(buf, m_session_id, 1); - - buf.push_back(get_byte(0, suite)); - buf.push_back(get_byte(1, suite)); - - buf.push_back(comp_method); - - TLS_Extensions extensions; - - if(m_secure_renegotiation) - extensions.push_back(new Renegotation_Extension(m_renegotiation_info)); - - if(m_fragment_size != 0) - extensions.push_back(new Maximum_Fragment_Length(m_fragment_size)); - - if(m_next_protocol) - extensions.push_back(new Next_Protocol_Notification(m_next_protocols)); - - buf += extensions.serialize(); - - return buf; + hash.update(writer.send(*this)); } /* * Deserialize a Server Hello message */ -void Server_Hello::deserialize(const MemoryRegion<byte>& buf) +Server_Hello::Server_Hello(const MemoryRegion<byte>& buf) { m_secure_renegotiation = false; m_next_protocol = false; @@ -136,11 +96,17 @@ void Server_Hello::deserialize(const MemoryRegion<byte>& buf) TLS_Data_Reader reader(buf); - s_version = static_cast<Version_Code>(reader.get_u16bit()); + const byte major_version = reader.get_byte(); + const byte minor_version = reader.get_byte(); + + s_version = Protocol_Version(major_version, minor_version); - if(s_version != SSL_V3 && s_version != TLS_V10 && s_version != TLS_V11) + if(s_version != Protocol_Version::SSL_V3 && + s_version != Protocol_Version::TLS_V10 && + s_version != Protocol_Version::TLS_V11 && + s_version != Protocol_Version::TLS_V12) { - throw TLS_Exception(PROTOCOL_VERSION, + throw TLS_Exception(Alert::PROTOCOL_VERSION, "Server_Hello: Unsupported server version"); } @@ -152,50 +118,82 @@ void Server_Hello::deserialize(const MemoryRegion<byte>& buf) comp_method = reader.get_byte(); - TLS_Extensions extensions(reader); + Extensions extensions(reader); - for(size_t i = 0; i != extensions.count(); ++i) + if(Renegotation_Extension* reneg = extensions.get<Renegotation_Extension>()) { - TLS_Extension* extn = extensions.at(i); - - if(Renegotation_Extension* reneg = dynamic_cast<Renegotation_Extension*>(extn)) - { - // checked by TLS_Client / TLS_Server as they know the handshake state - m_secure_renegotiation = true; - m_renegotiation_info = reneg->renegotiation_info(); - } - else if(Next_Protocol_Notification* npn = dynamic_cast<Next_Protocol_Notification*>(extn)) - { - m_next_protocols = npn->protocols(); - m_next_protocol = true; - } + // checked by Client / Server as they know the handshake state + m_secure_renegotiation = true; + m_renegotiation_info = reneg->renegotiation_info(); + } + + if(Next_Protocol_Notification* npn = extensions.get<Next_Protocol_Notification>()) + { + m_next_protocols = npn->protocols(); + m_next_protocol = true; } } /* -* Create a new Server Hello Done message +* Serialize a Server Hello message */ -Server_Hello_Done::Server_Hello_Done(Record_Writer& writer, - TLS_Handshake_Hash& hash) +MemoryVector<byte> Server_Hello::serialize() const { - send(writer, hash); + MemoryVector<byte> buf; + + buf.push_back(s_version.major_version()); + buf.push_back(s_version.minor_version()); + buf += s_random; + + append_tls_length_value(buf, m_session_id, 1); + + buf.push_back(get_byte(0, suite)); + buf.push_back(get_byte(1, suite)); + + buf.push_back(comp_method); + + Extensions extensions; + + if(m_secure_renegotiation) + extensions.add(new Renegotation_Extension(m_renegotiation_info)); + + if(m_fragment_size != 0) + extensions.add(new Maximum_Fragment_Length(m_fragment_size)); + + if(m_next_protocol) + extensions.add(new Next_Protocol_Notification(m_next_protocols)); + + buf += extensions.serialize(); + + return buf; } /* -* Serialize a Server Hello Done message +* Create a new Server Hello Done message */ -MemoryVector<byte> Server_Hello_Done::serialize() const +Server_Hello_Done::Server_Hello_Done(Record_Writer& writer, + Handshake_Hash& hash) { - return MemoryVector<byte>(); + hash.update(writer.send(*this)); } /* * Deserialize a Server Hello Done message */ -void Server_Hello_Done::deserialize(const MemoryRegion<byte>& buf) +Server_Hello_Done::Server_Hello_Done(const MemoryRegion<byte>& buf) { if(buf.size()) throw Decoding_Error("Server_Hello_Done: Must be empty, and is not"); } +/* +* Serialize a Server Hello Done message +*/ +MemoryVector<byte> Server_Hello_Done::serialize() const + { + return MemoryVector<byte>(); + } + +} + } diff --git a/src/tls/s_kex.cpp b/src/tls/s_kex.cpp index 150f13474..6707d2611 100644 --- a/src/tls/s_kex.cpp +++ b/src/tls/s_kex.cpp @@ -1,181 +1,235 @@ /* * Server Key Exchange Message -* (C) 2004-2010 Jack Lloyd +* (C) 2004-2010,2012 Jack Lloyd * * Released under the terms of the Botan license */ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_reader.h> +#include <botan/internal/tls_extensions.h> +#include <botan/tls_record.h> +#include <botan/internal/assert.h> +#include <botan/credentials_manager.h> +#include <botan/loadstor.h> #include <botan/pubkey.h> #include <botan/dh.h> +#include <botan/ecdh.h> #include <botan/rsa.h> -#include <botan/dsa.h> -#include <botan/loadstor.h> +#include <botan/oids.h> #include <memory> namespace Botan { +namespace TLS { + /** * Create a new Server Key Exchange message */ Server_Key_Exchange::Server_Key_Exchange(Record_Writer& writer, - TLS_Handshake_Hash& hash, + Handshake_State* state, + const Policy& policy, + Credentials_Manager& creds, RandomNumberGenerator& rng, - const Public_Key* kex_key, - const Private_Key* priv_key, - const MemoryRegion<byte>& c_random, - const MemoryRegion<byte>& s_random) + const Private_Key* signing_key) : + m_kex_key(0) { - const DH_PublicKey* dh_pub = dynamic_cast<const DH_PublicKey*>(kex_key); - const RSA_PublicKey* rsa_pub = dynamic_cast<const RSA_PublicKey*>(kex_key); + const std::string kex_algo = state->suite.kex_algo(); - if(dh_pub) + if(kex_algo == "PSK" || kex_algo == "DHE_PSK" || kex_algo == "ECDHE_PSK") { - params.push_back(dh_pub->get_domain().get_p()); - params.push_back(dh_pub->get_domain().get_g()); - params.push_back(BigInt::decode(dh_pub->public_value())); + std::string identity_hint = + creds.psk_identity_hint("tls-server", + state->client_hello->sni_hostname()); + + append_tls_length_value(m_params, identity_hint, 2); } - else if(rsa_pub) + + if(kex_algo == "DH" || kex_algo == "DHE_PSK") { - params.push_back(rsa_pub->get_n()); - params.push_back(rsa_pub->get_e()); + std::auto_ptr<DH_PrivateKey> dh(new DH_PrivateKey(rng, policy.dh_group())); + + append_tls_length_value(m_params, BigInt::encode(dh->get_domain().get_p()), 2); + append_tls_length_value(m_params, BigInt::encode(dh->get_domain().get_g()), 2); + append_tls_length_value(m_params, dh->public_value(), 2); + m_kex_key = dh.release(); } - else - throw Invalid_Argument("Bad key for TLS key exchange: not DH or RSA"); + else if(kex_algo == "ECDH" || kex_algo == "ECDHE_PSK") + { + const std::vector<std::string>& curves = + state->client_hello->supported_ecc_curves(); - // FIXME: cut and paste - std::string padding = ""; - Signature_Format format = IEEE_1363; + if(curves.empty()) + throw Internal_Error("Client sent no ECC extension but we negotiated ECDH"); - if(priv_key->algo_name() == "RSA") - padding = "EMSA3(TLS.Digest.0)"; - else if(priv_key->algo_name() == "DSA") - { - padding = "EMSA1(SHA-1)"; - format = DER_SEQUENCE; - } - else - throw Invalid_Argument(priv_key->algo_name() + - " is invalid/unknown for TLS signatures"); + const std::string curve_name = policy.choose_curve(curves); - PK_Signer signer(*priv_key, padding, format); + if(curve_name == "") + throw TLS_Exception(Alert::HANDSHAKE_FAILURE, + "Could not agree on an ECC curve with the client"); - signer.update(c_random); - signer.update(s_random); - signer.update(serialize_params()); - signature = signer.signature(rng); + EC_Group ec_group(curve_name); - send(writer, hash); - } + std::auto_ptr<ECDH_PrivateKey> ecdh(new ECDH_PrivateKey(rng, ec_group)); -/** -* Serialize a Server Key Exchange message -*/ -MemoryVector<byte> Server_Key_Exchange::serialize() const - { - MemoryVector<byte> buf = serialize_params(); - append_tls_length_value(buf, signature, 2); - return buf; - } + const std::string ecdh_domain_oid = ecdh->domain().get_oid(); + const std::string domain = OIDS::lookup(OID(ecdh_domain_oid)); -/** -* Serialize the ServerParams structure -*/ -MemoryVector<byte> Server_Key_Exchange::serialize_params() const - { - MemoryVector<byte> buf; + if(domain == "") + throw Internal_Error("Could not find name of ECDH domain " + ecdh_domain_oid); - for(size_t i = 0; i != params.size(); ++i) - append_tls_length_value(buf, BigInt::encode(params[i]), 2); + const u16bit named_curve_id = Supported_Elliptic_Curves::name_to_curve_id(domain); - return buf; + m_params.push_back(3); // named curve + m_params.push_back(get_byte(0, named_curve_id)); + m_params.push_back(get_byte(1, named_curve_id)); + + append_tls_length_value(m_params, ecdh->public_value(), 1); + + m_kex_key = ecdh.release(); + } + else if(kex_algo != "PSK") + throw Internal_Error("Server_Key_Exchange: Unknown kex type " + kex_algo); + + if(state->suite.sig_algo() != "") + { + BOTAN_ASSERT(signing_key, "No signing key set"); + + std::pair<std::string, Signature_Format> format = + state->choose_sig_format(signing_key, m_hash_algo, m_sig_algo, false); + + PK_Signer signer(*signing_key, format.first, format.second); + + signer.update(state->client_hello->random()); + signer.update(state->server_hello->random()); + signer.update(params()); + m_signature = signer.signature(rng); + } + + state->hash.update(writer.send(*this)); } /** * Deserialize a Server Key Exchange message */ -void Server_Key_Exchange::deserialize(const MemoryRegion<byte>& buf) +Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion<byte>& buf, + const std::string& kex_algo, + const std::string& sig_algo, + Protocol_Version version) : + m_kex_key(0) { if(buf.size() < 6) throw Decoding_Error("Server_Key_Exchange: Packet corrupted"); - MemoryVector<byte> values[4]; - size_t so_far = 0; + TLS_Data_Reader reader(buf); + + /* + * We really are just serializing things back to what they were + * before, but unfortunately to know where the signature is we need + * to be able to parse the whole thing anyway. + */ + + if(kex_algo == "PSK" || kex_algo == "DHE_PSK" || kex_algo == "ECDHE_PSK") + { + const std::string identity_hint = reader.get_string(2, 0, 65535); + append_tls_length_value(m_params, identity_hint, 2); + } + + if(kex_algo == "DH" || kex_algo == "DHE_PSK") + { + // 3 bigints, DH p, g, Y - for(size_t i = 0; i != 4; ++i) + for(size_t i = 0; i != 3; ++i) + { + BigInt v = BigInt::decode(reader.get_range<byte>(2, 1, 65535)); + append_tls_length_value(m_params, BigInt::encode(v), 2); + } + } + else if(kex_algo == "ECDH" || kex_algo == "ECDHE_PSK") { - const u16bit len = make_u16bit(buf[so_far], buf[so_far+1]); - so_far += 2; + const byte curve_type = reader.get_byte(); + + if(curve_type != 3) + throw Decoding_Error("Server_Key_Exchange: Server sent non-named ECC curve"); - if(len + so_far > buf.size()) - throw Decoding_Error("Server_Key_Exchange: Packet corrupted"); + const u16bit curve_id = reader.get_u16bit(); - values[i].resize(len); - copy_mem(&values[i][0], &buf[so_far], len); - so_far += len; + const std::string name = Supported_Elliptic_Curves::curve_id_to_name(curve_id); - if(i == 2 && so_far == buf.size()) - break; + MemoryVector<byte> ecdh_key = reader.get_range<byte>(1, 1, 255); + + if(name == "") + throw Decoding_Error("Server_Key_Exchange: Server sent unknown named curve " + + to_string(curve_id)); + + m_params.push_back(curve_type); + m_params.push_back(get_byte(0, curve_id)); + m_params.push_back(get_byte(1, curve_id)); + append_tls_length_value(m_params, ecdh_key, 1); } + else if(kex_algo != "PSK") + throw Decoding_Error("Server_Key_Exchange: Unsupported kex type " + kex_algo); - params.push_back(BigInt::decode(values[0])); - params.push_back(BigInt::decode(values[1])); - if(values[3].size()) + if(sig_algo != "") { - params.push_back(BigInt::decode(values[2])); - signature = values[3]; + if(version >= Protocol_Version::TLS_V12) + { + m_hash_algo = Signature_Algorithms::hash_algo_name(reader.get_byte()); + m_sig_algo = Signature_Algorithms::sig_algo_name(reader.get_byte()); + } + + m_signature = reader.get_range<byte>(2, 0, 65535); } - else - signature = values[2]; } + /** -* Return the public key +* Serialize a Server Key Exchange message */ -Public_Key* Server_Key_Exchange::key() const +MemoryVector<byte> Server_Key_Exchange::serialize() const { - if(params.size() == 2) - return new RSA_PublicKey(params[0], params[1]); - else if(params.size() == 3) - return new DH_PublicKey(DL_Group(params[0], params[1]), params[2]); - else - throw Internal_Error("Server_Key_Exchange::key: No key set"); + MemoryVector<byte> buf = params(); + + if(m_signature.size()) + { + // This should be an explicit version check + if(m_hash_algo != "" && m_sig_algo != "") + { + buf.push_back(Signature_Algorithms::hash_algo_code(m_hash_algo)); + buf.push_back(Signature_Algorithms::sig_algo_code(m_sig_algo)); + } + + append_tls_length_value(buf, m_signature, 2); + } + + return buf; } /** * Verify a Server Key Exchange message */ bool Server_Key_Exchange::verify(const X509_Certificate& cert, - const MemoryRegion<byte>& c_random, - const MemoryRegion<byte>& s_random) const + Handshake_State* state) const { - std::auto_ptr<Public_Key> key(cert.subject_public_key()); - // FIXME: cut and paste - std::string padding = ""; - Signature_Format format = IEEE_1363; + std::pair<std::string, Signature_Format> format = + state->understand_sig_format(key.get(), m_hash_algo, m_sig_algo, false); - if(key->algo_name() == "RSA") - padding = "EMSA3(TLS.Digest.0)"; - else if(key->algo_name() == "DSA") - { - padding = "EMSA1(SHA-1)"; - format = DER_SEQUENCE; - } - else - throw Invalid_Argument(key->algo_name() + - " is invalid/unknown for TLS signatures"); + PK_Verifier verifier(*key, format.first, format.second); - PK_Verifier verifier(*key, padding, format); + verifier.update(state->client_hello->random()); + verifier.update(state->server_hello->random()); + verifier.update(params()); - MemoryVector<byte> params_got = serialize_params(); - verifier.update(c_random); - verifier.update(s_random); - verifier.update(params_got); + return verifier.check_signature(m_signature); + } - return verifier.check_signature(signature); +const Private_Key& Server_Key_Exchange::server_kex_key() const + { + BOTAN_ASSERT(m_kex_key, "Key is non-NULL"); + return *m_kex_key; } +} } diff --git a/src/tls/tls_alert.cpp b/src/tls/tls_alert.cpp new file mode 100644 index 000000000..b526eeac3 --- /dev/null +++ b/src/tls/tls_alert.cpp @@ -0,0 +1,115 @@ +/* +* Alert Message +* (C) 2004-2006,2011 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#include <botan/tls_alert.h> +#include <botan/exceptn.h> + +namespace Botan { + +namespace TLS { + +Alert::Alert(const MemoryRegion<byte>& buf) + { + if(buf.size() != 2) + throw Decoding_Error("Alert: Bad size " + to_string(buf.size()) + + " for alert message"); + + if(buf[0] == 1) fatal = false; + else if(buf[0] == 2) fatal = true; + else + throw Decoding_Error("Alert: Bad code for alert level"); + + const byte dc = buf[1]; + + /* + * This is allowed by the specification but is not allocated and we're + * using it internally as a special 'no alert' type. + */ + if(dc == 255) + throw Internal_Error("Alert: description code 255, rejecting"); + + type_code = static_cast<Type>(dc); + } + +std::string Alert::type_string() const + { + switch(type()) + { + case CLOSE_NOTIFY: + return "close_notify"; + case UNEXPECTED_MESSAGE: + return "unexpected_message"; + case BAD_RECORD_MAC: + return "bad_record_mac"; + case DECRYPTION_FAILED: + return "decryption_failed"; + case RECORD_OVERFLOW: + return "record_overflow"; + case DECOMPRESSION_FAILURE: + return "decompression_failure"; + case HANDSHAKE_FAILURE: + return "handshake_failure"; + case NO_CERTIFICATE: + return "no_certificate"; + case BAD_CERTIFICATE: + return "bad_certificate"; + case UNSUPPORTED_CERTIFICATE: + return "unsupported_certificate"; + case CERTIFICATE_REVOKED: + return "certificate_revoked"; + case CERTIFICATE_EXPIRED: + return "certificate_expired"; + case CERTIFICATE_UNKNOWN: + return "certificate_unknown"; + case ILLEGAL_PARAMETER: + return "illegal_parameter"; + case UNKNOWN_CA: + return "unknown_ca"; + case ACCESS_DENIED: + return "access_denied"; + case DECODE_ERROR: + return "decode_error"; + case DECRYPT_ERROR: + return "decrypt_error"; + case EXPORT_RESTRICTION: + return "export_restriction"; + case PROTOCOL_VERSION: + return "protocol_version"; + case INSUFFICIENT_SECURITY: + return "insufficient_security"; + case INTERNAL_ERROR: + return "internal_error"; + case USER_CANCELED: + return "user_canceled"; + case NO_RENEGOTIATION: + return "no_renegotiation"; + + case UNSUPPORTED_EXTENSION: + return "unsupported_extension"; + case UNRECOGNIZED_NAME: + return "unrecognized_name"; + + case UNKNOWN_PSK_IDENTITY: + return "unknown_psk_identity"; + + case NULL_ALERT: + return ""; + } + + /* + * This is effectively the default case for the switch above, but we + * leave it out so that when an alert type is added to the enum the + * compiler can warn us that it is not included in the switch + * statement. + */ + return "unrecognized_alert_" + to_string(type()); + } + + +} + +} diff --git a/src/tls/tls_alert.h b/src/tls/tls_alert.h new file mode 100644 index 000000000..0446a8c30 --- /dev/null +++ b/src/tls/tls_alert.h @@ -0,0 +1,97 @@ +/* +* Alert Message +* (C) 2004-2006,2011 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#ifndef BOTAN_TLS_ALERT_H__ +#define BOTAN_TLS_ALERT_H__ + +#include <botan/secmem.h> +#include <string> + +namespace Botan { + +namespace TLS { + +/** +* SSL/TLS Alert Message +*/ +class BOTAN_DLL Alert + { + public: + enum Type { + CLOSE_NOTIFY = 0, + UNEXPECTED_MESSAGE = 10, + BAD_RECORD_MAC = 20, + DECRYPTION_FAILED = 21, + RECORD_OVERFLOW = 22, + DECOMPRESSION_FAILURE = 30, + HANDSHAKE_FAILURE = 40, + NO_CERTIFICATE = 41, // SSLv3 only + BAD_CERTIFICATE = 42, + UNSUPPORTED_CERTIFICATE = 43, + CERTIFICATE_REVOKED = 44, + CERTIFICATE_EXPIRED = 45, + CERTIFICATE_UNKNOWN = 46, + ILLEGAL_PARAMETER = 47, + UNKNOWN_CA = 48, + ACCESS_DENIED = 49, + DECODE_ERROR = 50, + DECRYPT_ERROR = 51, + EXPORT_RESTRICTION = 60, + PROTOCOL_VERSION = 70, + INSUFFICIENT_SECURITY = 71, + INTERNAL_ERROR = 80, + USER_CANCELED = 90, + NO_RENEGOTIATION = 100, + + UNSUPPORTED_EXTENSION = 110, + UNRECOGNIZED_NAME = 112, + + UNKNOWN_PSK_IDENTITY = 115, + + NULL_ALERT = 255 + }; + + /** + * @return true iff this alert is non-empty + */ + bool is_valid() const { return (type_code != NULL_ALERT); } + + /** + * @return if this alert is a fatal one or not + */ + bool is_fatal() const { return fatal; } + + /** + * @return type of alert + */ + Type type() const { return type_code; } + + /** + * @return type of alert + */ + std::string type_string() const; + + /** + * Deserialize an Alert message + * @param buf the serialized alert + */ + Alert(const MemoryRegion<byte>& buf); + + Alert(Type alert_type, bool is_fatal = false) : + fatal(is_fatal), type_code(alert_type) {} + + Alert() : fatal(false), type_code(NULL_ALERT) {} + private: + bool fatal; + Type type_code; + }; + +} + +} + +#endif diff --git a/src/tls/tls_alerts.h b/src/tls/tls_alerts.h deleted file mode 100644 index 0634d6763..000000000 --- a/src/tls/tls_alerts.h +++ /dev/null @@ -1,60 +0,0 @@ -/* -* Alert Message -* (C) 2004-2006,2011 Jack Lloyd -* -* Released under the terms of the Botan license -*/ - -#ifndef BOTAN_TLS_ALERT_H__ -#define BOTAN_TLS_ALERT_H__ - -#include <botan/tls_exceptn.h> - -namespace Botan { - -/** -* SSL/TLS Alert Message -*/ -class Alert - { - public: - /** - * @return if this alert is a fatal one or not - */ - bool is_fatal() const { return fatal; } - - /** - * @return type of alert - */ - Alert_Type type() const { return type_code; } - - /** - * Deserialize an Alert message - * @param buf the serialized alert - */ - Alert(const MemoryRegion<byte>& buf) - { - if(buf.size() != 2) - throw Decoding_Error("Alert: Bad size " + to_string(buf.size()) + - " for alert message"); - - if(buf[0] == 1) fatal = false; - else if(buf[0] == 2) fatal = true; - else - throw Decoding_Error("Alert: Bad code for alert level"); - - const byte dc = buf[1]; - - if(dc == 255) - throw Decoding_Error("Alert: description code 255, rejecting"); - - type_code = static_cast<Alert_Type>(dc); - } - private: - bool fatal; - Alert_Type type_code; - }; - -} - -#endif diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index 46c6d36cd..f45ce4bda 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -6,16 +6,18 @@ */ #include <botan/tls_channel.h> -#include <botan/internal/tls_alerts.h> #include <botan/internal/tls_handshake_state.h> +#include <botan/internal/tls_messages.h> #include <botan/internal/assert.h> #include <botan/loadstor.h> namespace Botan { -TLS_Channel::TLS_Channel(std::tr1::function<void (const byte[], size_t)> socket_output_fn, - std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn, - std::tr1::function<bool (const TLS_Session&)> handshake_complete) : +namespace TLS { + +Channel::Channel(std::tr1::function<void (const byte[], size_t)> socket_output_fn, + std::tr1::function<void (const byte[], size_t, Alert)> proc_fn, + std::tr1::function<bool (const Session&)> handshake_complete) : proc_fn(proc_fn), handshake_fn(handshake_complete), writer(socket_output_fn), @@ -25,13 +27,13 @@ TLS_Channel::TLS_Channel(std::tr1::function<void (const byte[], size_t)> socket_ { } -TLS_Channel::~TLS_Channel() +Channel::~Channel() { delete state; state = 0; } -size_t TLS_Channel::received_data(const byte buf[], size_t buf_size) +size_t Channel::received_data(const byte buf[], size_t buf_size) { try { @@ -64,7 +66,7 @@ size_t TLS_Channel::received_data(const byte buf[], size_t buf_size) * following record. Avoid spurious callbacks. */ if(record.size() > 0) - proc_fn(&record[0], record.size(), NULL_ALERT); + proc_fn(&record[0], record.size(), Alert()); } else { @@ -79,16 +81,16 @@ size_t TLS_Channel::received_data(const byte buf[], size_t buf_size) { Alert alert_msg(record); - alert_notify(alert_msg.is_fatal(), alert_msg.type()); + alert_notify(alert_msg); - proc_fn(0, 0, alert_msg.type()); + proc_fn(0, 0, alert_msg); - if(alert_msg.type() == CLOSE_NOTIFY) + if(alert_msg.type() == Alert::CLOSE_NOTIFY) { if(connection_closed) reader.reset(); else - alert(WARNING, CLOSE_NOTIFY); // reply in kind + send_alert(Alert(Alert::CLOSE_NOTIFY)); // reply in kind } else if(alert_msg.is_fatal()) { @@ -111,17 +113,22 @@ size_t TLS_Channel::received_data(const byte buf[], size_t buf_size) } catch(TLS_Exception& e) { - alert(FATAL, e.type()); + send_alert(Alert(e.type(), true)); throw; } catch(Decoding_Error& e) { - alert(FATAL, DECODE_ERROR); + send_alert(Alert(Alert::DECODE_ERROR, true)); + throw; + } + catch(Internal_Error& e) + { + send_alert(Alert(Alert::INTERNAL_ERROR, true)); throw; } catch(std::exception& e) { - alert(FATAL, INTERNAL_ERROR); + send_alert(Alert(Alert::INTERNAL_ERROR, true)); throw; } } @@ -129,80 +136,68 @@ size_t TLS_Channel::received_data(const byte buf[], size_t buf_size) /* * Split up and process handshake messages */ -void TLS_Channel::read_handshake(byte rec_type, - const MemoryRegion<byte>& rec_buf) +void Channel::read_handshake(byte rec_type, + const MemoryRegion<byte>& rec_buf) { if(rec_type == HANDSHAKE) { if(!state) - state = new Handshake_State; - state->queue.write(&rec_buf[0], rec_buf.size()); + state = new Handshake_State(new Stream_Handshake_Reader); + state->handshake_reader()->add_input(&rec_buf[0], rec_buf.size()); } + BOTAN_ASSERT(state, "Handshake message recieved without state in place"); + while(true) { Handshake_Type type = HANDSHAKE_NONE; - MemoryVector<byte> contents; if(rec_type == HANDSHAKE) { - if(state->queue.size() >= 4) + if(state->handshake_reader()->have_full_record()) { - byte head[4] = { 0 }; - state->queue.peek(head, 4); - - const size_t length = make_u32bit(0, head[1], head[2], head[3]); - - if(state->queue.size() >= length + 4) - { - type = static_cast<Handshake_Type>(head[0]); - contents.resize(length); - state->queue.read(head, 4); - state->queue.read(&contents[0], contents.size()); - } + std::pair<Handshake_Type, MemoryVector<byte> > msg = + state->handshake_reader()->get_next_record(); + process_handshake_msg(msg.first, msg.second); } + else + break; } else if(rec_type == CHANGE_CIPHER_SPEC) { - if(state->queue.size() == 0 && rec_buf.size() == 1 && rec_buf[0] == 1) - type = HANDSHAKE_CCS; + if(state->handshake_reader()->empty() && rec_buf.size() == 1 && rec_buf[0] == 1) + process_handshake_msg(HANDSHAKE_CCS, MemoryVector<byte>()); else throw Decoding_Error("Malformed ChangeCipherSpec message"); } else throw Decoding_Error("Unknown message type in handshake processing"); - if(type == HANDSHAKE_NONE) - break; - - process_handshake_msg(type, contents); - - if(type == HANDSHAKE_CCS || !state) + if(type == HANDSHAKE_CCS || !state || !state->handshake_reader()->have_full_record()) break; } } -void TLS_Channel::queue_for_sending(const byte buf[], size_t buf_size) +void Channel::send(const byte buf[], size_t buf_size) { - if(!handshake_completed) - throw std::runtime_error("Application data cannot be queued before handshake"); + if(!is_active()) + throw std::runtime_error("Data cannot be sent on inactive TLS connection"); writer.send(APPLICATION_DATA, buf, buf_size); } -void TLS_Channel::alert(Alert_Level alert_level, Alert_Type alert_code) +void Channel::send_alert(const Alert& alert) { - if(alert_code != NULL_ALERT && !connection_closed) + if(alert.is_valid() && !connection_closed) { try { - writer.alert(alert_level, alert_code); + writer.send_alert(alert); } catch(...) { /* swallow it */ } } - if(!connection_closed && - (alert_code == CLOSE_NOTIFY || alert_level == FATAL)) + if(!connection_closed && (alert.type() == Alert::CLOSE_NOTIFY || alert.is_fatal())) { connection_closed = true; @@ -213,7 +208,7 @@ void TLS_Channel::alert(Alert_Level alert_level, Alert_Type alert_code) } } -void TLS_Channel::Secure_Renegotiation_State::update(Client_Hello* client_hello) +void Channel::Secure_Renegotiation_State::update(Client_Hello* client_hello) { if(initial_handshake) { @@ -222,7 +217,7 @@ void TLS_Channel::Secure_Renegotiation_State::update(Client_Hello* client_hello) else { if(secure_renegotiation != client_hello->secure_renegotiation()) - throw TLS_Exception(HANDSHAKE_FAILURE, + throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Client changed its mind about secure renegotiation"); } @@ -233,19 +228,19 @@ void TLS_Channel::Secure_Renegotiation_State::update(Client_Hello* client_hello) if(initial_handshake) { if(!data.empty()) - throw TLS_Exception(HANDSHAKE_FAILURE, + throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Client sent renegotiation data on initial handshake"); } else { if(data != for_client_hello()) - throw TLS_Exception(HANDSHAKE_FAILURE, + throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Client sent bad renegotiation data"); } } } -void TLS_Channel::Secure_Renegotiation_State::update(Server_Hello* server_hello) +void Channel::Secure_Renegotiation_State::update(Server_Hello* server_hello) { if(initial_handshake) { @@ -257,7 +252,7 @@ void TLS_Channel::Secure_Renegotiation_State::update(Server_Hello* server_hello) else { if(secure_renegotiation != server_hello->secure_renegotiation()) - throw TLS_Exception(HANDSHAKE_FAILURE, + throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Server changed its mind about secure renegotiation"); } @@ -268,13 +263,13 @@ void TLS_Channel::Secure_Renegotiation_State::update(Server_Hello* server_hello) if(initial_handshake) { if(!data.empty()) - throw TLS_Exception(HANDSHAKE_FAILURE, + throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Server sent renegotiation data on initial handshake"); } else { if(data != for_server_hello()) - throw TLS_Exception(HANDSHAKE_FAILURE, + throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Server sent bad renegotiation data"); } } @@ -282,7 +277,7 @@ void TLS_Channel::Secure_Renegotiation_State::update(Server_Hello* server_hello) initial_handshake = false; } -void TLS_Channel::Secure_Renegotiation_State::update(Finished* client_finished, +void Channel::Secure_Renegotiation_State::update(Finished* client_finished, Finished* server_finished) { client_verify = client_finished->verify_data(); @@ -290,3 +285,5 @@ void TLS_Channel::Secure_Renegotiation_State::update(Finished* client_finished, } } + +} diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h index 0306d1a74..53af0bdfc 100644 --- a/src/tls/tls_channel.h +++ b/src/tls/tls_channel.h @@ -11,15 +11,18 @@ #include <botan/tls_policy.h> #include <botan/tls_record.h> #include <botan/tls_session.h> +#include <botan/tls_alert.h> #include <botan/x509cert.h> #include <vector> namespace Botan { +namespace TLS { + /** * Generic interface for TLS endpoint */ -class BOTAN_DLL TLS_Channel +class BOTAN_DLL Channel { public: /** @@ -32,12 +35,12 @@ class BOTAN_DLL TLS_Channel /** * Inject plaintext intended for counterparty */ - virtual void queue_for_sending(const byte buf[], size_t buf_size); + virtual void send(const byte buf[], size_t buf_size); /** * Send a close notification alert */ - void close() { alert(WARNING, CLOSE_NOTIFY); } + void close() { send_alert(Alert(Alert::CLOSE_NOTIFY)); } /** * @return true iff the connection is active for sending application data @@ -59,11 +62,11 @@ class BOTAN_DLL TLS_Channel */ std::vector<X509_Certificate> peer_cert_chain() const { return peer_certs; } - TLS_Channel(std::tr1::function<void (const byte[], size_t)> socket_output_fn, - std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn, - std::tr1::function<bool (const TLS_Session&)> handshake_complete); + Channel(std::tr1::function<void (const byte[], size_t)> socket_output_fn, + std::tr1::function<void (const byte[], size_t, Alert)> proc_fn, + std::tr1::function<bool (const Session&)> handshake_complete); - virtual ~TLS_Channel(); + virtual ~Channel(); protected: /** @@ -72,7 +75,7 @@ class BOTAN_DLL TLS_Channel * @param level is warning or fatal * @param type is the type of alert */ - void alert(Alert_Level level, Alert_Type type); + void send_alert(const Alert& alert); virtual void read_handshake(byte rec_type, const MemoryRegion<byte>& rec_buf); @@ -80,10 +83,10 @@ class BOTAN_DLL TLS_Channel virtual void process_handshake_msg(Handshake_Type type, const MemoryRegion<byte>& contents) = 0; - virtual void alert_notify(bool fatal_alert, Alert_Type type) = 0; + virtual void alert_notify(const Alert& alert) = 0; - std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn; - std::tr1::function<bool (const TLS_Session&)> handshake_fn; + std::tr1::function<void (const byte[], size_t, Alert)> proc_fn; + std::tr1::function<bool (const Session&)> handshake_fn; Record_Writer writer; Record_Reader reader; @@ -131,4 +134,6 @@ class BOTAN_DLL TLS_Channel } +} + #endif diff --git a/src/tls/tls_ciphersuite.cpp b/src/tls/tls_ciphersuite.cpp new file mode 100644 index 000000000..247948464 --- /dev/null +++ b/src/tls/tls_ciphersuite.cpp @@ -0,0 +1,347 @@ +/* +* TLS Cipher Suites +* (C) 2004-2010,2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#include <botan/tls_ciphersuite.h> +#include <botan/tls_magic.h> +#include <botan/parsing.h> +#include <sstream> +#include <stdexcept> + +namespace Botan { + +namespace TLS { + +/** +* Convert an SSL/TLS ciphersuite to algorithm fields +*/ +Ciphersuite Ciphersuite::by_id(u16bit suite) + { + switch(static_cast<Ciphersuite_Code>(suite)) + { + // RSA ciphersuites + + case TLS_RSA_WITH_AES_128_CBC_SHA: + return Ciphersuite("RSA", "RSA", "SHA-1", "AES-128", 16); + + case TLS_RSA_WITH_AES_256_CBC_SHA: + return Ciphersuite("RSA", "RSA", "SHA-1", "AES-256", 32); + + case TLS_RSA_WITH_AES_128_CBC_SHA256: + return Ciphersuite("RSA", "RSA", "SHA-256", "AES-128", 16); + + case TLS_RSA_WITH_AES_256_CBC_SHA256: + return Ciphersuite("RSA", "RSA", "SHA-256", "AES-256", 32); + + case TLS_RSA_WITH_3DES_EDE_CBC_SHA: + return Ciphersuite("RSA", "RSA", "SHA-1", "3DES", 24); + + case TLS_RSA_WITH_RC4_128_SHA: + return Ciphersuite("RSA", "RSA", "SHA-1", "ARC4", 16); + + case TLS_RSA_WITH_RC4_128_MD5: + return Ciphersuite("RSA", "RSA", "MD5", "ARC4", 16); + + case TLS_RSA_WITH_CAMELLIA_128_CBC_SHA: + return Ciphersuite("RSA", "RSA", "SHA-1", "Camellia", 16); + + case TLS_RSA_WITH_CAMELLIA_256_CBC_SHA: + return Ciphersuite("RSA", "RSA", "SHA-1", "Camellia", 32); + + case TLS_RSA_WITH_SEED_CBC_SHA: + return Ciphersuite("RSA", "RSA", "SHA-1", "SEED", 16); + +#if defined(BOTAN_HAS_IDEA) + case TLS_RSA_WITH_IDEA_CBC_SHA: + return Ciphersuite("RSA", "RSA", "SHA-1", "IDEA", 16); +#endif + + // DH/DSS ciphersuites + + case TLS_DHE_DSS_WITH_AES_128_CBC_SHA: + return Ciphersuite("DSA", "DH", "SHA-1", "AES-128", 16); + + case TLS_DHE_DSS_WITH_AES_256_CBC_SHA: + return Ciphersuite("DSA", "DH", "SHA-1", "AES-256", 32); + + case TLS_DHE_DSS_WITH_AES_128_CBC_SHA256: + return Ciphersuite("DSA", "DH", "SHA-256", "AES-128", 16); + + case TLS_DHE_DSS_WITH_AES_256_CBC_SHA256: + return Ciphersuite("DSA", "DH", "SHA-256", "AES-256", 32); + + case TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA: + return Ciphersuite("DSA", "DH", "SHA-1", "3DES", 24); + + case TLS_DHE_DSS_WITH_RC4_128_SHA: + return Ciphersuite("DSA", "DH", "SHA-1", "ARC4", 16); + + case TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA: + return Ciphersuite("DSA", "DH", "SHA-1", "Camellia", 16); + + case TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA: + return Ciphersuite("DSA", "DH", "SHA-1", "Camellia", 32); + + case TLS_DHE_DSS_WITH_SEED_CBC_SHA: + return Ciphersuite("DSA", "DH", "SHA-1", "SEED", 16); + + // DH/RSA ciphersuites + + case TLS_DHE_RSA_WITH_AES_128_CBC_SHA: + return Ciphersuite("RSA", "DH", "SHA-1", "AES-128", 16); + + case TLS_DHE_RSA_WITH_AES_256_CBC_SHA: + return Ciphersuite("RSA", "DH", "SHA-1", "AES-256", 32); + + case TLS_DHE_RSA_WITH_AES_128_CBC_SHA256: + return Ciphersuite("RSA", "DH", "SHA-256", "AES-128", 16); + + case TLS_DHE_RSA_WITH_AES_256_CBC_SHA256: + return Ciphersuite("RSA", "DH", "SHA-256", "AES-256", 32); + + case TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA: + return Ciphersuite("RSA", "DH", "SHA-1", "3DES", 24); + + case TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA: + return Ciphersuite("RSA", "DH", "SHA-1", "Camellia", 16); + + case TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA: + return Ciphersuite("RSA", "DH", "SHA-1", "Camellia", 32); + + case TLS_DHE_RSA_WITH_SEED_CBC_SHA: + return Ciphersuite("RSA", "DH", "SHA-1", "SEED", 16); + + // ECDH/RSA ciphersuites + case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: + return Ciphersuite("RSA", "ECDH", "SHA-1", "AES-128", 16); + + case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: + return Ciphersuite("RSA", "ECDH", "SHA-1", "AES-256", 32); + + case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: + return Ciphersuite("RSA", "ECDH", "SHA-256", "AES-128", 16); + + case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384: + return Ciphersuite("RSA", "ECDH", "SHA-384", "AES-256", 32); + + case TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA: + return Ciphersuite("RSA", "ECDH", "SHA-1", "3DES", 24); + + case TLS_ECDHE_RSA_WITH_RC4_128_SHA: + return Ciphersuite("RSA", "ECDH", "SHA-1", "ARC4", 16); + + // ECDH/ECDSA ciphersuites + + case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: + return Ciphersuite("ECDSA", "ECDH", "SHA-1", "AES-128", 16); + + case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: + return Ciphersuite("ECDSA", "ECDH", "SHA-1", "AES-256", 32); + + case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: + return Ciphersuite("ECDSA", "ECDH", "SHA-256", "AES-128", 16); + + case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384: + return Ciphersuite("ECDSA", "ECDH", "SHA-384", "AES-256", 32); + + case TLS_ECDHE_ECDSA_WITH_RC4_128_SHA: + return Ciphersuite("ECDSA", "ECDH", "SHA-1", "ARC4", 16); + + case TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA: + return Ciphersuite("ECDSA", "ECDH", "SHA-1", "3DES", 24); + + // PSK ciphersuites + + case TLS_PSK_WITH_RC4_128_SHA: + return Ciphersuite("", "PSK", "SHA-1", "ARC4", 16); + + case TLS_PSK_WITH_3DES_EDE_CBC_SHA: + return Ciphersuite("", "PSK", "SHA-1", "3DES", 24); + + case TLS_PSK_WITH_AES_128_CBC_SHA: + return Ciphersuite("", "PSK", "SHA-1", "AES-128", 16); + + case TLS_PSK_WITH_AES_128_CBC_SHA256: + return Ciphersuite("", "PSK", "SHA-256", "AES-128", 16); + + case TLS_PSK_WITH_AES_256_CBC_SHA: + return Ciphersuite("", "PSK", "SHA-1", "AES-256", 32); + + case TLS_PSK_WITH_AES_256_CBC_SHA384: + return Ciphersuite("", "PSK", "SHA-384", "AES-256", 32); + + // PSK+DH ciphersuites + + case TLS_DHE_PSK_WITH_RC4_128_SHA: + return Ciphersuite("", "DHE_PSK", "SHA-1", "ARC4", 16); + + case TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA: + return Ciphersuite("", "DHE_PSK", "SHA-1", "3DES", 24); + + case TLS_DHE_PSK_WITH_AES_128_CBC_SHA: + return Ciphersuite("", "DHE_PSK", "SHA-1", "AES-128", 16); + + case TLS_DHE_PSK_WITH_AES_128_CBC_SHA256: + return Ciphersuite("", "DHE_PSK", "SHA-256", "AES-128", 16); + + case TLS_DHE_PSK_WITH_AES_256_CBC_SHA: + return Ciphersuite("", "DHE_PSK", "SHA-1", "AES-256", 32); + + case TLS_DHE_PSK_WITH_AES_256_CBC_SHA384: + return Ciphersuite("", "DHE_PSK", "SHA-384", "AES-256", 32); + + // PSK+ECDH ciphersuites + + case TLS_ECDHE_PSK_WITH_RC4_128_SHA: + return Ciphersuite("", "ECDHE_PSK", "SHA-1", "ARC4", 16); + + case TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA: + return Ciphersuite("", "ECDHE_PSK", "SHA-1", "3DES", 24); + + case TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA: + return Ciphersuite("", "ECDHE_PSK", "SHA-1", "AES-128", 16); + + case TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256: + return Ciphersuite("", "ECDHE_PSK", "SHA-256", "AES-128", 16); + + case TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA: + return Ciphersuite("", "ECDHE_PSK", "SHA-1", "AES-256", 32); + + case TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384: + return Ciphersuite("", "ECDHE_PSK", "SHA-384", "AES-256", 32); + + // SRP ciphersuites + + case TLS_SRP_SHA_WITH_AES_128_CBC_SHA: + return Ciphersuite("", "SRP", "SHA-1", "AES-128", 16); + + case TLS_SRP_SHA_WITH_AES_256_CBC_SHA: + return Ciphersuite("", "SRP", "SHA-1", "AES-256", 32); + + case TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA: + return Ciphersuite("", "SRP", "SHA-1", "3DES", 24); + + // SRP/RSA ciphersuites + + case TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA: + return Ciphersuite("RSA", "SRP", "SHA-1", "AES-128", 16); + + case TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA: + return Ciphersuite("RSA", "SRP", "SHA-1", "AES-256", 32); + + case TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA: + return Ciphersuite("RSA", "SRP", "SHA-1", "3DES", 24); + + // SRP/DSA ciphersuites + + case TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA: + return Ciphersuite("DSA", "SRP", "SHA-1", "AES-128", 16); + + case TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA: + return Ciphersuite("DSA", "SRP", "SHA-1", "AES-256", 32); + + case TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA: + return Ciphersuite("DSA", "SRP", "SHA-1", "3DES", 24); + + // Signaling ciphersuite values + + case TLS_EMPTY_RENEGOTIATION_INFO_SCSV: + return Ciphersuite(); + } + + return Ciphersuite(); // some unknown ciphersuite + } + +Ciphersuite Ciphersuite::by_name(const std::string& name) + { + for(size_t i = 0; i != 65536; ++i) + { + Ciphersuite suite = Ciphersuite::by_id(i); + + if(!suite.valid()) + continue; // not a ciphersuite we know, skip + + if(suite.to_string() == name) + return suite; + } + + return Ciphersuite(); // some unknown ciphersuite + } + +std::string Ciphersuite::to_string() const + { + if(m_cipher_keylen == 0) + throw std::runtime_error("Ciphersuite::to_string - no value set"); + + std::ostringstream out; + + out << "TLS_"; + + if(kex_algo() != "RSA") + { + if(kex_algo() == "DH") + out << "DHE"; + else if(kex_algo() == "ECDH") + out << "ECDHE"; + else if(kex_algo() == "SRP") + out << "SRP_SHA"; + else + out << kex_algo(); + + out << '_'; + } + + if(sig_algo() == "DSA") + out << "DSS_"; + else if(sig_algo() != "") + out << sig_algo() << '_'; + + out << "WITH_"; + + if(cipher_algo() == "ARC4") + { + out << "RC4_128_"; + } + else + { + if(cipher_algo() == "3DES") + out << "3DES_EDE"; + else if(cipher_algo() == "Camellia") + out << "CAMELLIA_" << Botan::to_string(8*cipher_keylen()); + else + out << replace_char(cipher_algo(), '-', '_'); + + out << "_CBC_"; + } + + if(mac_algo() == "SHA-1") + out << "SHA"; + else if(mac_algo() == "SHA-256") + out << "SHA256"; + else if(mac_algo() == "SHA-384") + out << "SHA384"; + else + out << mac_algo(); + + return out.str(); + } + +Ciphersuite::Ciphersuite(const std::string& sig_algo, + const std::string& kex_algo, + const std::string& mac_algo, + const std::string& cipher_algo, + size_t cipher_algo_keylen) : + m_sig_algo(sig_algo), + m_kex_algo(kex_algo), + m_mac_algo(mac_algo), + m_cipher_algo(cipher_algo), + m_cipher_keylen(cipher_algo_keylen) + { + } + +} + +} diff --git a/src/tls/tls_ciphersuite.h b/src/tls/tls_ciphersuite.h new file mode 100644 index 000000000..e5d8c967b --- /dev/null +++ b/src/tls/tls_ciphersuite.h @@ -0,0 +1,59 @@ +/* +* TLS Cipher Suites +* (C) 2004-2011 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#ifndef BOTAN_TLS_CIPHER_SUITES_H__ +#define BOTAN_TLS_CIPHER_SUITES_H__ + +#include <botan/types.h> +#include <string> + +namespace Botan { + +namespace TLS { + +/** +* Ciphersuite Information +*/ +class BOTAN_DLL Ciphersuite + { + public: + static Ciphersuite by_id(u16bit suite); + + static Ciphersuite by_name(const std::string& name); + + /** + * Formats the ciphersuite back to an RFC-style ciphersuite string + */ + std::string to_string() const; + + std::string kex_algo() const { return m_kex_algo; } + std::string sig_algo() const { return m_sig_algo; } + + std::string cipher_algo() const { return m_cipher_algo; } + std::string mac_algo() const { return m_mac_algo; } + + size_t cipher_keylen() const { return m_cipher_keylen; } + + bool valid() const { return (m_cipher_keylen > 0); } + + Ciphersuite() : m_cipher_keylen(0) {} + + Ciphersuite(const std::string& sig_algo, + const std::string& kex_algo, + const std::string& mac_algo, + const std::string& cipher_algo, + size_t cipher_algo_keylen); + private: + std::string m_sig_algo, m_kex_algo, m_mac_algo, m_cipher_algo; + size_t m_cipher_keylen; + }; + +} + +} + +#endif diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index 7abcdf644..02e24a1c9 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -6,36 +6,36 @@ */ #include <botan/tls_client.h> -#include <botan/internal/tls_session_key.h> #include <botan/internal/tls_handshake_state.h> +#include <botan/internal/tls_messages.h> #include <botan/internal/stl_util.h> -#include <botan/rsa.h> -#include <botan/dsa.h> -#include <botan/dh.h> +#include <memory> namespace Botan { +namespace TLS { + /* * TLS Client Constructor */ -TLS_Client::TLS_Client(std::tr1::function<void (const byte[], size_t)> output_fn, - std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn, - std::tr1::function<bool (const TLS_Session&)> handshake_fn, - TLS_Session_Manager& session_manager, - Credentials_Manager& creds, - const TLS_Policy& policy, - RandomNumberGenerator& rng, - const std::string& hostname, - std::tr1::function<std::string (std::vector<std::string>)> next_protocol) : - TLS_Channel(output_fn, proc_fn, handshake_fn), +Client::Client(std::tr1::function<void (const byte[], size_t)> output_fn, + std::tr1::function<void (const byte[], size_t, Alert)> proc_fn, + std::tr1::function<bool (const Session&)> handshake_fn, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + const std::string& hostname, + std::tr1::function<std::string (std::vector<std::string>)> next_protocol) : + Channel(output_fn, proc_fn, handshake_fn), policy(policy), rng(rng), session_manager(session_manager), creds(creds) { - writer.set_version(SSL_V3); + writer.set_version(Protocol_Version::SSL_V3); - state = new Handshake_State; + state = new Handshake_State(new Stream_Handshake_Reader); state->set_expected_next(SERVER_HELLO); state->client_npn_cb = next_protocol; @@ -46,7 +46,7 @@ TLS_Client::TLS_Client(std::tr1::function<void (const byte[], size_t)> output_fn if(hostname != "") { - TLS_Session session_info; + Session session_info; if(session_manager.load_from_host_info(hostname, 0, session_info)) { if(session_info.srp_identifier() == srp_identifier) @@ -82,12 +82,12 @@ TLS_Client::TLS_Client(std::tr1::function<void (const byte[], size_t)> output_fn /* * Send a new client hello to renegotiate */ -void TLS_Client::renegotiate() +void Client::renegotiate() { if(state) return; // currently in handshake - state = new Handshake_State; + state = new Handshake_State(new Stream_Handshake_Reader); state->set_expected_next(SERVER_HELLO); state->client_hello = new Client_Hello(writer, state->hash, policy, rng, @@ -96,9 +96,9 @@ void TLS_Client::renegotiate() secure_renegotiation.update(state->client_hello); } -void TLS_Client::alert_notify(bool, Alert_Type type) +void Client::alert_notify(const Alert& alert) { - if(type == NO_RENEGOTIATION) + if(alert.type() == Alert::NO_RENEGOTIATION) { if(handshake_completed && state) { @@ -111,7 +111,7 @@ void TLS_Client::alert_notify(bool, Alert_Type type) /* * Process a handshake message */ -void TLS_Client::process_handshake_msg(Handshake_Type type, +void Client::process_handshake_msg(Handshake_Type type, const MemoryRegion<byte>& contents) { if(state == 0) @@ -131,7 +131,7 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, state = 0; // RFC 5746 section 4.2 - alert(WARNING, NO_RENEGOTIATION); + send_alert(Alert(Alert::NO_RENEGOTIATION)); return; } @@ -155,32 +155,32 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, if(!state->client_hello->offered_suite(state->server_hello->ciphersuite())) { - throw TLS_Exception(HANDSHAKE_FAILURE, + throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Server replied with ciphersuite we didn't send"); } if(!value_exists(state->client_hello->compression_methods(), state->server_hello->compression_method())) { - throw TLS_Exception(HANDSHAKE_FAILURE, + throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Server replied with compression method we didn't send"); } if(!state->client_hello->next_protocol_notification() && state->server_hello->next_protocol_notification()) { - throw TLS_Exception(HANDSHAKE_FAILURE, + throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Server sent next protocol but we didn't request it"); } - state->version = state->server_hello->version(); + state->set_version(state->server_hello->version()); - writer.set_version(state->version); - reader.set_version(state->version); + writer.set_version(state->version()); + reader.set_version(state->version()); secure_renegotiation.update(state->server_hello); - state->suite = TLS_Cipher_Suite(state->server_hello->ciphersuite()); + state->suite = Ciphersuite::by_id(state->server_hello->ciphersuite()); if(!state->server_hello->session_id().empty() && (state->server_hello->session_id() == state->client_hello->session_id())) @@ -188,18 +188,16 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, // successful resumption /* - * In this case, we offered the original session and the server - * must resume with it + * In this case, we offered the version used in the original + * session, and the server must resume with the same version. */ if(state->server_hello->version() != state->client_hello->version()) - throw TLS_Exception(HANDSHAKE_FAILURE, + throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Server resumed session but with wrong version"); - state->keys = SessionKeys(state->suite, state->version, - state->resume_master_secret, - state->client_hello->random(), - state->server_hello->random(), - true); + state->keys = Session_Keys(state, + state->resume_master_secret, + true); state->set_expected_next(HANDSHAKE_CCS); } @@ -207,23 +205,36 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, { // new session - if(state->version > state->client_hello->version()) + if(state->version() > state->client_hello->version()) { - throw TLS_Exception(HANDSHAKE_FAILURE, - "TLS_Client: Server replied with bad version"); + throw TLS_Exception(Alert::HANDSHAKE_FAILURE, + "Client: Server replied with bad version"); } - if(state->version < policy.min_version()) + if(state->version() < policy.min_version()) { - throw TLS_Exception(PROTOCOL_VERSION, - "TLS_Client: Server is too old for specified policy"); + throw TLS_Exception(Alert::PROTOCOL_VERSION, + "Client: Server is too old for specified policy"); } - if(state->suite.sig_type() != TLS_ALGO_SIGNER_ANON) + if(state->suite.sig_algo() != "") { state->set_expected_next(CERTIFICATE); } - else if(state->suite.kex_type() != TLS_ALGO_KEYEXCH_NOKEX) + else if(state->suite.kex_algo() == "PSK") + { + /* PSK is anonymous so no certificate/cert req message is + ever sent. The server may or may not send a server kex, + depending on if it has an identity hint for us. + + DHE_PSK always sends a server key exchange for the DH + exchange portion. + */ + + state->set_expected_next(SERVER_KEX); + state->set_expected_next(SERVER_HELLO_DONE); + } + else if(state->suite.kex_algo() != "RSA") { state->set_expected_next(SERVER_KEX); } @@ -236,7 +247,7 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, } else if(type == CERTIFICATE) { - if(state->suite.kex_type() != TLS_ALGO_KEYEXCH_NOKEX) + if(state->suite.kex_algo() != "RSA") { state->set_expected_next(SERVER_KEX); } @@ -250,28 +261,23 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, peer_certs = state->server_certs->cert_chain(); if(peer_certs.size() == 0) - throw TLS_Exception(HANDSHAKE_FAILURE, - "TLS_Client: No certificates sent by server"); + throw TLS_Exception(Alert::HANDSHAKE_FAILURE, + "Client: No certificates sent by server"); - if(!policy.check_cert(peer_certs)) - throw TLS_Exception(BAD_CERTIFICATE, - "TLS_Client: Server certificate is not valid"); - - state->kex_pub = peer_certs[0].subject_public_key(); + try + { + const std::string hostname = state->client_hello->sni_hostname(); + creds.verify_certificate_chain("tls-client", hostname, peer_certs); + } + catch(std::exception& e) + { + throw TLS_Exception(Alert::BAD_CERTIFICATE, e.what()); + } - bool is_dsa = false, is_rsa = false; + std::auto_ptr<Public_Key> peer_key(peer_certs[0].subject_public_key()); - if(dynamic_cast<DSA_PublicKey*>(state->kex_pub)) - is_dsa = true; - else if(dynamic_cast<RSA_PublicKey*>(state->kex_pub)) - is_rsa = true; - else - throw TLS_Exception(UNSUPPORTED_CERTIFICATE, - "Unknown key type received in server kex"); - - if((is_dsa && state->suite.sig_type() != TLS_ALGO_SIGNER_DSA) || - (is_rsa && state->suite.sig_type() != TLS_ALGO_SIGNER_RSA)) - throw TLS_Exception(ILLEGAL_PARAMETER, + if(peer_key->algo_name() != state->suite.sig_algo()) + throw TLS_Exception(Alert::ILLEGAL_PARAMETER, "Certificate key type did not match ciphersuite"); } else if(type == SERVER_KEX) @@ -279,33 +285,24 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, state->set_expected_next(CERTIFICATE_REQUEST); // optional state->set_expected_next(SERVER_HELLO_DONE); - state->server_kex = new Server_Key_Exchange(contents); - - if(state->kex_pub) - delete state->kex_pub; - - state->kex_pub = state->server_kex->key(); - - if(dynamic_cast<DH_PublicKey*>(state->kex_pub) && - state->suite.kex_type() != TLS_ALGO_KEYEXCH_DH) - { - throw TLS_Exception(HANDSHAKE_FAILURE, - "Server sent DH key but negotiated something else"); - } + state->server_kex = new Server_Key_Exchange(contents, + state->suite.kex_algo(), + state->suite.sig_algo(), + state->version()); - if(state->suite.sig_type() != TLS_ALGO_SIGNER_ANON) + if(state->suite.sig_algo() != "") { - if(!state->server_kex->verify(peer_certs[0], - state->client_hello->random(), - state->server_hello->random())) - throw TLS_Exception(DECRYPT_ERROR, - "Bad signature on server key exchange"); + if(!state->server_kex->verify(peer_certs[0], state)) + { + throw TLS_Exception(Alert::DECRYPT_ERROR, + "Bad signature on server key exchange"); + } } } else if(type == CERTIFICATE_REQUEST) { state->set_expected_next(SERVER_HELLO_DONE); - state->cert_req = new Certificate_Req(contents); + state->cert_req = new Certificate_Req(contents, state->version()); } else if(type == SERVER_HELLO_DONE) { @@ -315,11 +312,11 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, if(state->received_handshake_msg(CERTIFICATE_REQUEST)) { - std::vector<Certificate_Type> types = - state->cert_req->acceptable_types(); + const std::vector<std::string>& types = + state->cert_req->acceptable_cert_types(); std::vector<X509_Certificate> client_certs = - creds.cert_chain("", // use types here + creds.cert_chain(types, "tls-client", state->client_hello->sni_hostname()); @@ -329,9 +326,15 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, } state->client_kex = - new Client_Key_Exchange(writer, state->hash, rng, - state->kex_pub, state->version, - state->client_hello->version()); + new Client_Key_Exchange(writer, + state, + creds, + peer_certs, + rng); + + state->keys = Session_Keys(state, + state->client_kex->pre_master_secret(), + false); if(state->received_handshake_msg(CERTIFICATE_REQUEST) && !state->client_certs->empty()) @@ -341,18 +344,16 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, "tls-client", state->client_hello->sni_hostname()); - state->client_verify = new Certificate_Verify(writer, state->hash, - rng, private_key); + state->client_verify = new Certificate_Verify(writer, + state, + rng, + private_key); } - state->keys = SessionKeys(state->suite, state->version, - state->client_kex->pre_master_secret(), - state->client_hello->random(), - state->server_hello->random()); - writer.send(CHANGE_CIPHER_SPEC, 1); - writer.activate(state->suite, state->keys, CLIENT); + writer.activate(CLIENT, state->suite, state->keys, + state->server_hello->compression_method()); if(state->server_hello->next_protocol_notification()) { @@ -362,15 +363,14 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, state->next_protocol = new Next_Protocol(writer, state->hash, protocol); } - state->client_finished = new Finished(writer, state->hash, - state->version, CLIENT, - state->keys.master_secret()); + state->client_finished = new Finished(writer, state, CLIENT); } else if(type == HANDSHAKE_CCS) { state->set_expected_next(FINISHED); - reader.activate(state->suite, state->keys, CLIENT); + reader.activate(CLIENT, state->suite, state->keys, + state->server_hello->compression_method()); } else if(type == FINISHED) { @@ -378,9 +378,8 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, state->server_finished = new Finished(contents); - if(!state->server_finished->verify(state->keys.master_secret(), - state->version, state->hash, SERVER)) - throw TLS_Exception(DECRYPT_ERROR, + if(!state->server_finished->verify(state, SERVER)) + throw TLS_Exception(Alert::DECRYPT_ERROR, "Finished message didn't verify"); state->hash.update(type, contents); @@ -389,14 +388,13 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, { writer.send(CHANGE_CIPHER_SPEC, 1); - writer.activate(state->suite, state->keys, CLIENT); + writer.activate(CLIENT, state->suite, state->keys, + state->server_hello->compression_method()); - state->client_finished = new Finished(writer, state->hash, - state->version, CLIENT, - state->keys.master_secret()); + state->client_finished = new Finished(writer, state, CLIENT); } - TLS_Session session_info( + Session session_info( state->server_hello->session_id(), state->keys.master_secret(), state->server_hello->version(), @@ -426,3 +424,5 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, } } + +} diff --git a/src/tls/tls_client.h b/src/tls/tls_client.h index 95b5c8f61..5c2692cd6 100644 --- a/src/tls/tls_client.h +++ b/src/tls/tls_client.h @@ -15,10 +15,12 @@ namespace Botan { +namespace TLS { + /** * SSL/TLS Client */ -class BOTAN_DLL TLS_Client : public TLS_Channel +class BOTAN_DLL Client : public Channel { public: /** @@ -40,30 +42,32 @@ class BOTAN_DLL TLS_Client : public TLS_Channel * called with the list of protocols the server advertised; * the client should return the protocol it would like to use. */ - TLS_Client(std::tr1::function<void (const byte[], size_t)> socket_output_fn, - std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn, - std::tr1::function<bool (const TLS_Session&)> handshake_complete, - TLS_Session_Manager& session_manager, - Credentials_Manager& creds, - const TLS_Policy& policy, - RandomNumberGenerator& rng, - const std::string& servername = "", - std::tr1::function<std::string (std::vector<std::string>)> next_protocol = - std::tr1::function<std::string (std::vector<std::string>)>()); + Client(std::tr1::function<void (const byte[], size_t)> socket_output_fn, + std::tr1::function<void (const byte[], size_t, Alert)> proc_fn, + std::tr1::function<bool (const Session&)> handshake_complete, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + const std::string& servername = "", + std::tr1::function<std::string (std::vector<std::string>)> next_protocol = + std::tr1::function<std::string (std::vector<std::string>)>()); void renegotiate(); private: void process_handshake_msg(Handshake_Type type, const MemoryRegion<byte>& contents); - void alert_notify(bool is_fatal, Alert_Type type); + void alert_notify(const Alert& alert); - const TLS_Policy& policy; + const Policy& policy; RandomNumberGenerator& rng; - TLS_Session_Manager& session_manager; + Session_Manager& session_manager; Credentials_Manager& creds; }; } +} + #endif diff --git a/src/tls/tls_exceptn.h b/src/tls/tls_exceptn.h index 37b9c0d27..ad19c6c9d 100644 --- a/src/tls/tls_exceptn.h +++ b/src/tls/tls_exceptn.h @@ -9,24 +9,26 @@ #define BOTAN_TLS_EXCEPTION_H__ #include <botan/exceptn.h> -#include <botan/tls_magic.h> +#include <botan/tls_alert.h> namespace Botan { +namespace TLS { + /** * Exception Base Class */ class BOTAN_DLL TLS_Exception : public Exception { public: - Alert_Type type() const throw() { return alert_type; } + Alert::Type type() const throw() { return alert_type; } - TLS_Exception(Alert_Type type, + TLS_Exception(Alert::Type type, const std::string& err_msg = "Unknown error") : Exception(err_msg), alert_type(type) {} private: - Alert_Type alert_type; + Alert::Type alert_type; }; /** @@ -35,9 +37,11 @@ class BOTAN_DLL TLS_Exception : public Exception struct BOTAN_DLL Unexpected_Message : public TLS_Exception { Unexpected_Message(const std::string& err) : - TLS_Exception(UNEXPECTED_MESSAGE, err) {} + TLS_Exception(Alert::UNEXPECTED_MESSAGE, err) {} }; } +} + #endif diff --git a/src/tls/tls_extensions.cpp b/src/tls/tls_extensions.cpp index 8f60b5e9b..d0d5b2c62 100644 --- a/src/tls/tls_extensions.cpp +++ b/src/tls/tls_extensions.cpp @@ -11,31 +11,48 @@ namespace Botan { +namespace TLS { + namespace { -TLS_Extension* make_extension(TLS_Data_Reader& reader, - u16bit code, - u16bit size) +Extension* make_extension(TLS_Data_Reader& reader, + u16bit code, + u16bit size) { - if(code == TLSEXT_SERVER_NAME_INDICATION) - return new Server_Name_Indicator(reader, size); - else if(code == TLSEXT_MAX_FRAGMENT_LENGTH) - return new Maximum_Fragment_Length(reader, size); - else if(code == TLSEXT_SRP_IDENTIFIER) - return new SRP_Identifier(reader, size); - else if(code == TLSEXT_SAFE_RENEGOTIATION) - return new Renegotation_Extension(reader, size); - else if(code == TLSEXT_SESSION_TICKET) - return new Session_Ticket(reader, size); - else if(code == TLSEXT_NEXT_PROTOCOL) - return new Next_Protocol_Notification(reader, size); - else - return 0; // not known + switch(code) + { + case TLSEXT_SERVER_NAME_INDICATION: + return new Server_Name_Indicator(reader, size); + + case TLSEXT_MAX_FRAGMENT_LENGTH: + return new Maximum_Fragment_Length(reader, size); + + case TLSEXT_SRP_IDENTIFIER: + return new SRP_Identifier(reader, size); + + case TLSEXT_USABLE_ELLIPTIC_CURVES: + return new Supported_Elliptic_Curves(reader, size); + + case TLSEXT_SAFE_RENEGOTIATION: + return new Renegotation_Extension(reader, size); + + case TLSEXT_SIGNATURE_ALGORITHMS: + return new Signature_Algorithms(reader, size); + + case TLSEXT_NEXT_PROTOCOL: + return new Next_Protocol_Notification(reader, size); + + case TLSEXT_SESSION_TICKET: + return new Session_Ticket(reader, size); + + default: + return 0; // not known + } } } -TLS_Extensions::TLS_Extensions(TLS_Data_Reader& reader) +Extensions::Extensions(TLS_Data_Reader& reader) { if(reader.has_remaining()) { @@ -49,30 +66,31 @@ TLS_Extensions::TLS_Extensions(TLS_Data_Reader& reader) const u16bit extension_code = reader.get_u16bit(); const u16bit extension_size = reader.get_u16bit(); - TLS_Extension* extn = make_extension(reader, + Extension* extn = make_extension(reader, extension_code, extension_size); if(extn) - extensions.push_back(extn); + this->add(extn); else // unknown/unhandled extension reader.discard_next(extension_size); } } } -MemoryVector<byte> TLS_Extensions::serialize() const +MemoryVector<byte> Extensions::serialize() const { MemoryVector<byte> buf(2); // 2 bytes for length field - for(size_t i = 0; i != extensions.size(); ++i) + for(std::map<Handshake_Extension_Type, Extension*>::const_iterator i = extensions.begin(); + i != extensions.end(); ++i) { - if(extensions[i]->empty()) + if(i->second->empty()) continue; - const u16bit extn_code = extensions[i]->type(); + const u16bit extn_code = i->second->type(); - MemoryVector<byte> extn_val = extensions[i]->serialize(); + MemoryVector<byte> extn_val = i->second->serialize(); buf.push_back(get_byte(0, extn_code)); buf.push_back(get_byte(1, extn_code)); @@ -95,10 +113,14 @@ MemoryVector<byte> TLS_Extensions::serialize() const return buf; } -TLS_Extensions::~TLS_Extensions() +Extensions::~Extensions() { - for(size_t i = 0; i != extensions.size(); ++i) - delete extensions[i]; + for(std::map<Handshake_Extension_Type, Extension*>::const_iterator i = extensions.begin(); + i != extensions.end(); ++i) + { + delete i->second; + } + extensions.clear(); } @@ -204,7 +226,7 @@ size_t Maximum_Fragment_Length::fragment_size() const case 4: return 4096; default: - throw TLS_Exception(ILLEGAL_PARAMETER, + throw TLS_Exception(Alert::ILLEGAL_PARAMETER, "Bad value in maximum fragment extension"); } } @@ -271,4 +293,217 @@ MemoryVector<byte> Next_Protocol_Notification::serialize() const return buf; } +std::string Supported_Elliptic_Curves::curve_id_to_name(u16bit id) + { + switch(id) + { + case 15: + return "secp160k1"; + case 16: + return "secp160r1"; + case 17: + return "secp160r2"; + case 18: + return "secp192k1"; + case 19: + return "secp192r1"; + case 20: + return "secp224k1"; + case 21: + return "secp224r1"; + case 22: + return "secp256k1"; + case 23: + return "secp256r1"; + case 24: + return "secp384r1"; + case 25: + return "secp521r1"; + default: + return ""; // something we don't know or support + } + } + +u16bit Supported_Elliptic_Curves::name_to_curve_id(const std::string& name) + { + if(name == "secp160k1") + return 15; + if(name == "secp160r1") + return 16; + if(name == "secp160r2") + return 17; + if(name == "secp192k1") + return 18; + if(name == "secp192r1") + return 19; + if(name == "secp224k1") + return 20; + if(name == "secp224r1") + return 21; + if(name == "secp256k1") + return 22; + if(name == "secp256r1") + return 23; + if(name == "secp384r1") + return 24; + if(name == "secp521r1") + return 25; + + throw Invalid_Argument("name_to_curve_id unknown name " + name); + } + +MemoryVector<byte> Supported_Elliptic_Curves::serialize() const + { + MemoryVector<byte> buf(2); + + for(size_t i = 0; i != m_curves.size(); ++i) + { + const u16bit id = name_to_curve_id(m_curves[i]); + buf.push_back(get_byte(0, id)); + buf.push_back(get_byte(1, id)); + } + + buf[0] = get_byte<u16bit>(0, buf.size()-2); + buf[1] = get_byte<u16bit>(1, buf.size()-2); + + return buf; + } + +Supported_Elliptic_Curves::Supported_Elliptic_Curves(TLS_Data_Reader& reader, + u16bit extension_size) + { + u16bit len = reader.get_u16bit(); + + if(len + 2 != extension_size) + throw Decoding_Error("Inconsistent length field in elliptic curve list"); + + if(len % 2 == 1) + throw Decoding_Error("Elliptic curve list of strange size"); + + len /= 2; + + for(size_t i = 0; i != len; ++i) + { + const u16bit id = reader.get_u16bit(); + const std::string name = curve_id_to_name(id); + + if(name != "") + m_curves.push_back(name); + } + } + +std::string Signature_Algorithms::hash_algo_name(byte code) + { + switch(code) + { + // code 1 is MD5 - ignore it + + case 2: + return "SHA-1"; + case 3: + return "SHA-224"; + case 4: + return "SHA-256"; + case 5: + return "SHA-384"; + case 6: + return "SHA-512"; + default: + return ""; + } + } + +byte Signature_Algorithms::hash_algo_code(const std::string& name) + { + if(name == "SHA-1") + return 2; + + if(name == "SHA-224") + return 3; + + if(name == "SHA-256") + return 4; + + if(name == "SHA-384") + return 5; + + if(name == "SHA-512") + return 6; + + throw Internal_Error("Unknown hash ID " + name + " for signature_algorithms"); + } + +std::string Signature_Algorithms::sig_algo_name(byte code) + { + switch(code) + { + case 1: + return "RSA"; + case 2: + return "DSA"; + case 3: + return "ECDSA"; + default: + return ""; + } + } + +byte Signature_Algorithms::sig_algo_code(const std::string& name) + { + if(name == "RSA") + return 1; + + if(name == "DSA") + return 2; + + if(name == "ECDSA") + return 3; + + throw Internal_Error("Unknown sig ID " + name + " for signature_algorithms"); + } + +MemoryVector<byte> Signature_Algorithms::serialize() const + { + MemoryVector<byte> buf(2); + + for(size_t i = 0; i != m_supported_algos.size(); ++i) + { + if(m_supported_algos[i].second == "") + continue; + + buf.push_back(hash_algo_code(m_supported_algos[i].first)); + buf.push_back(sig_algo_code(m_supported_algos[i].second)); + } + + buf[0] = get_byte<u16bit>(0, buf.size()-2); + buf[1] = get_byte<u16bit>(1, buf.size()-2); + + return buf; + } + +Signature_Algorithms::Signature_Algorithms(TLS_Data_Reader& reader, + u16bit extension_size) + { + u16bit len = reader.get_u16bit(); + + if(len + 2 != extension_size) + throw Decoding_Error("Bad encoding on signature algorithms extension"); + + while(len) + { + const std::string hash_code = hash_algo_name(reader.get_byte()); + const std::string sig_code = sig_algo_name(reader.get_byte()); + + len -= 2; + + // If not something we know, ignore it completely + if(hash_code == "" || sig_code == "") + continue; + + m_supported_algos.push_back(std::make_pair(hash_code, sig_code)); + } + } + +} + } diff --git a/src/tls/tls_extensions.h b/src/tls/tls_extensions.h index 6d4e40434..a9e85221e 100644 --- a/src/tls/tls_extensions.h +++ b/src/tls/tls_extensions.h @@ -12,35 +12,60 @@ #include <botan/tls_magic.h> #include <vector> #include <string> +#include <map> namespace Botan { -class TLS_Session; +namespace TLS { + class TLS_Data_Reader; +enum Handshake_Extension_Type { + TLSEXT_SERVER_NAME_INDICATION = 0, + TLSEXT_MAX_FRAGMENT_LENGTH = 1, + TLSEXT_CLIENT_CERT_URL = 2, + TLSEXT_TRUSTED_CA_KEYS = 3, + TLSEXT_TRUNCATED_HMAC = 4, + + TLSEXT_CERTIFICATE_TYPES = 9, + TLSEXT_USABLE_ELLIPTIC_CURVES = 10, + TLSEXT_EC_POINT_FORMATS = 11, + TLSEXT_SRP_IDENTIFIER = 12, + TLSEXT_SIGNATURE_ALGORITHMS = 13, + + TLSEXT_SESSION_TICKET = 35, + + TLSEXT_NEXT_PROTOCOL = 13172, + + TLSEXT_SAFE_RENEGOTIATION = 65281, +}; + /** * Base class representing a TLS extension of some kind */ -class TLS_Extension +class Extension { public: - virtual TLS_Handshake_Extension_Type type() const = 0; + virtual Handshake_Extension_Type type() const = 0; + virtual MemoryVector<byte> serialize() const = 0; virtual bool empty() const = 0; - virtual ~TLS_Extension() {} + virtual ~Extension() {} }; /** * Server Name Indicator extension (RFC 3546) */ -class Server_Name_Indicator : public TLS_Extension +class Server_Name_Indicator : public Extension { public: - TLS_Handshake_Extension_Type type() const + static Handshake_Extension_Type static_type() { return TLSEXT_SERVER_NAME_INDICATION; } + Handshake_Extension_Type type() const { return static_type(); } + Server_Name_Indicator(const std::string& host_name) : sni_host_name(host_name) {} @@ -59,12 +84,14 @@ class Server_Name_Indicator : public TLS_Extension /** * SRP identifier extension (RFC 5054) */ -class SRP_Identifier : public TLS_Extension +class SRP_Identifier : public Extension { public: - TLS_Handshake_Extension_Type type() const + static Handshake_Extension_Type static_type() { return TLSEXT_SRP_IDENTIFIER; } + Handshake_Extension_Type type() const { return static_type(); } + SRP_Identifier(const std::string& identifier) : srp_identifier(identifier) {} @@ -83,12 +110,14 @@ class SRP_Identifier : public TLS_Extension /** * Renegotiation Indication Extension (RFC 5746) */ -class Renegotation_Extension : public TLS_Extension +class Renegotation_Extension : public Extension { public: - TLS_Handshake_Extension_Type type() const + static Handshake_Extension_Type static_type() { return TLSEXT_SAFE_RENEGOTIATION; } + Handshake_Extension_Type type() const { return static_type(); } + Renegotation_Extension() {} Renegotation_Extension(const MemoryRegion<byte>& bits) : @@ -110,12 +139,14 @@ class Renegotation_Extension : public TLS_Extension /** * Maximum Fragment Length Negotiation Extension (RFC 4366 sec 3.2) */ -class Maximum_Fragment_Length : public TLS_Extension +class Maximum_Fragment_Length : public Extension { public: - TLS_Handshake_Extension_Type type() const + static Handshake_Extension_Type static_type() { return TLSEXT_MAX_FRAGMENT_LENGTH; } + Handshake_Extension_Type type() const { return static_type(); } + bool empty() const { return val != 0; } size_t fragment_size() const; @@ -147,12 +178,14 @@ class Maximum_Fragment_Length : public TLS_Extension * spec (implemented in Chromium); the internet draft leaves the format * unspecified. */ -class Next_Protocol_Notification : public TLS_Extension +class Next_Protocol_Notification : public Extension { public: - TLS_Handshake_Extension_Type type() const + static Handshake_Extension_Type static_type() { return TLSEXT_NEXT_PROTOCOL; } + Handshake_Extension_Type type() const { return static_type(); } + const std::vector<std::string>& protocols() const { return m_protocols; } @@ -209,32 +242,111 @@ class Session_Ticket : public TLS_Extension }; /** +* Supported Elliptic Curves Extension (RFC 4492) +*/ +class Supported_Elliptic_Curves : public Extension + { + public: + static Handshake_Extension_Type static_type() + { return TLSEXT_USABLE_ELLIPTIC_CURVES; } + + Handshake_Extension_Type type() const { return static_type(); } + + static std::string curve_id_to_name(u16bit id); + static u16bit name_to_curve_id(const std::string& name); + + const std::vector<std::string>& curves() const { return m_curves; } + + MemoryVector<byte> serialize() const; + + Supported_Elliptic_Curves(const std::vector<std::string>& curves) : + m_curves(curves) {} + + Supported_Elliptic_Curves(TLS_Data_Reader& reader, + u16bit extension_size); + + bool empty() const { return m_curves.empty(); } + private: + std::vector<std::string> m_curves; + }; + +/** +* Signature Algorithms Extension for TLS 1.2 (RFC 5246) +*/ +class Signature_Algorithms : public Extension + { + public: + static Handshake_Extension_Type static_type() + { return TLSEXT_SIGNATURE_ALGORITHMS; } + + Handshake_Extension_Type type() const { return static_type(); } + + static std::string hash_algo_name(byte code); + static byte hash_algo_code(const std::string& name); + + static std::string sig_algo_name(byte code); + static byte sig_algo_code(const std::string& name); + + std::vector<std::pair<std::string, std::string> > + supported_signature_algorthms() const + { + return m_supported_algos; + } + + MemoryVector<byte> serialize() const; + + bool empty() const { return false; } + + Signature_Algorithms(const std::vector<std::pair<std::string, std::string> >& algos) : + m_supported_algos(algos) {} + + Signature_Algorithms(TLS_Data_Reader& reader, + u16bit extension_size); + private: + std::vector<std::pair<std::string, std::string> > m_supported_algos; + }; + +/** * Represents a block of extensions in a hello message */ -class TLS_Extensions +class Extensions { public: - size_t count() const { return extensions.size(); } + template<typename T> + T* get() const + { + Handshake_Extension_Type type = T::static_type(); - TLS_Extension* at(size_t idx) { return extensions.at(idx); } + std::map<Handshake_Extension_Type, Extension*>::const_iterator i = + extensions.find(type); - void push_back(TLS_Extension* extn) - { extensions.push_back(extn); } + if(i != extensions.end()) + return dynamic_cast<T*>(i->second); + return 0; + } + + void add(Extension* extn) + { + delete extensions[extn->type()]; // or hard error if already exists? + extensions[extn->type()] = extn; + } MemoryVector<byte> serialize() const; - TLS_Extensions() {} + Extensions() {} - TLS_Extensions(TLS_Data_Reader& reader); // deserialize + Extensions(TLS_Data_Reader& reader); // deserialize - ~TLS_Extensions(); + ~Extensions(); private: - TLS_Extensions(const TLS_Extensions&) {} - TLS_Extensions& operator=(const TLS_Extensions&) { return (*this); } + Extensions(const Extensions&) {} + Extensions& operator=(const Extensions&) { return (*this); } - std::vector<TLS_Extension*> extensions; + std::map<Handshake_Extension_Type, Extension*> extensions; }; } +} + #endif diff --git a/src/tls/tls_handshake_hash.cpp b/src/tls/tls_handshake_hash.cpp index 9621af535..61295a95c 100644 --- a/src/tls/tls_handshake_hash.cpp +++ b/src/tls/tls_handshake_hash.cpp @@ -6,14 +6,17 @@ */ #include <botan/internal/tls_handshake_hash.h> -#include <botan/md5.h> -#include <botan/sha160.h> +#include <botan/tls_exceptn.h> +#include <botan/libstate.h> +#include <botan/hash.h> #include <memory> namespace Botan { -void TLS_Handshake_Hash::update(Handshake_Type handshake_type, - const MemoryRegion<byte>& handshake_msg) +namespace TLS { + +void Handshake_Hash::update(Handshake_Type handshake_type, + const MemoryRegion<byte>& handshake_msg) { update(static_cast<byte>(handshake_type)); @@ -27,56 +30,74 @@ void TLS_Handshake_Hash::update(Handshake_Type handshake_type, /** * Return a TLS Handshake Hash */ -SecureVector<byte> TLS_Handshake_Hash::final() +SecureVector<byte> Handshake_Hash::final(Protocol_Version version, + const std::string& mac_algo) { - MD5 md5; - SHA_160 sha1; - - md5.update(data); - sha1.update(data); - - SecureVector<byte> output; - output += md5.final(); - output += sha1.final(); - return output; + Algorithm_Factory& af = global_state().algorithm_factory(); + + std::auto_ptr<HashFunction> hash; + + if(version == Protocol_Version::TLS_V10 || version == Protocol_Version::TLS_V11) + { + hash.reset(af.make_hash_function("TLS.Digest.0")); + } + else if(version == Protocol_Version::TLS_V12) + { + if(mac_algo == "SHA-1" || mac_algo == "SHA-256") + hash.reset(af.make_hash_function("SHA-256")); + else + hash.reset(af.make_hash_function(mac_algo)); + } + else + throw TLS_Exception(Alert::PROTOCOL_VERSION, + "Unknown version for handshake hashes"); + + hash->update(data); + return hash->final(); } /** * Return a SSLv3 Handshake Hash */ -SecureVector<byte> TLS_Handshake_Hash::final_ssl3(const MemoryRegion<byte>& secret) +SecureVector<byte> Handshake_Hash::final_ssl3(const MemoryRegion<byte>& secret) { const byte PAD_INNER = 0x36, PAD_OUTER = 0x5C; - MD5 md5; - SHA_160 sha1; + Algorithm_Factory& af = global_state().algorithm_factory(); - md5.update(data); - sha1.update(data); + std::auto_ptr<HashFunction> md5(af.make_hash_function("MD5")); + std::auto_ptr<HashFunction> sha1(af.make_hash_function("SHA-1")); - md5.update(secret); - sha1.update(secret); + md5->update(data); + sha1->update(data); + + md5->update(secret); + sha1->update(secret); for(size_t i = 0; i != 48; ++i) - md5.update(PAD_INNER); + md5->update(PAD_INNER); for(size_t i = 0; i != 40; ++i) - sha1.update(PAD_INNER); + sha1->update(PAD_INNER); + + SecureVector<byte> inner_md5 = md5->final(), inner_sha1 = sha1->final(); - SecureVector<byte> inner_md5 = md5.final(), inner_sha1 = sha1.final(); + md5->update(secret); + sha1->update(secret); - md5.update(secret); - sha1.update(secret); for(size_t i = 0; i != 48; ++i) - md5.update(PAD_OUTER); + md5->update(PAD_OUTER); for(size_t i = 0; i != 40; ++i) - sha1.update(PAD_OUTER); - md5.update(inner_md5); - sha1.update(inner_sha1); + sha1->update(PAD_OUTER); + + md5->update(inner_md5); + sha1->update(inner_sha1); SecureVector<byte> output; - output += md5.final(); - output += sha1.final(); + output += md5->final(); + output += sha1->final(); return output; } } + +} diff --git a/src/tls/tls_handshake_hash.h b/src/tls/tls_handshake_hash.h index 4ee1fc1b9..c13f97aa8 100644 --- a/src/tls/tls_handshake_hash.h +++ b/src/tls/tls_handshake_hash.h @@ -9,16 +9,19 @@ #define BOTAN_TLS_HANDSHAKE_HASH_H__ #include <botan/secmem.h> +#include <botan/tls_version.h> #include <botan/tls_magic.h> namespace Botan { +namespace TLS { + using namespace Botan; /** * TLS Handshake Hash */ -class TLS_Handshake_Hash +class Handshake_Hash { public: void update(const byte in[], size_t length) @@ -33,8 +36,10 @@ class TLS_Handshake_Hash void update(Handshake_Type handshake_type, const MemoryRegion<byte>& handshake_msg); - SecureVector<byte> final(); - SecureVector<byte> final_ssl3(const MemoryRegion<byte>&); + SecureVector<byte> final(Protocol_Version version, + const std::string& mac_algo); + + SecureVector<byte> final_ssl3(const MemoryRegion<byte>& master_secret); const SecureVector<byte>& get_contents() const { return data; } @@ -45,4 +50,6 @@ class TLS_Handshake_Hash } +} + #endif diff --git a/src/tls/tls_handshake_reader.cpp b/src/tls/tls_handshake_reader.cpp new file mode 100644 index 000000000..8278a2296 --- /dev/null +++ b/src/tls/tls_handshake_reader.cpp @@ -0,0 +1,66 @@ +/* +* TLS Handshake Reader +* (C) 2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#include <botan/internal/tls_handshake_reader.h> +#include <botan/exceptn.h> + +namespace Botan { + +namespace TLS { + +void Stream_Handshake_Reader::add_input(const byte record[], + size_t record_size) + { + m_queue.write(record, record_size); + } + +bool Stream_Handshake_Reader::empty() const + { + return m_queue.empty(); + } + +bool Stream_Handshake_Reader::have_full_record() const + { + if(m_queue.size() >= 4) + { + byte head[4] = { 0 }; + m_queue.peek(head, 4); + + const size_t length = make_u32bit(0, head[1], head[2], head[3]); + + return (m_queue.size() >= length + 4); + } + + return false; + } + +std::pair<Handshake_Type, MemoryVector<byte> > Stream_Handshake_Reader::get_next_record() + { + if(m_queue.size() >= 4) + { + byte head[4] = { 0 }; + m_queue.peek(head, 4); + + const size_t length = make_u32bit(0, head[1], head[2], head[3]); + + if(m_queue.size() >= length + 4) + { + Handshake_Type type = static_cast<Handshake_Type>(head[0]); + MemoryVector<byte> contents(length); + m_queue.read(head, 4); // discard + m_queue.read(&contents[0], contents.size()); + + return std::make_pair(type, contents); + } + } + + throw Internal_Error("Stream_Handshake_Reader::get_next_record called without a full record"); + } + +} + +} diff --git a/src/tls/tls_handshake_reader.h b/src/tls/tls_handshake_reader.h new file mode 100644 index 000000000..06a273ced --- /dev/null +++ b/src/tls/tls_handshake_reader.h @@ -0,0 +1,58 @@ +/* +* TLS Handshake Reader +* (C) 2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#ifndef BOTAN_TLS_HANDSHAKE_READER_H__ +#define BOTAN_TLS_HANDSHAKE_READER_H__ + +#include <botan/tls_magic.h> +#include <botan/secqueue.h> +#include <botan/loadstor.h> +#include <utility> + +namespace Botan { + +namespace TLS { + +/** +* Handshake Reader Interface +*/ +class Handshake_Reader + { + public: + virtual void add_input(const byte record[], size_t record_size) = 0; + + virtual bool empty() const = 0; + + virtual bool have_full_record() const = 0; + + virtual std::pair<Handshake_Type, MemoryVector<byte> > get_next_record() = 0; + + virtual ~Handshake_Reader() {} + }; + +/** +* Reader of TLS handshake messages +*/ +class Stream_Handshake_Reader : public Handshake_Reader + { + public: + void add_input(const byte record[], size_t record_size); + + bool empty() const; + + bool have_full_record() const; + + std::pair<Handshake_Type, MemoryVector<byte> > get_next_record(); + private: + SecureQueue m_queue; + }; + +} + +} + +#endif diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp index d6f215067..2db97db0a 100644 --- a/src/tls/tls_handshake_state.cpp +++ b/src/tls/tls_handshake_state.cpp @@ -6,9 +6,14 @@ */ #include <botan/internal/tls_handshake_state.h> +#include <botan/internal/tls_messages.h> +#include <botan/internal/assert.h> +#include <botan/lookup.h> namespace Botan { +namespace TLS { + namespace { u32bit bitmask_for_handshake_type(Handshake_Type type) @@ -71,7 +76,7 @@ u32bit bitmask_for_handshake_type(Handshake_Type type) /* * Initialize the SSL/TLS Handshake State */ -Handshake_State::Handshake_State() +Handshake_State::Handshake_State(Handshake_Reader* reader) { client_hello = 0; server_hello = 0; @@ -87,15 +92,21 @@ Handshake_State::Handshake_State() client_finished = 0; server_finished = 0; - kex_pub = 0; - kex_priv = 0; + m_handshake_reader = reader; + + server_rsa_kex_key = 0; - version = SSL_V3; + m_version = Protocol_Version::SSL_V3; hand_expecting_mask = 0; hand_received_mask = 0; } +void Handshake_State::set_version(const Protocol_Version& version) + { + m_version = version; + } + void Handshake_State::confirm_transition_to(Handshake_Type handshake_msg) { const u32bit mask = bitmask_for_handshake_type(handshake_msg); @@ -128,6 +139,142 @@ bool Handshake_State::received_handshake_msg(Handshake_Type handshake_msg) const return (hand_received_mask & mask); } +KDF* Handshake_State::protocol_specific_prf() + { + if(version() == Protocol_Version::SSL_V3) + { + return get_kdf("SSL3-PRF"); + } + else if(version() == Protocol_Version::TLS_V10 || version() == Protocol_Version::TLS_V11) + { + return get_kdf("TLS-PRF"); + } + else if(version() == Protocol_Version::TLS_V12) + { + if(suite.mac_algo() == "SHA-1" || suite.mac_algo() == "SHA-256") + return get_kdf("TLS-12-PRF(SHA-256)"); + + return get_kdf("TLS-12-PRF(" + suite.mac_algo() + ")"); + } + + throw Internal_Error("Unknown version code " + version().to_string()); + } + +std::pair<std::string, Signature_Format> +Handshake_State::choose_sig_format(const Private_Key* key, + std::string& hash_algo_out, + std::string& sig_algo_out, + bool for_client_auth) + { + const std::string sig_algo = key->algo_name(); + + const std::vector<std::pair<std::string, std::string> > supported_algos = + (for_client_auth) ? cert_req->supported_algos() : client_hello->supported_algos(); + + std::string hash_algo; + + for(size_t i = 0; i != supported_algos.size(); ++i) + { + if(supported_algos[i].second == sig_algo) + { + hash_algo = supported_algos[i].first; + break; + } + } + + if(for_client_auth && this->version() == Protocol_Version::SSL_V3) + hash_algo = "Raw"; + + if(hash_algo == "" && this->version() == Protocol_Version::TLS_V12) + hash_algo = "SHA-1"; // TLS 1.2 but no compatible hashes set (?) + + BOTAN_ASSERT(hash_algo != "", "Couldn't figure out hash to use"); + + if(this->version() >= Protocol_Version::TLS_V12) + { + hash_algo_out = hash_algo; + sig_algo_out = sig_algo; + } + + if(sig_algo == "RSA") + { + const std::string padding = "EMSA3(" + hash_algo + ")"; + + return std::make_pair(padding, IEEE_1363); + } + else if(sig_algo == "DSA" || sig_algo == "ECDSA") + { + const std::string padding = "EMSA1(" + hash_algo + ")"; + + return std::make_pair(padding, DER_SEQUENCE); + } + + throw Invalid_Argument(sig_algo + " is invalid/unknown for TLS signatures"); + } + +std::pair<std::string, Signature_Format> +Handshake_State::understand_sig_format(const Public_Key* key, + std::string hash_algo, + std::string sig_algo, + bool for_client_auth) + { + const std::string algo_name = key->algo_name(); + + /* + FIXME: This should check what was sent against the client hello + preferences, or the certificate request, to ensure it was allowed + by those restrictions. + + Or not? + */ + + if(this->version() < Protocol_Version::TLS_V12) + { + if(hash_algo != "" || sig_algo != "") + throw Decoding_Error("Counterparty sent hash/sig IDs with old version"); + } + else + { + if(hash_algo == "") + throw Decoding_Error("Counterparty did not send hash/sig IDS"); + + if(sig_algo != algo_name) + throw Decoding_Error("Counterparty sent inconsistent key and sig types"); + } + + if(algo_name == "RSA") + { + if(for_client_auth && this->version() == Protocol_Version::SSL_V3) + { + hash_algo = "Raw"; + } + else if(this->version() < Protocol_Version::TLS_V12) + { + hash_algo = "TLS.Digest.0"; + } + + const std::string padding = "EMSA3(" + hash_algo + ")"; + return std::make_pair(padding, IEEE_1363); + } + else if(algo_name == "DSA" || algo_name == "ECDSA") + { + if(algo_name == "DSA" && for_client_auth && this->version() == Protocol_Version::SSL_V3) + { + hash_algo = "Raw"; + } + else if(this->version() < Protocol_Version::TLS_V12) + { + hash_algo = "SHA-1"; + } + + const std::string padding = "EMSA1(" + hash_algo + ")"; + + return std::make_pair(padding, DER_SEQUENCE); + } + + throw Invalid_Argument(algo_name + " is invalid/unknown for TLS signatures"); + } + /* * Destroy the SSL/TLS Handshake State */ @@ -147,8 +294,9 @@ Handshake_State::~Handshake_State() delete client_finished; delete server_finished; - delete kex_pub; - delete kex_priv; + delete m_handshake_reader; } } + +} diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h index 7ca2dae94..206e19096 100644 --- a/src/tls/tls_handshake_state.h +++ b/src/tls/tls_handshake_state.h @@ -8,9 +8,13 @@ #ifndef BOTAN_TLS_HANDSHAKE_STATE_H__ #define BOTAN_TLS_HANDSHAKE_STATE_H__ -#include <botan/internal/tls_messages.h> +#include <botan/internal/tls_handshake_hash.h> +#include <botan/internal/tls_handshake_reader.h> #include <botan/internal/tls_session_key.h> -#include <botan/secqueue.h> +#include <botan/pk_keys.h> +#include <botan/pubkey.h> + +#include <utility> #if defined(BOTAN_USE_STD_TR1) @@ -28,13 +32,17 @@ namespace Botan { +class KDF; + +namespace TLS { + /** * SSL/TLS Handshake State */ class Handshake_State { public: - Handshake_State(); + Handshake_State(Handshake_Reader* reader); ~Handshake_State(); bool received_handshake_msg(Handshake_Type handshake_msg) const; @@ -42,32 +50,46 @@ class Handshake_State void confirm_transition_to(Handshake_Type handshake_msg); void set_expected_next(Handshake_Type handshake_msg); - Version_Code version; + std::pair<std::string, Signature_Format> + understand_sig_format(const Public_Key* key, + std::string hash_algo, + std::string sig_algo, + bool for_client_auth); + + std::pair<std::string, Signature_Format> + choose_sig_format(const Private_Key* key, + std::string& hash_algo, + std::string& sig_algo, + bool for_client_auth); - Client_Hello* client_hello; - Server_Hello* server_hello; - Certificate* server_certs; - Server_Key_Exchange* server_kex; - Certificate_Req* cert_req; - Server_Hello_Done* server_hello_done; + KDF* protocol_specific_prf(); - Certificate* client_certs; - Client_Key_Exchange* client_kex; - Certificate_Verify* client_verify; + Protocol_Version version() const { return m_version; } - Next_Protocol* next_protocol; + void set_version(const Protocol_Version& version); - Finished* client_finished; - Finished* server_finished; + class Client_Hello* client_hello; + class Server_Hello* server_hello; + class Certificate* server_certs; + class Server_Key_Exchange* server_kex; + class Certificate_Req* cert_req; + class Server_Hello_Done* server_hello_done; - Public_Key* kex_pub; - Private_Key* kex_priv; + class Certificate* client_certs; + class Client_Key_Exchange* client_kex; + class Certificate_Verify* client_verify; - TLS_Cipher_Suite suite; - SessionKeys keys; - TLS_Handshake_Hash hash; + class Next_Protocol* next_protocol; - SecureQueue queue; + class Finished* client_finished; + class Finished* server_finished; + + // Used by the server only, in case of RSA key exchange + Private_Key* server_rsa_kex_key; + + Ciphersuite suite; + Session_Keys keys; + Handshake_Hash hash; /* * Only used by clients for session resumption @@ -79,10 +101,15 @@ class Handshake_State */ std::tr1::function<std::string (std::vector<std::string>)> client_npn_cb; + Handshake_Reader* handshake_reader() { return m_handshake_reader; } private: + Handshake_Reader* m_handshake_reader; u32bit hand_expecting_mask, hand_received_mask; + Protocol_Version m_version; }; } +} + #endif diff --git a/src/tls/tls_magic.h b/src/tls/tls_magic.h index 5a35d4c46..0e45407d3 100644 --- a/src/tls/tls_magic.h +++ b/src/tls/tls_magic.h @@ -10,6 +10,8 @@ namespace Botan { +namespace TLS { + /** * Protocol Constants for SSL/TLS */ @@ -22,13 +24,6 @@ enum Size_Limits { MAX_TLS_RECORD_SIZE = MAX_CIPHERTEXT_SIZE + TLS_HEADER_SIZE, }; -enum Version_Code { - NO_VERSION_SET = 0x0000, - SSL_V3 = 0x0300, - TLS_V10 = 0x0301, - TLS_V11 = 0x0302 -}; - enum Connection_Side { CLIENT = 1, SERVER = 2 }; enum Record_Type { @@ -41,69 +36,24 @@ enum Record_Type { }; enum Handshake_Type { - HELLO_REQUEST = 0, - CLIENT_HELLO = 1, - CLIENT_HELLO_SSLV2 = 200, // Not a wire value - SERVER_HELLO = 2, - NEW_SESSION_TICKET = 4, // RFC 5077 - CERTIFICATE = 11, - SERVER_KEX = 12, - CERTIFICATE_REQUEST = 13, - SERVER_HELLO_DONE = 14, - CERTIFICATE_VERIFY = 15, - CLIENT_KEX = 16, - FINISHED = 20, - - NEXT_PROTOCOL = 67, - - HANDSHAKE_CCS = 100, // Not a wire value - HANDSHAKE_NONE = 255 // Null value -}; - -enum Alert_Level { - WARNING = 1, - FATAL = 2 -}; - -enum Alert_Type { - CLOSE_NOTIFY = 0, - UNEXPECTED_MESSAGE = 10, - BAD_RECORD_MAC = 20, - DECRYPTION_FAILED = 21, - RECORD_OVERFLOW = 22, - DECOMPRESSION_FAILURE = 30, - HANDSHAKE_FAILURE = 40, - NO_CERTIFICATE = 41, // SSLv3 only - BAD_CERTIFICATE = 42, - UNSUPPORTED_CERTIFICATE = 43, - CERTIFICATE_REVOKED = 44, - CERTIFICATE_EXPIRED = 45, - CERTIFICATE_UNKNOWN = 46, - ILLEGAL_PARAMETER = 47, - UNKNOWN_CA = 48, - ACCESS_DENIED = 49, - DECODE_ERROR = 50, - DECRYPT_ERROR = 51, - EXPORT_RESTRICTION = 60, - PROTOCOL_VERSION = 70, - INSUFFICIENT_SECURITY = 71, - INTERNAL_ERROR = 80, - USER_CANCELED = 90, - NO_RENEGOTIATION = 100, - - UNSUPPORTED_EXTENSION = 110, - UNRECOGNIZED_NAME = 112, - - UNKNOWN_PSK_IDENTITY = 115, - - NULL_ALERT = 255 -}; - -enum Certificate_Type { - RSA_CERT = 1, - DSS_CERT = 2, - DH_RSA_CERT = 3, - DH_DSS_CERT = 4 + HELLO_REQUEST = 0, + CLIENT_HELLO = 1, + CLIENT_HELLO_SSLV2 = 253, // Not a wire value + SERVER_HELLO = 2, + HELLO_VERIFY_REQUEST = 3, + NEW_SESSION_TICKET = 4, // RFC 5077 + CERTIFICATE = 11, + SERVER_KEX = 12, + CERTIFICATE_REQUEST = 13, + SERVER_HELLO_DONE = 14, + CERTIFICATE_VERIFY = 15, + CLIENT_KEX = 16, + FINISHED = 20, + + NEXT_PROTOCOL = 67, + + HANDSHAKE_CCS = 254, // Not a wire value + HANDSHAKE_NONE = 255 // Null value }; enum Ciphersuite_Code { @@ -115,13 +65,18 @@ enum Ciphersuite_Code { TLS_RSA_WITH_AES_256_CBC_SHA = 0x0035, TLS_RSA_WITH_AES_128_CBC_SHA256 = 0x003C, TLS_RSA_WITH_AES_256_CBC_SHA256 = 0x003D, + TLS_RSA_WITH_CAMELLIA_128_CBC_SHA = 0x0041, + TLS_RSA_WITH_CAMELLIA_256_CBC_SHA = 0x0084, TLS_RSA_WITH_SEED_CBC_SHA = 0x0096, + TLS_RSA_WITH_IDEA_CBC_SHA = 0x0007, TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA = 0x0013, TLS_DHE_DSS_WITH_AES_128_CBC_SHA = 0x0032, TLS_DHE_DSS_WITH_AES_256_CBC_SHA = 0x0038, TLS_DHE_DSS_WITH_AES_128_CBC_SHA256 = 0x0040, TLS_DHE_DSS_WITH_AES_256_CBC_SHA256 = 0x006A, + TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA = 0x0044, + TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA = 0x0087, TLS_DHE_DSS_WITH_SEED_CBC_SHA = 0x0099, TLS_DHE_DSS_WITH_RC4_128_SHA = 0x0066, @@ -130,15 +85,10 @@ enum Ciphersuite_Code { TLS_DHE_RSA_WITH_AES_256_CBC_SHA = 0x0039, TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 = 0x0067, TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 = 0x006B, + TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA = 0x0045, + TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA = 0x0088, TLS_DHE_RSA_WITH_SEED_CBC_SHA = 0x009A, - TLS_SRP_SHA_RSA_WITH_3DES_EDE_SHA = 0xC01B, - TLS_SRP_SHA_DSS_WITH_3DES_EDE_SHA = 0xC01C, - TLS_SRP_SHA_RSA_WITH_AES_128_SHA = 0xC01E, - TLS_SRP_SHA_DSS_WITH_AES_128_SHA = 0xC01F, - TLS_SRP_SHA_RSA_WITH_AES_256_SHA = 0xC021, - TLS_SRP_SHA_DSS_WITH_AES_256_SHA = 0xC022, - TLS_ECDHE_ECDSA_WITH_RC4_128_SHA = 0xC007, TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA = 0xC008, TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA = 0xC009, @@ -153,65 +103,49 @@ enum Ciphersuite_Code { TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 = 0xC027, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 = 0xC028, + TLS_PSK_WITH_RC4_128_SHA = 0x008A, + TLS_PSK_WITH_3DES_EDE_CBC_SHA = 0x008B, + TLS_PSK_WITH_AES_128_CBC_SHA = 0x008C, + TLS_PSK_WITH_AES_256_CBC_SHA = 0x008D, + TLS_PSK_WITH_AES_128_CBC_SHA256 = 0x00AE, + TLS_PSK_WITH_AES_256_CBC_SHA384 = 0x00AF, + + TLS_DHE_PSK_WITH_RC4_128_SHA = 0x008E, + TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA = 0x008F, + TLS_DHE_PSK_WITH_AES_128_CBC_SHA = 0x0090, + TLS_DHE_PSK_WITH_AES_256_CBC_SHA = 0x0091, + TLS_DHE_PSK_WITH_AES_128_CBC_SHA256 = 0x00B2, + TLS_DHE_PSK_WITH_AES_256_CBC_SHA384 = 0x00B3, + + TLS_ECDHE_PSK_WITH_RC4_128_SHA = 0xC033, + TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA = 0xC034, + TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA = 0xC035, + TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA = 0xC036, + TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 = 0xC037, + TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384 = 0xC038, + + TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA = 0xC01A, + TLS_SRP_SHA_WITH_AES_128_CBC_SHA = 0xC01D, + TLS_SRP_SHA_WITH_AES_256_CBC_SHA = 0xC020, + + TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA = 0xC01C, + TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA = 0xC01F, + TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA = 0xC022, + + TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA = 0xC01B, + TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA = 0xC01E, + TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA = 0xC021, + /* signalling values that cannot be negotiated */ TLS_EMPTY_RENEGOTIATION_INFO_SCSV = 0x00FF }; -/* -* Form of the ciphersuites broken down by field instead of -* being randomly assigned codepoints. -*/ -enum TLS_Ciphersuite_Algos { - TLS_ALGO_SIGNER_MASK = 0xFF000000, - TLS_ALGO_SIGNER_ANON = 0x01000000, - TLS_ALGO_SIGNER_RSA = 0x02000000, - TLS_ALGO_SIGNER_DSA = 0x03000000, - TLS_ALGO_SIGNER_ECDSA = 0x04000000, - - TLS_ALGO_KEYEXCH_MASK = 0x00FF0000, - TLS_ALGO_KEYEXCH_NOKEX = 0x00010000, // exchange via key in server cert - TLS_ALGO_KEYEXCH_DH = 0x00020000, - TLS_ALGO_KEYEXCH_ECDH = 0x00030000, - TLS_ALGO_KEYEXCH_SRP = 0x00040000, - - TLS_ALGO_MAC_MASK = 0x0000FF00, - TLS_ALGO_MAC_MD5 = 0x00000100, - TLS_ALGO_MAC_SHA1 = 0x00000200, - TLS_ALGO_MAC_SHA256 = 0x00000300, - TLS_ALGO_MAC_SHA384 = 0x00000400, - - TLS_ALGO_CIPHER_MASK = 0x000000FF, - TLS_ALGO_CIPHER_RC4_128 = 0x00000001, - TLS_ALGO_CIPHER_3DES_CBC = 0x00000002, - TLS_ALGO_CIPHER_AES128_CBC = 0x00000003, - TLS_ALGO_CIPHER_AES256_CBC = 0x00000004, - TLS_ALGO_CIPHER_SEED_CBC = 0x00000005 -}; - enum Compression_Method { NO_COMPRESSION = 0x00, DEFLATE_COMPRESSION = 0x01 }; -enum TLS_Handshake_Extension_Type { - TLSEXT_SERVER_NAME_INDICATION = 0, - TLSEXT_MAX_FRAGMENT_LENGTH = 1, - TLSEXT_CLIENT_CERT_URL = 2, - TLSEXT_TRUSTED_CA_KEYS = 3, - TLSEXT_TRUNCATED_HMAC = 4, - - TLSEXT_USABLE_ELLIPTIC_CURVES = 10, - TLSEXT_EC_POINT_FORMATS = 11, - - TLSEXT_SRP_IDENTIFIER = 12, - - TLSEXT_CERTIFICATE_TYPES = 9, - TLSEXT_SESSION_TICKET = 35, - - TLSEXT_NEXT_PROTOCOL = 13172, - - TLSEXT_SAFE_RENEGOTIATION = 65281, -}; +} } diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h index 94e17cb9b..617b03813 100644 --- a/src/tls/tls_messages.h +++ b/src/tls/tls_messages.h @@ -8,11 +8,11 @@ #ifndef BOTAN_TLS_MESSAGES_H__ #define BOTAN_TLS_MESSAGES_H__ -#include <botan/internal/tls_handshake_hash.h> +#include <botan/internal/tls_handshake_state.h> #include <botan/tls_session.h> #include <botan/tls_policy.h> #include <botan/tls_magic.h> -#include <botan/tls_suites.h> +#include <botan/tls_ciphersuite.h> #include <botan/bigint.h> #include <botan/pkcs8.h> #include <botan/x509cert.h> @@ -20,6 +20,10 @@ namespace Botan { +class Credentials_Manager; + +namespace TLS { + class Record_Writer; class Record_Reader; @@ -29,27 +33,46 @@ class Record_Reader; class Handshake_Message { public: - void send(Record_Writer& writer, TLS_Handshake_Hash& hash) const; - + virtual MemoryVector<byte> serialize() const = 0; virtual Handshake_Type type() const = 0; + Handshake_Message() {} virtual ~Handshake_Message() {} private: + Handshake_Message(const Handshake_Message&) {} Handshake_Message& operator=(const Handshake_Message&) { return (*this); } - virtual MemoryVector<byte> serialize() const = 0; - virtual void deserialize(const MemoryRegion<byte>&) = 0; }; MemoryVector<byte> make_hello_random(RandomNumberGenerator& rng); /** +* DTLS Hello Verify Request +*/ +class Hello_Verify_Request : public Handshake_Message + { + public: + MemoryVector<byte> serialize() const; + Handshake_Type type() const { return HELLO_VERIFY_REQUEST; } + + MemoryVector<byte> cookie() const { return m_cookie; } + + Hello_Verify_Request(const MemoryRegion<byte>& buf); + + Hello_Verify_Request(const MemoryVector<byte>& client_hello_bits, + const std::string& client_identity, + const SymmetricKey& secret_key); + private: + MemoryVector<byte> m_cookie; + }; + +/** * Client Hello Message */ class Client_Hello : public Handshake_Message { public: Handshake_Type type() const { return CLIENT_HELLO; } - Version_Code version() const { return m_version; } + Protocol_Version version() const { return m_version; } const MemoryVector<byte>& session_id() const { return m_session_id; } std::vector<byte> session_id_vector() const @@ -59,6 +82,12 @@ class Client_Hello : public Handshake_Message return v; } + const std::vector<std::pair<std::string, std::string> >& supported_algos() const + { return m_supported_algos; } + + const std::vector<std::string>& supported_ecc_curves() const + { return m_supported_curves; } + std::vector<u16bit> ciphersuites() const { return m_suites; } std::vector<byte> compression_methods() const { return m_comp_methods; } @@ -85,8 +114,8 @@ class Client_Hello : public Handshake_Message { return m_session_ticket; } Client_Hello(Record_Writer& writer, - TLS_Handshake_Hash& hash, - const TLS_Policy& policy, + Handshake_Hash& hash, + const Policy& policy, RandomNumberGenerator& rng, const MemoryRegion<byte>& reneg_info, bool next_protocol = false, @@ -94,26 +123,20 @@ class Client_Hello : public Handshake_Message const std::string& srp_identifier = ""); Client_Hello(Record_Writer& writer, - TLS_Handshake_Hash& hash, + Handshake_Hash& hash, RandomNumberGenerator& rng, - const TLS_Session& resumed_session, + const Session& resumed_session, bool next_protocol = false); Client_Hello(const MemoryRegion<byte>& buf, - Handshake_Type type) - { - if(type == CLIENT_HELLO) - deserialize(buf); - else - deserialize_sslv2(buf); - } + Handshake_Type type); private: MemoryVector<byte> serialize() const; void deserialize(const MemoryRegion<byte>& buf); void deserialize_sslv2(const MemoryRegion<byte>& buf); - Version_Code m_version; + Protocol_Version m_version; MemoryVector<byte> m_session_id, m_random; std::vector<u16bit> m_suites; std::vector<byte> m_comp_methods; @@ -125,6 +148,9 @@ class Client_Hello : public Handshake_Message bool m_secure_renegotiation; MemoryVector<byte> m_renegotiation_info; + std::vector<std::pair<std::string, std::string> > m_supported_algos; + std::vector<std::string> m_supported_curves; + bool m_supports_session_ticket; MemoryVector<byte> m_session_ticket; }; @@ -136,7 +162,7 @@ class Server_Hello : public Handshake_Message { public: Handshake_Type type() const { return SERVER_HELLO; } - Version_Code version() { return s_version; } + Protocol_Version version() { return s_version; } const MemoryVector<byte>& session_id() const { return m_session_id; } u16bit ciphersuite() const { return suite; } byte compression_method() const { return comp_method; } @@ -163,11 +189,11 @@ class Server_Hello : public Handshake_Message const MemoryVector<byte>& random() const { return s_random; } Server_Hello(Record_Writer& writer, - TLS_Handshake_Hash& hash, - Version_Code version, + Handshake_Hash& hash, + Protocol_Version version, const Client_Hello& other, - const std::vector<X509_Certificate>& certs, - const TLS_Policy& policies, + const std::vector<std::string>& available_cert_types, + const Policy& policies, bool client_has_secure_renegotiation, const MemoryRegion<byte>& reneg_info, bool client_has_npn, @@ -175,9 +201,9 @@ class Server_Hello : public Handshake_Message RandomNumberGenerator& rng); Server_Hello(Record_Writer& writer, - TLS_Handshake_Hash& hash, + Handshake_Hash& hash, const MemoryRegion<byte>& session_id, - Version_Code ver, + Protocol_Version ver, u16bit ciphersuite, byte compression, size_t max_fragment_size, @@ -187,12 +213,11 @@ class Server_Hello : public Handshake_Message const std::vector<std::string>& next_protocols, RandomNumberGenerator& rng); - Server_Hello(const MemoryRegion<byte>& buf) { deserialize(buf); } + Server_Hello(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); - Version_Code s_version; + Protocol_Version s_version; MemoryVector<byte> m_session_id, s_random; u16bit suite; byte comp_method; @@ -216,26 +241,22 @@ class Client_Key_Exchange : public Handshake_Message const SecureVector<byte>& pre_master_secret() const { return pre_master; } - SecureVector<byte> pre_master_secret(RandomNumberGenerator& rng, - const Private_Key* key, - Version_Code version); - Client_Key_Exchange(Record_Writer& output, - TLS_Handshake_Hash& hash, - RandomNumberGenerator& rng, - const Public_Key* my_key, - Version_Code using_version, - Version_Code pref_version); + Handshake_State* state, + Credentials_Manager& creds, + const std::vector<X509_Certificate>& peer_certs, + RandomNumberGenerator& rng); Client_Key_Exchange(const MemoryRegion<byte>& buf, - const TLS_Cipher_Suite& suite, - Version_Code using_version); + const Handshake_State* state, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng); + private: - MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); + MemoryVector<byte> serialize() const { return key_material; } SecureVector<byte> key_material, pre_master; - bool include_length; }; /** @@ -245,20 +266,20 @@ class Certificate : public Handshake_Message { public: Handshake_Type type() const { return CERTIFICATE; } - const std::vector<X509_Certificate>& cert_chain() const { return certs; } + const std::vector<X509_Certificate>& cert_chain() const { return m_certs; } - size_t count() const { return certs.size(); } - bool empty() const { return certs.empty(); } + size_t count() const { return m_certs.size(); } + bool empty() const { return m_certs.empty(); } Certificate(Record_Writer& writer, - TLS_Handshake_Hash& hash, + Handshake_Hash& hash, const std::vector<X509_Certificate>& certs); - Certificate(const MemoryRegion<byte>& buf) { deserialize(buf); } + Certificate(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); - std::vector<X509_Certificate> certs; + + std::vector<X509_Certificate> m_certs; }; /** @@ -269,22 +290,29 @@ class Certificate_Req : public Handshake_Message public: Handshake_Type type() const { return CERTIFICATE_REQUEST; } - std::vector<Certificate_Type> acceptable_types() const { return types; } + const std::vector<std::string>& acceptable_cert_types() const + { return cert_key_types; } + std::vector<X509_DN> acceptable_CAs() const { return names; } + std::vector<std::pair<std::string, std::string> > supported_algos() const + { return m_supported_algos; } + Certificate_Req(Record_Writer& writer, - TLS_Handshake_Hash& hash, + Handshake_Hash& hash, + const Policy& policy, const std::vector<X509_Certificate>& allowed_cas, - const std::vector<Certificate_Type>& types = - std::vector<Certificate_Type>()); + Protocol_Version version); - Certificate_Req(const MemoryRegion<byte>& buf) { deserialize(buf); } + Certificate_Req(const MemoryRegion<byte>& buf, + Protocol_Version version); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); std::vector<X509_DN> names; - std::vector<Certificate_Type> types; + std::vector<std::string> cert_key_types; + + std::vector<std::pair<std::string, std::string> > m_supported_algos; }; /** @@ -298,25 +326,23 @@ class Certificate_Verify : public Handshake_Message /** * Check the signature on a certificate verify message * @param cert the purported certificate - * @param hash the running handshake message hash - * @param version the version number we negotiated - * @param master_secret the session key (only used if version is SSL_V3) + * @param state the handshake state */ bool verify(const X509_Certificate& cert, - TLS_Handshake_Hash& hash, - Version_Code version, - const SecureVector<byte>& master_secret); + Handshake_State* state); Certificate_Verify(Record_Writer& writer, - TLS_Handshake_Hash& hash, + Handshake_State* state, RandomNumberGenerator& rng, const Private_Key* key); - Certificate_Verify(const MemoryRegion<byte>& buf) { deserialize(buf); } + Certificate_Verify(const MemoryRegion<byte>& buf, + Protocol_Version version); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); + std::string sig_algo; // sig algo used to create signature + std::string hash_algo; // hash used to create signature MemoryVector<byte> signature; }; @@ -331,26 +357,16 @@ class Finished : public Handshake_Message MemoryVector<byte> verify_data() const { return verification_data; } - bool verify(const MemoryRegion<byte>& buf, - Version_Code version, - const TLS_Handshake_Hash& hash, + bool verify(Handshake_State* state, Connection_Side side); Finished(Record_Writer& writer, - TLS_Handshake_Hash& hash, - Version_Code version, - Connection_Side side, - const MemoryRegion<byte>& master_secret); + Handshake_State* state, + Connection_Side side); - Finished(const MemoryRegion<byte>& buf) { deserialize(buf); } + Finished(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); - - MemoryVector<byte> compute_verify(const MemoryRegion<byte>& master_secret, - TLS_Handshake_Hash hash, - Connection_Side side, - Version_Code version); Connection_Side side; MemoryVector<byte> verification_data; @@ -365,10 +381,9 @@ class Hello_Request : public Handshake_Message Handshake_Type type() const { return HELLO_REQUEST; } Hello_Request(Record_Writer& writer); - Hello_Request(const MemoryRegion<byte>& buf) { deserialize(buf); } + Hello_Request(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); }; /** @@ -378,28 +393,38 @@ class Server_Key_Exchange : public Handshake_Message { public: Handshake_Type type() const { return SERVER_KEX; } - Public_Key* key() const; + + const MemoryVector<byte>& params() const { return m_params; } bool verify(const X509_Certificate& cert, - const MemoryRegion<byte>& c_random, - const MemoryRegion<byte>& s_random) const; + Handshake_State* state) const; + + // Only valid for certain kex types + const Private_Key& server_kex_key() const; Server_Key_Exchange(Record_Writer& writer, - TLS_Handshake_Hash& hash, + Handshake_State* state, + const Policy& policy, + Credentials_Manager& creds, RandomNumberGenerator& rng, - const Public_Key* kex_key, - const Private_Key* priv_key, - const MemoryRegion<byte>& c_random, - const MemoryRegion<byte>& s_random); + const Private_Key* signing_key = 0); - Server_Key_Exchange(const MemoryRegion<byte>& buf) { deserialize(buf); } + Server_Key_Exchange(const MemoryRegion<byte>& buf, + const std::string& kex_alg, + const std::string& sig_alg, + Protocol_Version version); + + ~Server_Key_Exchange() { delete m_kex_key; } private: MemoryVector<byte> serialize() const; - MemoryVector<byte> serialize_params() const; - void deserialize(const MemoryRegion<byte>&); - std::vector<BigInt> params; - MemoryVector<byte> signature; + Private_Key* m_kex_key; + + MemoryVector<byte> m_params; + + std::string m_sig_algo; // sig algo used to create signature + std::string m_hash_algo; // hash used to create signature + MemoryVector<byte> m_signature; }; /** @@ -410,11 +435,10 @@ class Server_Hello_Done : public Handshake_Message public: Handshake_Type type() const { return SERVER_HELLO_DONE; } - Server_Hello_Done(Record_Writer& writer, TLS_Handshake_Hash& hash); - Server_Hello_Done(const MemoryRegion<byte>& buf) { deserialize(buf); } + Server_Hello_Done(Record_Writer& writer, Handshake_Hash& hash); + Server_Hello_Done(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); }; /** @@ -428,13 +452,12 @@ class Next_Protocol : public Handshake_Message std::string protocol() const { return m_protocol; } Next_Protocol(Record_Writer& writer, - TLS_Handshake_Hash& hash, + Handshake_Hash& hash, const std::string& protocol); - Next_Protocol(const MemoryRegion<byte>& buf) { deserialize(buf); } + Next_Protocol(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); std::string m_protocol; }; @@ -465,4 +488,6 @@ class New_Session_Ticket : public Handshake_Message } +} + #endif diff --git a/src/tls/tls_policy.cpp b/src/tls/tls_policy.cpp index 596f5e53e..1ab55f7c6 100644 --- a/src/tls/tls_policy.cpp +++ b/src/tls/tls_policy.cpp @@ -1,88 +1,230 @@ /* * Policies for TLS -* (C) 2004-2010 Jack Lloyd +* (C) 2004-2010,2012 Jack Lloyd * * Released under the terms of the Botan license */ #include <botan/tls_policy.h> +#include <botan/tls_ciphersuite.h> +#include <botan/tls_magic.h> #include <botan/tls_exceptn.h> +#include <botan/internal/stl_util.h> + +#include <assert.h> namespace Botan { -/* -* Return allowed ciphersuites -*/ -std::vector<u16bit> TLS_Policy::ciphersuites(bool have_srp) const +namespace TLS { + +std::vector<std::string> Policy::allowed_ciphers() const { - return suite_list(allow_static_rsa(), allow_edh_rsa(), allow_edh_dsa(), - allow_srp() && have_srp); + std::vector<std::string> allowed; + + allowed.push_back("AES-256"); + allowed.push_back("AES-128"); + allowed.push_back("3DES"); + allowed.push_back("ARC4"); + + // Note that Camellia, SEED and IDEA are not included by default + + return allowed; } -/* -* Return allowed ciphersuites -*/ -std::vector<u16bit> TLS_Policy::suite_list(bool use_rsa, - bool use_edh_rsa, - bool use_edh_dsa, - bool use_srp) const +std::vector<std::string> Policy::allowed_hashes() const { - std::vector<u16bit> suites; + std::vector<std::string> allowed; - if(use_srp) - { - if(use_edh_rsa) - { - suites.push_back(TLS_SRP_SHA_DSS_WITH_AES_256_SHA); - suites.push_back(TLS_SRP_SHA_DSS_WITH_AES_128_SHA); - suites.push_back(TLS_SRP_SHA_DSS_WITH_3DES_EDE_SHA); - } + allowed.push_back("SHA-512"); + allowed.push_back("SHA-384"); + allowed.push_back("SHA-256"); + allowed.push_back("SHA-224"); + allowed.push_back("SHA-1"); + // Note that MD5 is not included by default + + return allowed; + } + +std::vector<std::string> Policy::allowed_key_exchange_methods() const + { + std::vector<std::string> allowed; + + //allowed.push_back("SRP"); + //allowed.push_back("ECDHE_PSK"); + //allowed.push_back("DHE_PSK"); + //allowed.push_back("PSK"); + allowed.push_back("ECDH"); + allowed.push_back("DH"); + allowed.push_back("RSA"); // RSA via server cert + + return allowed; + } - if(use_edh_dsa) +std::vector<std::string> Policy::allowed_signature_methods() const + { + std::vector<std::string> allowed; + + allowed.push_back("ECDSA"); + allowed.push_back("RSA"); + allowed.push_back("DSA"); + allowed.push_back(""); + + return allowed; + } + +std::vector<std::string> Policy::allowed_ecc_curves() const + { + std::vector<std::string> curves; + curves.push_back("secp521r1"); + curves.push_back("secp384r1"); + curves.push_back("secp256r1"); + curves.push_back("secp256k1"); + curves.push_back("secp224r1"); + curves.push_back("secp224k1"); + curves.push_back("secp192r1"); + curves.push_back("secp192k1"); + curves.push_back("secp160r2"); + curves.push_back("secp160r1"); + curves.push_back("secp160k1"); + return curves; + } + +Protocol_Version Policy::min_version() const + { + return Protocol_Version::SSL_V3; + } + +Protocol_Version Policy::pref_version() const + { + return Protocol_Version::TLS_V12; + } + +namespace { + +class Ciphersuite_Preference_Ordering + { + public: + Ciphersuite_Preference_Ordering(const std::vector<std::string>& ciphers, + const std::vector<std::string>& hashes, + const std::vector<std::string>& kex, + const std::vector<std::string>& sigs) : + m_ciphers(ciphers), m_hashes(hashes), m_kex(kex), m_sigs(sigs) {} + + bool operator()(const Ciphersuite& a, const Ciphersuite& b) const { - suites.push_back(TLS_SRP_SHA_RSA_WITH_AES_256_SHA); - suites.push_back(TLS_SRP_SHA_RSA_WITH_AES_128_SHA); - suites.push_back(TLS_SRP_SHA_RSA_WITH_3DES_EDE_SHA); + if(a.kex_algo() != b.kex_algo()) + { + for(size_t i = 0; i != m_kex.size(); ++i) + { + if(a.kex_algo() == m_kex[i]) + return true; + if(b.kex_algo() == m_kex[i]) + return false; + } + } + + if(a.cipher_algo() != b.cipher_algo()) + { + for(size_t i = 0; i != m_ciphers.size(); ++i) + { + if(a.cipher_algo() == m_ciphers[i]) + return true; + if(b.cipher_algo() == m_ciphers[i]) + return false; + } + } + + if(a.cipher_keylen() != b.cipher_keylen()) + { + if(a.cipher_keylen() < b.cipher_keylen()) + return false; + if(a.cipher_keylen() > b.cipher_keylen()) + return true; + } + + if(a.sig_algo() != b.sig_algo()) + { + for(size_t i = 0; i != m_sigs.size(); ++i) + { + if(a.sig_algo() == m_sigs[i]) + return true; + if(b.sig_algo() == m_sigs[i]) + return false; + } + } + + if(a.mac_algo() != b.mac_algo()) + { + for(size_t i = 0; i != m_hashes.size(); ++i) + { + if(a.mac_algo() == m_hashes[i]) + return true; + if(b.mac_algo() == m_hashes[i]) + return false; + } + } + + return false; // equal (?!?) } - } + private: + std::vector<std::string> m_ciphers, m_hashes, m_kex, m_sigs; + + }; - if(use_edh_dsa) +} + +std::vector<u16bit> Policy::ciphersuite_list(bool have_srp) const + { + std::vector<std::string> ciphers = allowed_ciphers(); + std::vector<std::string> hashes = allowed_hashes(); + std::vector<std::string> kex = allowed_key_exchange_methods(); + std::vector<std::string> sigs = allowed_signature_methods(); + + if(!have_srp) { - suites.push_back(TLS_DHE_DSS_WITH_AES_256_CBC_SHA); - suites.push_back(TLS_DHE_DSS_WITH_AES_128_CBC_SHA); - suites.push_back(TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA); - suites.push_back(TLS_DHE_DSS_WITH_SEED_CBC_SHA); + std::vector<std::string>::iterator i = + std::find(kex.begin(), kex.end(), "SRP"); + + if(i != kex.end()) + kex.erase(i); } - if(use_edh_rsa) + Ciphersuite_Preference_Ordering order(ciphers, hashes, kex, sigs); + + std::map<Ciphersuite, u16bit, Ciphersuite_Preference_Ordering> + ciphersuites(order); + + for(size_t i = 0; i != 65536; ++i) { - suites.push_back(TLS_DHE_RSA_WITH_AES_256_CBC_SHA); - suites.push_back(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); - suites.push_back(TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA); - suites.push_back(TLS_DHE_RSA_WITH_SEED_CBC_SHA); + Ciphersuite suite = Ciphersuite::by_id(i); + + if(!suite.valid()) + continue; // not a ciphersuite we know, skip + + if(value_exists(ciphers, suite.cipher_algo()) && + value_exists(hashes, suite.mac_algo()) && + value_exists(kex, suite.kex_algo()) && + value_exists(sigs, suite.sig_algo())) + { + ciphersuites[suite] = i; + } } - if(use_rsa) + std::vector<u16bit> ciphersuite_codes; + + for(std::map<Ciphersuite, u16bit, Ciphersuite_Preference_Ordering>::iterator i = ciphersuites.begin(); + i != ciphersuites.end(); ++i) { - suites.push_back(TLS_RSA_WITH_AES_256_CBC_SHA); - suites.push_back(TLS_RSA_WITH_AES_128_CBC_SHA); - suites.push_back(TLS_RSA_WITH_3DES_EDE_CBC_SHA); - suites.push_back(TLS_RSA_WITH_SEED_CBC_SHA); - suites.push_back(TLS_RSA_WITH_RC4_128_SHA); - suites.push_back(TLS_RSA_WITH_RC4_128_MD5); + ciphersuite_codes.push_back(i->second); } - if(suites.size() == 0) - throw TLS_Exception(INTERNAL_ERROR, - "TLS_Policy error: All ciphersuites disabled"); - - return suites; + return ciphersuite_codes; } /* * Return allowed compression algorithms */ -std::vector<byte> TLS_Policy::compression() const +std::vector<byte> Policy::compression() const { std::vector<byte> algs; algs.push_back(NO_COMPRESSION); @@ -90,33 +232,57 @@ std::vector<byte> TLS_Policy::compression() const } /* +* Choose an ECC curve to use +*/ +std::string Policy::choose_curve(const std::vector<std::string>& curve_names) const + { + std::vector<std::string> our_curves = allowed_ecc_curves(); + + for(size_t i = 0; i != our_curves.size(); ++i) + if(value_exists(curve_names, our_curves[i])) + return our_curves[i]; + + return ""; // no shared curve + } + +/* * Choose which ciphersuite to use */ -u16bit TLS_Policy::choose_suite(const std::vector<u16bit>& c_suites, - bool have_rsa, - bool have_dsa, - bool have_srp) const +u16bit Policy::choose_suite(const std::vector<u16bit>& client_suites, + const std::vector<std::string>& available_cert_types, + bool have_shared_ecc_curve, + bool have_srp) const { - const bool use_static_rsa = allow_static_rsa() && have_rsa; - const bool use_edh_rsa = allow_edh_rsa() && have_rsa; - const bool use_edh_dsa = allow_edh_dsa() && have_dsa; - const bool use_srp = allow_srp() && have_srp; + std::vector<u16bit> ciphersuites = ciphersuite_list(have_srp); - std::vector<u16bit> s_suites = suite_list(use_static_rsa, use_edh_rsa, - use_edh_dsa, use_srp); + for(size_t i = 0; i != ciphersuites.size(); ++i) + { + const u16bit suite_id = ciphersuites[i]; + Ciphersuite suite = Ciphersuite::by_id(suite_id); - for(size_t i = 0; i != s_suites.size(); ++i) - for(size_t j = 0; j != c_suites.size(); ++j) - if(s_suites[i] == c_suites[j]) - return s_suites[i]; + if(!have_shared_ecc_curve) + { + if(suite.kex_algo() == "ECDH" || suite.sig_algo() == "ECDSA") + continue; + } - return 0; + if(suite.sig_algo() != "" && + !value_exists(available_cert_types, suite.sig_algo())) + { + continue; + } + + if(value_exists(client_suites, suite_id)) + return suite_id; + } + + return 0; // no shared cipersuite } /* * Choose which compression algorithm to use */ -byte TLS_Policy::choose_compression(const std::vector<byte>& c_comp) const +byte Policy::choose_compression(const std::vector<byte>& c_comp) const { std::vector<byte> s_comp = compression(); @@ -128,12 +294,6 @@ byte TLS_Policy::choose_compression(const std::vector<byte>& c_comp) const return NO_COMPRESSION; } -/* -* Return the group to use for empheral DH -*/ -DL_Group TLS_Policy::dh_group() const - { - return DL_Group("modp/ietf/1024"); - } +} } diff --git a/src/tls/tls_policy.h b/src/tls/tls_policy.h index 48ff9185e..f53b9bab6 100644 --- a/src/tls/tls_policy.h +++ b/src/tls/tls_policy.h @@ -8,62 +8,119 @@ #ifndef BOTAN_TLS_POLICY_H__ #define BOTAN_TLS_POLICY_H__ -#include <botan/tls_magic.h> +#include <botan/tls_version.h> #include <botan/x509cert.h> #include <botan/dl_group.h> #include <vector> namespace Botan { +namespace TLS { + /** * TLS Policy Base Class -* Inherit and overload as desired to suite local policy concerns +* Inherit and overload as desired to suit local policy concerns */ -class BOTAN_DLL TLS_Policy +class BOTAN_DLL Policy { public: - std::vector<u16bit> ciphersuites(bool have_srp) const; - virtual std::vector<byte> compression() const; - virtual u16bit choose_suite(const std::vector<u16bit>& client_suites, - bool rsa_ok, - bool dsa_ok, - bool srp_ok) const; + /** + * Returns a list of ciphers we are willing to negotiate, in + * order of preference. Allowed values: any block cipher name, or + * ARC4. + */ + virtual std::vector<std::string> allowed_ciphers() const; + + /** + * Returns a list of hash algorithms we are willing to use, in + * order of preference. This is used for both MACs and signatures. + * Allowed values: any hash name, though currently only MD5, + * SHA-1, and the SHA-2 variants are used. + */ + virtual std::vector<std::string> allowed_hashes() const; + + /** + * Returns a list of key exchange algorithms we are willing to + * use, in order of preference. Allowed values: DH, empty string + * (representing RSA using server certificate key) + */ + virtual std::vector<std::string> allowed_key_exchange_methods() const; - virtual byte choose_compression(const std::vector<byte>& client) const; + /** + * Returns a list of signature algorithms we are willing to + * use, in order of preference. Allowed values RSA and DSA. + */ + virtual std::vector<std::string> allowed_signature_methods() const; - virtual bool allow_static_rsa() const { return true; } - virtual bool allow_edh_rsa() const { return true; } - virtual bool allow_edh_dsa() const { return true; } - virtual bool allow_srp() const { return true; } + /** + * Return list of ECC curves we are willing to use in order of preference + */ + virtual std::vector<std::string> allowed_ecc_curves() const; - virtual bool require_client_auth() const { return false; } + /** + * Returns a list of signature algorithms we are willing to use, + * in order of preference. Allowed values any value of + * Compression_Method. + */ + virtual std::vector<byte> compression() const; + /** + * Choose an elliptic curve to use + */ + virtual std::string choose_curve(const std::vector<std::string>& curve_names) const; + + /** + * Require support for RFC 5746 extensions to enable + * renegotiation. + * + * @warning Changing this to false exposes you to injected + * plaintext attacks. Read the RFC for background. + */ virtual bool require_secure_renegotiation() const { return true; } - virtual DL_Group dh_group() const; - virtual size_t rsa_export_keysize() const { return 512; } + /** + * Return the group to use for ephemeral Diffie-Hellman key agreement + */ + virtual DL_Group dh_group() const { return DL_Group("modp/ietf/1536"); } + + /** + * If this function returns false, unknown SRP/PSK identifiers + * will be rejected with an unknown_psk_identifier alert as soon + * as the non-existence is identified. Otherwise, a false + * identifier value will be used and the protocol allowed to + * proceed, causing the login to eventually fail without + * revealing that the username does not exist on this system. + */ + virtual bool hide_unknown_users() const { return false; } - /* - * @return the minimum version that we will negotiate + /** + * @return the minimum version that we are willing to negotiate */ - virtual Version_Code min_version() const { return SSL_V3; } + virtual Protocol_Version min_version() const; - /* + /** * @return the version we would prefer to negotiate */ - virtual Version_Code pref_version() const { return TLS_V11; } + virtual Protocol_Version pref_version() const; - virtual bool check_cert(const std::vector<X509_Certificate>& cert_chain) const = 0; + /** + * Return allowed ciphersuites, in order of preference + */ + std::vector<u16bit> ciphersuite_list(bool have_srp) const; + + u16bit choose_suite(const std::vector<u16bit>& client_suites, + const std::vector<std::string>& available_cert_types, + bool have_shared_ecc_curve, + bool have_srp) const; + + byte choose_compression(const std::vector<byte>& client_algos) const; - virtual ~TLS_Policy() {} - private: - virtual std::vector<u16bit> suite_list(bool use_rsa, - bool use_edh_rsa, - bool use_edh_dsa, - bool use_srp) const; + virtual ~Policy() {} }; } +} + #endif diff --git a/src/tls/tls_reader.h b/src/tls/tls_reader.h index ef36912d3..162f691aa 100644 --- a/src/tls/tls_reader.h +++ b/src/tls/tls_reader.h @@ -17,6 +17,8 @@ namespace Botan { +namespace TLS { + /** * Helper class for decoding TLS protocol messages */ @@ -26,13 +28,10 @@ class TLS_Data_Reader TLS_Data_Reader(const MemoryRegion<byte>& buf_in) : buf(buf_in), offset(0) {} - ~TLS_Data_Reader() + void assert_done() const { if(has_remaining()) - { - abort(); throw Decoding_Error("Extra bytes at end of message"); - } } size_t remaining_bytes() const @@ -155,8 +154,9 @@ class TLS_Data_Reader { if(buf.size() - offset < n) { - abort(); - throw Decoding_Error("TLS_Data_Reader: Corrupt packet"); + throw Decoding_Error("TLS_Data_Reader: Expected " + to_string(n) + + " bytes remaining, only " + to_string(buf.size()-offset) + + " left"); } } @@ -207,6 +207,18 @@ void append_tls_length_value(MemoryRegion<byte>& buf, append_tls_length_value(buf, &vals[0], vals.size(), tag_size); } +inline void append_tls_length_value(MemoryRegion<byte>& buf, + const std::string& str, + size_t tag_size) + { + append_tls_length_value(buf, + reinterpret_cast<const byte*>(&str[0]), + str.size(), + tag_size); + } + +} + } #endif diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h index c7d2d0018..fb27db5e2 100644 --- a/src/tls/tls_record.h +++ b/src/tls/tls_record.h @@ -8,7 +8,10 @@ #ifndef BOTAN_TLS_RECORDS_H__ #define BOTAN_TLS_RECORDS_H__ -#include <botan/tls_suites.h> +#include <botan/tls_ciphersuite.h> +#include <botan/tls_alert.h> +#include <botan/tls_magic.h> +#include <botan/tls_version.h> #include <botan/pipe.h> #include <botan/mac.h> #include <botan/secqueue.h> @@ -30,7 +33,9 @@ namespace Botan { -class SessionKeys; +namespace TLS { + +class Session_Keys; /** * TLS Record Writer @@ -41,13 +46,16 @@ class BOTAN_DLL Record_Writer void send(byte type, const byte input[], size_t length); void send(byte type, byte val) { send(type, &val, 1); } - void alert(Alert_Level level, Alert_Type type); + MemoryVector<byte> send(class Handshake_Message& msg); + + void send_alert(const Alert& alert); - void activate(const TLS_Cipher_Suite& suite, - const SessionKeys& keys, - Connection_Side side); + void activate(Connection_Side side, + const Ciphersuite& suite, + const Session_Keys& keys, + byte compression_method); - void set_version(Version_Code version); + void set_version(Protocol_Version version); void reset(); @@ -72,7 +80,7 @@ class BOTAN_DLL Record_Writer size_t m_block_size, m_mac_size, m_iv_size, m_max_fragment; u64bit m_seq_no; - byte m_major, m_minor; + Protocol_Version m_version; }; /** @@ -97,11 +105,12 @@ class BOTAN_DLL Record_Reader byte& msg_type, MemoryVector<byte>& msg); - void activate(const TLS_Cipher_Suite& suite, - const SessionKeys& keys, - Connection_Side side); + void activate(Connection_Side side, + const Ciphersuite& suite, + const Session_Keys& keys, + byte compression_method); - void set_version(Version_Code version); + void set_version(Protocol_Version version); void reset(); @@ -127,9 +136,11 @@ class BOTAN_DLL Record_Reader MessageAuthenticationCode* m_mac; size_t m_block_size, m_iv_size, m_max_fragment; u64bit m_seq_no; - byte m_major, m_minor; + Protocol_Version m_version; }; } +} + #endif diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp index ccba16629..eacbc02e0 100644 --- a/src/tls/tls_server.cpp +++ b/src/tls/tls_server.cpp @@ -1,37 +1,25 @@ /* * TLS Server -* (C) 2004-2011 Jack Lloyd +* (C) 2004-2011,2012 Jack Lloyd * * Released under the terms of the Botan license */ #include <botan/tls_server.h> -#include <botan/internal/tls_session_key.h> #include <botan/internal/tls_handshake_state.h> +#include <botan/internal/tls_messages.h> #include <botan/internal/stl_util.h> -#include <botan/rsa.h> -#include <botan/dh.h> +#include <botan/internal/assert.h> +#include <memory> namespace Botan { -namespace { +namespace TLS { -/* -* Choose what version to respond with -*/ -Version_Code choose_version(Version_Code client, Version_Code minimum) - { - if(client < minimum) - throw TLS_Exception(PROTOCOL_VERSION, - "Client version is unacceptable by policy"); - - if(client == SSL_V3 || client == TLS_V10 || client == TLS_V11) - return client; - return TLS_V11; - } +namespace { -bool check_for_resume(TLS_Session& session_info, - TLS_Session_Manager& session_manager, +bool check_for_resume(Session& session_info, + Session_Manager& session_manager, Client_Hello* client_hello) { MemoryVector<byte> client_session_id = client_hello->session_id(); @@ -49,7 +37,7 @@ bool check_for_resume(TLS_Session& session_info, // client didn't send original ciphersuite if(!value_exists(client_hello->ciphersuites(), - session_info.ciphersuite())) + session_info.ciphersuite_code())) return false; // client didn't send original compression method @@ -74,20 +62,40 @@ bool check_for_resume(TLS_Session& session_info, return true; } +std::map<std::string, std::vector<X509_Certificate> > +get_server_certs(const std::string& hostname, + Credentials_Manager& creds) + { + const char* cert_types[] = { "RSA", "DSA", "ECDSA", 0 }; + + std::map<std::string, std::vector<X509_Certificate> > cert_chains; + + for(size_t i = 0; cert_types[i]; ++i) + { + std::vector<X509_Certificate> certs = + creds.cert_chain_single_type(cert_types[i], "tls-server", hostname); + + if(!certs.empty()) + cert_chains[cert_types[i]] = certs; + } + + return cert_chains; + } + } /* * TLS Server Constructor */ -TLS_Server::TLS_Server(std::tr1::function<void (const byte[], size_t)> output_fn, - std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn, - std::tr1::function<bool (const TLS_Session&)> handshake_fn, - TLS_Session_Manager& session_manager, - Credentials_Manager& creds, - const TLS_Policy& policy, - RandomNumberGenerator& rng, - const std::vector<std::string>& next_protocols) : - TLS_Channel(output_fn, proc_fn, handshake_fn), +Server::Server(std::tr1::function<void (const byte[], size_t)> output_fn, + std::tr1::function<void (const byte[], size_t, Alert)> proc_fn, + std::tr1::function<bool (const Session&)> handshake_fn, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + const std::vector<std::string>& next_protocols) : + Channel(output_fn, proc_fn, handshake_fn), policy(policy), rng(rng), session_manager(session_manager), @@ -99,19 +107,19 @@ TLS_Server::TLS_Server(std::tr1::function<void (const byte[], size_t)> output_fn /* * Send a hello request to the client */ -void TLS_Server::renegotiate() +void Server::renegotiate() { if(state) return; // currently in handshake - state = new Handshake_State; + state = new Handshake_State(new Stream_Handshake_Reader); state->set_expected_next(CLIENT_HELLO); Hello_Request hello_req(writer); } -void TLS_Server::alert_notify(bool, Alert_Type type) +void Server::alert_notify(const Alert& alert) { - if(type == NO_RENEGOTIATION) + if(alert.type() == Alert::NO_RENEGOTIATION) { if(handshake_completed && state) { @@ -124,23 +132,23 @@ void TLS_Server::alert_notify(bool, Alert_Type type) /* * Split up and process handshake messages */ -void TLS_Server::read_handshake(byte rec_type, - const MemoryRegion<byte>& rec_buf) +void Server::read_handshake(byte rec_type, + const MemoryRegion<byte>& rec_buf) { if(rec_type == HANDSHAKE && !state) { - state = new Handshake_State; + state = new Handshake_State(new Stream_Handshake_Reader); state->set_expected_next(CLIENT_HELLO); } - TLS_Channel::read_handshake(rec_type, rec_buf); + Channel::read_handshake(rec_type, rec_buf); } /* * Process a handshake message */ -void TLS_Server::process_handshake_msg(Handshake_Type type, - const MemoryRegion<byte>& contents) +void Server::process_handshake_msg(Handshake_Type type, + const MemoryRegion<byte>& contents) { if(state == 0) throw Unexpected_Message("Unexpected handshake message from client"); @@ -168,15 +176,23 @@ void TLS_Server::process_handshake_msg(Handshake_Type type, m_hostname = state->client_hello->sni_hostname(); - state->version = choose_version(state->client_hello->version(), - policy.min_version()); + Protocol_Version client_version = state->client_hello->version(); + + if(client_version < policy.min_version()) + throw TLS_Exception(Alert::PROTOCOL_VERSION, + "Client version is unacceptable by policy"); + + if(client_version <= policy.pref_version()) + state->set_version(client_version); + else + state->set_version(policy.pref_version()); secure_renegotiation.update(state->client_hello); - writer.set_version(state->version); - reader.set_version(state->version); + writer.set_version(state->version()); + reader.set_version(state->version()); - TLS_Session session_info; + Session session_info; const bool resuming = check_for_resume(session_info, session_manager, state->client_hello); @@ -189,8 +205,8 @@ void TLS_Server::process_handshake_msg(Handshake_Type type, writer, state->hash, session_info.session_id(), - Version_Code(session_info.version()), - session_info.ciphersuite(), + Protocol_Version(session_info.version()), + session_info.ciphersuite_code(), session_info.compression_method(), session_info.fragment_size(), secure_renegotiation.supported(), @@ -205,21 +221,16 @@ void TLS_Server::process_handshake_msg(Handshake_Type type, writer.set_maximum_fragment_size(session_info.fragment_size()); } - state->suite = TLS_Cipher_Suite(state->server_hello->ciphersuite()); + state->suite = Ciphersuite::by_id(state->server_hello->ciphersuite()); - state->keys = SessionKeys(state->suite, state->version, - session_info.master_secret(), - state->client_hello->random(), - state->server_hello->random(), - true); + state->keys = Session_Keys(state, session_info.master_secret(), true); writer.send(CHANGE_CIPHER_SPEC, 1); - writer.activate(state->suite, state->keys, SERVER); + writer.activate(SERVER, state->suite, state->keys, + state->server_hello->compression_method()); - state->server_finished = new Finished(writer, state->hash, - state->version, SERVER, - state->keys.master_secret()); + state->server_finished = new Finished(writer, state, SERVER); if(!handshake_fn(session_info)) session_manager.remove_entry(session_info.session_id()); @@ -228,23 +239,31 @@ void TLS_Server::process_handshake_msg(Handshake_Type type, } else // new session { - std::vector<X509_Certificate> server_certs = - creds.cert_chain("", - "tls-server", - m_hostname); + std::map<std::string, std::vector<X509_Certificate> > cert_chains; + + cert_chains = get_server_certs(m_hostname, creds); - Private_Key* private_key = - server_certs.empty() ? 0 : - (creds.private_key_for(server_certs[0], - "tls-server", - m_hostname)); + if(m_hostname != "" && cert_chains.empty()) + { + send_alert(Alert(Alert::UNRECOGNIZED_NAME)); + cert_chains = get_server_certs("", creds); + } + + std::vector<std::string> available_cert_types; + + for(std::map<std::string, std::vector<X509_Certificate> >::const_iterator i = cert_chains.begin(); + i != cert_chains.end(); ++i) + { + if(!i->second.empty()) + available_cert_types.push_back(i->first); + } state->server_hello = new Server_Hello( writer, state->hash, - state->version, + state->version(), *(state->client_hello), - server_certs, + available_cert_types, policy, secure_renegotiation.supported(), secure_renegotiation.for_server_hello(), @@ -258,37 +277,53 @@ void TLS_Server::process_handshake_msg(Handshake_Type type, writer.set_maximum_fragment_size(state->client_hello->fragment_size()); } - state->suite = TLS_Cipher_Suite(state->server_hello->ciphersuite()); + state->suite = Ciphersuite::by_id(state->server_hello->ciphersuite()); - if(state->suite.sig_type() != TLS_ALGO_SIGNER_ANON) + const std::string sig_algo = state->suite.sig_algo(); + const std::string kex_algo = state->suite.kex_algo(); + + if(sig_algo != "") { + BOTAN_ASSERT(!cert_chains[sig_algo].empty(), + "Attempting to send empty certificate chain"); + state->server_certs = new Certificate(writer, state->hash, - server_certs); + cert_chains[sig_algo]); } - if(state->suite.kex_type() != TLS_ALGO_KEYEXCH_NOKEX) + Private_Key* private_key = 0; + + if(kex_algo == "RSA" || sig_algo != "") { - if(state->suite.kex_type() == TLS_ALGO_KEYEXCH_DH) - state->kex_priv = new DH_PrivateKey(rng, policy.dh_group()); - else - throw Internal_Error("TLS_Server: Unknown ciphersuite kex type"); + private_key = creds.private_key_for(state->server_certs->cert_chain()[0], + "tls-server", + m_hostname); - state->server_kex = - new Server_Key_Exchange(writer, state->hash, rng, - state->kex_priv, private_key, - state->client_hello->random(), - state->server_hello->random()); + if(!private_key) + throw Internal_Error("No private key located for associated server cert"); } - else - state->kex_priv = PKCS8::copy_key(*private_key, rng); - if(policy.require_client_auth()) + if(kex_algo == "RSA") { - // FIXME: figure out the allowed CAs/cert types + state->server_rsa_kex_key = private_key; + } + else + { + state->server_kex = + new Server_Key_Exchange(writer, state, policy, creds, rng, private_key); + } - state->cert_req = new Certificate_Req(writer, state->hash, - std::vector<X509_Certificate>()); + std::vector<X509_Certificate> client_auth_CAs = + creds.trusted_certificate_authorities("tls-server", m_hostname); + + if(!client_auth_CAs.empty() && state->suite.sig_algo() != "") + { + state->cert_req = new Certificate_Req(writer, + state->hash, + policy, + client_auth_CAs, + state->version()); state->set_expected_next(CERTIFICATE); } @@ -311,7 +346,7 @@ void TLS_Server::process_handshake_msg(Handshake_Type type, // Is this allowed by the protocol? if(state->client_certs->count() > 1) - throw TLS_Exception(CERTIFICATE_UNKNOWN, + throw TLS_Exception(Alert::CERTIFICATE_UNKNOWN, "Client sent more than one certificate"); state->set_expected_next(CLIENT_KEX); @@ -323,29 +358,19 @@ void TLS_Server::process_handshake_msg(Handshake_Type type, else state->set_expected_next(HANDSHAKE_CCS); - state->client_kex = new Client_Key_Exchange(contents, state->suite, - state->version); + state->client_kex = new Client_Key_Exchange(contents, state, creds, policy, rng); - SecureVector<byte> pre_master = - state->client_kex->pre_master_secret(rng, state->kex_priv, - state->client_hello->version()); - - state->keys = SessionKeys(state->suite, state->version, pre_master, - state->client_hello->random(), - state->server_hello->random()); + state->keys = Session_Keys(state, state->client_kex->pre_master_secret(), false); } else if(type == CERTIFICATE_VERIFY) { - state->client_verify = new Certificate_Verify(contents); + state->client_verify = new Certificate_Verify(contents, state->version()); const std::vector<X509_Certificate>& client_certs = state->client_certs->cert_chain(); const bool sig_valid = - state->client_verify->verify(client_certs[0], - state->hash, - state->server_hello->version(), - state->keys.master_secret()); + state->client_verify->verify(client_certs[0], state); state->hash.update(type, contents); @@ -355,9 +380,16 @@ void TLS_Server::process_handshake_msg(Handshake_Type type, * unable to correctly verify a signature, ..." */ if(!sig_valid) - throw TLS_Exception(DECRYPT_ERROR, "Client cert verify failed"); + throw TLS_Exception(Alert::DECRYPT_ERROR, "Client cert verify failed"); - // FIXME: check cert was issued by a CA we requested, signatures, etc. + try + { + creds.verify_certificate_chain("tls-server", "", client_certs); + } + catch(std::exception& e) + { + throw TLS_Exception(Alert::BAD_CERTIFICATE, e.what()); + } state->set_expected_next(HANDSHAKE_CCS); } @@ -368,7 +400,8 @@ void TLS_Server::process_handshake_msg(Handshake_Type type, else state->set_expected_next(FINISHED); - reader.activate(state->suite, state->keys, SERVER); + reader.activate(SERVER, state->suite, state->keys, + state->server_hello->compression_method()); } else if(type == NEXT_PROTOCOL) { @@ -384,46 +417,43 @@ void TLS_Server::process_handshake_msg(Handshake_Type type, state->client_finished = new Finished(contents); - if(!state->client_finished->verify(state->keys.master_secret(), - state->version, state->hash, CLIENT)) - throw TLS_Exception(DECRYPT_ERROR, + if(!state->client_finished->verify(state, CLIENT)) + throw TLS_Exception(Alert::DECRYPT_ERROR, "Finished message didn't verify"); - // already sent it if resuming if(!state->server_finished) { state->hash.update(type, contents); writer.send(CHANGE_CIPHER_SPEC, 1); - writer.activate(state->suite, state->keys, SERVER); + writer.activate(SERVER, state->suite, state->keys, + state->server_hello->compression_method()); - state->server_finished = new Finished(writer, state->hash, - state->version, SERVER, - state->keys.master_secret()); + state->server_finished = new Finished(writer, state, SERVER); if(state->client_certs && state->client_verify) peer_certs = state->client_certs->cert_chain(); - } - TLS_Session session_info( - state->server_hello->session_id(), - state->keys.master_secret(), - state->server_hello->version(), - state->server_hello->ciphersuite(), - state->server_hello->compression_method(), - SERVER, - secure_renegotiation.supported(), - state->server_hello->fragment_size(), - peer_certs, - m_hostname, - "" - ); - - if(handshake_fn(session_info)) - session_manager.save(session_info); - else - session_manager.remove_entry(session_info.session_id()); + // already sent finished if resuming, so this is a new session + + Session session_info( + state->server_hello->session_id(), + state->keys.master_secret(), + state->server_hello->version(), + state->server_hello->ciphersuite(), + state->server_hello->compression_method(), + SERVER, + secure_renegotiation.supported(), + state->server_hello->fragment_size(), + peer_certs, + m_hostname, + "" + ); + + if(handshake_fn(session_info)) + session_manager.save(session_info); + } secure_renegotiation.update(state->client_finished, state->server_finished); @@ -437,3 +467,5 @@ void TLS_Server::process_handshake_msg(Handshake_Type type, } } + +} diff --git a/src/tls/tls_server.h b/src/tls/tls_server.h index f8c3a8563..025bbf3ec 100644 --- a/src/tls/tls_server.h +++ b/src/tls/tls_server.h @@ -15,24 +15,26 @@ namespace Botan { +namespace TLS { + /** * TLS Server */ -class BOTAN_DLL TLS_Server : public TLS_Channel +class BOTAN_DLL Server : public Channel { public: /** - * TLS_Server initialization + * Server initialization */ - TLS_Server(std::tr1::function<void (const byte[], size_t)> socket_output_fn, - std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn, - std::tr1::function<bool (const TLS_Session&)> handshake_complete, - TLS_Session_Manager& session_manager, - Credentials_Manager& creds, - const TLS_Policy& policy, - RandomNumberGenerator& rng, - const std::vector<std::string>& protocols = - std::vector<std::string>()); + Server(std::tr1::function<void (const byte[], size_t)> socket_output_fn, + std::tr1::function<void (const byte[], size_t, Alert)> proc_fn, + std::tr1::function<bool (const Session&)> handshake_complete, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + const std::vector<std::string>& protocols = + std::vector<std::string>()); void renegotiate(); @@ -53,11 +55,11 @@ class BOTAN_DLL TLS_Server : public TLS_Channel void process_handshake_msg(Handshake_Type, const MemoryRegion<byte>&); - void alert_notify(bool is_fatal, Alert_Type type); + void alert_notify(const Alert& alert); - const TLS_Policy& policy; + const Policy& policy; RandomNumberGenerator& rng; - TLS_Session_Manager& session_manager; + Session_Manager& session_manager; Credentials_Manager& creds; std::vector<std::string> m_possible_protocols; @@ -67,4 +69,6 @@ class BOTAN_DLL TLS_Server : public TLS_Channel } +} + #endif diff --git a/src/tls/tls_session.cpp b/src/tls/tls_session.cpp index c40e9d79e..f8e686a4a 100644 --- a/src/tls/tls_session.cpp +++ b/src/tls/tls_session.cpp @@ -9,23 +9,26 @@ #include <botan/der_enc.h> #include <botan/ber_dec.h> #include <botan/asn1_str.h> +#include <botan/pem.h> #include <botan/time.h> #include <botan/lookup.h> #include <memory> namespace Botan { -TLS_Session::TLS_Session(const MemoryRegion<byte>& session_identifier, - const MemoryRegion<byte>& master_secret, - Version_Code version, - u16bit ciphersuite, - byte compression_method, - Connection_Side side, - bool secure_renegotiation_supported, - size_t fragment_size, - const std::vector<X509_Certificate>& certs, - const std::string& sni_hostname, - const std::string& srp_identifier) : +namespace TLS { + +Session::Session(const MemoryRegion<byte>& session_identifier, + const MemoryRegion<byte>& master_secret, + Protocol_Version version, + u16bit ciphersuite, + byte compression_method, + Connection_Side side, + bool secure_renegotiation_supported, + size_t fragment_size, + const std::vector<X509_Certificate>& certs, + const std::string& sni_hostname, + const std::string& srp_identifier) : m_start_time(system_time()), m_identifier(session_identifier), m_master_secret(master_secret), @@ -35,15 +38,13 @@ TLS_Session::TLS_Session(const MemoryRegion<byte>& session_identifier, m_connection_side(side), m_secure_renegotiation_supported(secure_renegotiation_supported), m_fragment_size(fragment_size), + m_peer_certs(certs), m_sni_hostname(sni_hostname), m_srp_identifier(srp_identifier) { - // FIXME: encode all of them? - if(certs.size()) - m_peer_certificate = certs[0].BER_encode(); } -TLS_Session::TLS_Session(const byte ber[], size_t ber_len) +Session::Session(const byte ber[], size_t ber_len) { BER_Decoder decoder(ber, ber_len); @@ -51,48 +52,84 @@ TLS_Session::TLS_Session(const byte ber[], size_t ber_len) ASN1_String sni_hostname_str; ASN1_String srp_identifier_str; + byte major_version = 0, minor_version = 0; + + MemoryVector<byte> peer_cert_bits; + BER_Decoder(ber, ber_len) - .decode_and_check(static_cast<size_t>(TLS_SESSION_PARAM_STRUCT_VERSION), - "Unknown version in session structure") - .decode(m_identifier, OCTET_STRING) - .decode_integer_type(m_start_time) - .decode_integer_type(m_version) - .decode_integer_type(m_ciphersuite) - .decode_integer_type(m_compression_method) - .decode_integer_type(side_code) - .decode_integer_type(m_fragment_size) - .decode(m_secure_renegotiation_supported) - .decode(m_master_secret, OCTET_STRING) - .decode(m_peer_certificate, OCTET_STRING) - .decode(sni_hostname_str) - .decode(srp_identifier_str); + .start_cons(SEQUENCE) + .decode_and_check(static_cast<size_t>(TLS_SESSION_PARAM_STRUCT_VERSION), + "Unknown version in session structure") + .decode(m_identifier, OCTET_STRING) + .decode_integer_type(m_start_time) + .decode_integer_type(major_version) + .decode_integer_type(minor_version) + .decode_integer_type(m_ciphersuite) + .decode_integer_type(m_compression_method) + .decode_integer_type(side_code) + .decode_integer_type(m_fragment_size) + .decode(m_secure_renegotiation_supported) + .decode(m_master_secret, OCTET_STRING) + .decode(peer_cert_bits, OCTET_STRING) + .decode(sni_hostname_str) + .decode(srp_identifier_str) + .end_cons() + .verify_end(); + m_version = Protocol_Version(major_version, minor_version); m_sni_hostname = sni_hostname_str.value(); m_srp_identifier = srp_identifier_str.value(); m_connection_side = static_cast<Connection_Side>(side_code); + + if(!peer_cert_bits.empty()) + { + DataSource_Memory certs(peer_cert_bits); + + while(!certs.end_of_data()) + m_peer_certs.push_back(X509_Certificate(certs)); + } } -SecureVector<byte> TLS_Session::BER_encode() const +Session::Session(const std::string& pem) { + SecureVector<byte> der = PEM_Code::decode_check_label(pem, "SSL SESSION"); + + *this = Session(&der[0], der.size()); + } + +SecureVector<byte> Session::DER_encode() const + { + MemoryVector<byte> peer_cert_bits; + for(size_t i = 0; i != m_peer_certs.size(); ++i) + peer_cert_bits += m_peer_certs[i].BER_encode(); + return DER_Encoder() .start_cons(SEQUENCE) .encode(static_cast<size_t>(TLS_SESSION_PARAM_STRUCT_VERSION)) .encode(m_identifier, OCTET_STRING) .encode(static_cast<size_t>(m_start_time)) - .encode(static_cast<size_t>(m_version)) + .encode(static_cast<size_t>(m_version.major_version())) + .encode(static_cast<size_t>(m_version.minor_version())) .encode(static_cast<size_t>(m_ciphersuite)) .encode(static_cast<size_t>(m_compression_method)) .encode(static_cast<size_t>(m_connection_side)) .encode(static_cast<size_t>(m_fragment_size)) .encode(m_secure_renegotiation_supported) .encode(m_master_secret, OCTET_STRING) - .encode(m_peer_certificate, OCTET_STRING) + .encode(peer_cert_bits, OCTET_STRING) .encode(ASN1_String(m_sni_hostname, UTF8_STRING)) .encode(ASN1_String(m_srp_identifier, UTF8_STRING)) .end_cons() .get_contents(); } +std::string Session::PEM_encode() const + { + return PEM_Code::encode(this->DER_encode(), "SSL SESSION"); + } + +} + MemoryVector<byte> TLS_Session::encrypt(const SymmetricKey& master_key, const MemoryRegion<byte>& key_name, @@ -180,5 +217,4 @@ TLS_Session TLS_Session::decrypt(const MemoryRegion<byte>& buf, } } - } diff --git a/src/tls/tls_session.h b/src/tls/tls_session.h index 4a8d50e26..40aaee278 100644 --- a/src/tls/tls_session.h +++ b/src/tls/tls_session.h @@ -9,25 +9,29 @@ #define TLS_SESSION_STATE_H__ #include <botan/x509cert.h> +#include <botan/tls_version.h> +#include <botan/tls_ciphersuite.h> #include <botan/tls_magic.h> #include <botan/secmem.h> #include <botan/symkey.h> namespace Botan { +namespace TLS { + /** * Class representing a TLS session state */ -class BOTAN_DLL TLS_Session +class BOTAN_DLL Session { public: /** * Uninitialized session */ - TLS_Session() : + Session() : m_start_time(0), - m_version(0), + m_version(), m_ciphersuite(0), m_compression_method(0), m_connection_side(static_cast<Connection_Side>(0)), @@ -38,22 +42,34 @@ class BOTAN_DLL TLS_Session /** * New session (sets session start time) */ - TLS_Session(const MemoryRegion<byte>& session_id, - const MemoryRegion<byte>& master_secret, - Version_Code version, - u16bit ciphersuite, - byte compression_method, - Connection_Side side, - bool secure_renegotiation_supported, - size_t fragment_size, - const std::vector<X509_Certificate>& peer_certs, - const std::string& sni_hostname = "", - const std::string& srp_identifier = ""); + Session(const MemoryRegion<byte>& session_id, + const MemoryRegion<byte>& master_secret, + Protocol_Version version, + u16bit ciphersuite, + byte compression_method, + Connection_Side side, + bool secure_renegotiation_supported, + size_t fragment_size, + const std::vector<X509_Certificate>& peer_certs, + const std::string& sni_hostname = "", + const std::string& srp_identifier = ""); /** - * Load a session from BER (created by BER_encode) + * Load a session from DER representation (created by DER_encode) + */ + Session(const byte ber[], size_t ber_len); + + /** + * Load a session from PEM representation (created by PEM_encode) + */ + Session(const std::string& pem); + + /** + * Encode this session data for storage + * @warning if the master secret is compromised so is the + * session traffic */ - TLS_Session(const byte ber[], size_t ber_len); + SecureVector<byte> DER_encode() const; /** * Encrypt a session (useful for serialization or session tickets) @@ -71,18 +87,22 @@ class BOTAN_DLL TLS_Session * @warning if the master secret is compromised so is the * session traffic */ - SecureVector<byte> BER_encode() const; + std::string PEM_encode() const; /** * Get the version of the saved session */ - Version_Code version() const - { return static_cast<Version_Code>(m_version); } + Protocol_Version version() const { return m_version; } /** - * Get the ciphersuite of the saved session + * Get the ciphersuite code of the saved session */ - u16bit ciphersuite() const { return m_ciphersuite; } + u16bit ciphersuite_code() const { return m_ciphersuite; } + + /** + * Get the ciphersuite info of the saved session + */ + Ciphersuite ciphersuite() const { return Ciphersuite::by_id(m_ciphersuite); } /** * Get the compression method used in the saved session @@ -129,6 +149,11 @@ class BOTAN_DLL TLS_Session { return m_secure_renegotiation_supported; } /** + * Return the certificate chain of the peer (possibly empty) + */ + std::vector<X509_Certificate> peer_certs() const { return m_peer_certs; } + + /** * Get the time this session began (seconds since Epoch) */ u64bit start_time() const { return m_start_time; } @@ -141,7 +166,7 @@ class BOTAN_DLL TLS_Session MemoryVector<byte> m_identifier; SecureVector<byte> m_master_secret; - u16bit m_version; + Protocol_Version m_version; u16bit m_ciphersuite; byte m_compression_method; Connection_Side m_connection_side; @@ -149,11 +174,13 @@ class BOTAN_DLL TLS_Session bool m_secure_renegotiation_supported; size_t m_fragment_size; - MemoryVector<byte> m_peer_certificate; // optional + std::vector<X509_Certificate> m_peer_certs; std::string m_sni_hostname; // optional std::string m_srp_identifier; // optional }; } +} + #endif diff --git a/src/tls/tls_session_key.cpp b/src/tls/tls_session_key.cpp index f0ddc4493..edd0617bc 100644 --- a/src/tls/tls_session_key.cpp +++ b/src/tls/tls_session_key.cpp @@ -6,43 +6,28 @@ */ #include <botan/internal/tls_session_key.h> +#include <botan/internal/tls_handshake_state.h> +#include <botan/internal/tls_messages.h> #include <botan/lookup.h> #include <memory> namespace Botan { -namespace { - -std::string lookup_prf_name(Version_Code version) - { - if(version == SSL_V3) - return "SSL3-PRF"; - else if(version == TLS_V10 || version == TLS_V11) - return "TLS-PRF"; - else - throw Invalid_Argument("SessionKeys: Unknown version code"); - } - -} +namespace TLS { /** -* SessionKeys Constructor +* Session_Keys Constructor */ -SessionKeys::SessionKeys(const TLS_Cipher_Suite& suite, - Version_Code version, - const MemoryRegion<byte>& pre_master_secret, - const MemoryRegion<byte>& client_random, - const MemoryRegion<byte>& server_random, - bool resuming) +Session_Keys::Session_Keys(Handshake_State* state, + const MemoryRegion<byte>& pre_master_secret, + bool resuming) { - const std::string prf_name = lookup_prf_name(version); - - const size_t mac_keylen = output_length_of(suite.mac_algo()); - const size_t cipher_keylen = suite.cipher_keylen(); + const size_t mac_keylen = output_length_of(state->suite.mac_algo()); + const size_t cipher_keylen = state->suite.cipher_keylen(); size_t cipher_ivlen = 0; - if(have_block_cipher(suite.cipher_algo())) - cipher_ivlen = block_size_of(suite.cipher_algo()); + if(have_block_cipher(state->suite.cipher_algo())) + cipher_ivlen = block_size_of(state->suite.cipher_algo()); const size_t prf_gen = 2 * (mac_keylen + cipher_keylen + cipher_ivlen); @@ -52,7 +37,7 @@ SessionKeys::SessionKeys(const TLS_Cipher_Suite& suite, const byte KEY_GEN_MAGIC[] = { 0x6B, 0x65, 0x79, 0x20, 0x65, 0x78, 0x70, 0x61, 0x6E, 0x73, 0x69, 0x6F, 0x6E }; - std::auto_ptr<KDF> prf(get_kdf(prf_name)); + std::auto_ptr<KDF> prf(state->protocol_specific_prf()); if(resuming) { @@ -62,20 +47,20 @@ SessionKeys::SessionKeys(const TLS_Cipher_Suite& suite, { SecureVector<byte> salt; - if(version != SSL_V3) + if(state->version() != Protocol_Version::SSL_V3) salt += std::make_pair(MASTER_SECRET_MAGIC, sizeof(MASTER_SECRET_MAGIC)); - salt += client_random; - salt += server_random; + salt += state->client_hello->random(); + salt += state->server_hello->random(); master_sec = prf->derive_key(48, pre_master_secret, salt); } SecureVector<byte> salt; - if(version != SSL_V3) + if(state->version() != Protocol_Version::SSL_V3) salt += std::make_pair(KEY_GEN_MAGIC, sizeof(KEY_GEN_MAGIC)); - salt += server_random; - salt += client_random; + salt += state->server_hello->random(); + salt += state->client_hello->random(); SymmetricKey keyblock = prf->derive_key(prf_gen, master_sec, salt); @@ -100,3 +85,5 @@ SessionKeys::SessionKeys(const TLS_Cipher_Suite& suite, } } + +} diff --git a/src/tls/tls_session_key.h b/src/tls/tls_session_key.h index a698dfcfc..25de56aea 100644 --- a/src/tls/tls_session_key.h +++ b/src/tls/tls_session_key.h @@ -8,16 +8,18 @@ #ifndef BOTAN_TLS_SESSION_KEYS_H__ #define BOTAN_TLS_SESSION_KEYS_H__ -#include <botan/tls_suites.h> +#include <botan/tls_ciphersuite.h> #include <botan/tls_exceptn.h> #include <botan/symkey.h> namespace Botan { +namespace TLS { + /** * TLS Session Keys */ -class SessionKeys +class Session_Keys { public: SymmetricKey client_cipher_key() const { return c_cipher; } @@ -31,14 +33,11 @@ class SessionKeys const SecureVector<byte>& master_secret() const { return master_sec; } - SessionKeys() {} + Session_Keys() {} - SessionKeys(const TLS_Cipher_Suite& suite, - Version_Code version, - const MemoryRegion<byte>& pre_master, - const MemoryRegion<byte>& client_random, - const MemoryRegion<byte>& server_random, - bool resuming = false); + Session_Keys(class Handshake_State* state, + const MemoryRegion<byte>& pre_master, + bool resuming); private: SecureVector<byte> master_sec; @@ -48,4 +47,6 @@ class SessionKeys } +} + #endif diff --git a/src/tls/tls_session_manager.cpp b/src/tls/tls_session_manager.cpp index e5ec75c88..59fc75b9f 100644 --- a/src/tls/tls_session_manager.cpp +++ b/src/tls/tls_session_manager.cpp @@ -11,10 +11,12 @@ namespace Botan { -bool TLS_Session_Manager_In_Memory::load_from_session_str( - const std::string& session_str, TLS_Session& session) +namespace TLS { + +bool Session_Manager_In_Memory::load_from_session_str( + const std::string& session_str, Session& session) { - std::map<std::string, TLS_Session>::iterator i = sessions.find(session_str); + std::map<std::string, Session>::iterator i = sessions.find(session_str); if(i == sessions.end()) return false; @@ -31,14 +33,14 @@ bool TLS_Session_Manager_In_Memory::load_from_session_str( return true; } -bool TLS_Session_Manager_In_Memory::load_from_session_id( - const MemoryRegion<byte>& session_id, TLS_Session& session) +bool Session_Manager_In_Memory::load_from_session_id( + const MemoryRegion<byte>& session_id, Session& session) { return load_from_session_str(hex_encode(session_id), session); } -bool TLS_Session_Manager_In_Memory::load_from_host_info( - const std::string& hostname, u16bit port, TLS_Session& session) +bool Session_Manager_In_Memory::load_from_host_info( + const std::string& hostname, u16bit port, Session& session) { std::map<std::string, std::string>::iterator i; @@ -59,17 +61,17 @@ bool TLS_Session_Manager_In_Memory::load_from_host_info( return false; } -void TLS_Session_Manager_In_Memory::remove_entry( +void Session_Manager_In_Memory::remove_entry( const MemoryRegion<byte>& session_id) { - std::map<std::string, TLS_Session>::iterator i = + std::map<std::string, Session>::iterator i = sessions.find(hex_encode(session_id)); if(i != sessions.end()) sessions.erase(i); } -void TLS_Session_Manager_In_Memory::save(const TLS_Session& session) +void Session_Manager_In_Memory::save(const Session& session) { if(max_sessions != 0) { @@ -90,3 +92,5 @@ void TLS_Session_Manager_In_Memory::save(const TLS_Session& session) } } + +} diff --git a/src/tls/tls_session_manager.h b/src/tls/tls_session_manager.h index 289b76a3b..c0a9996e3 100644 --- a/src/tls/tls_session_manager.h +++ b/src/tls/tls_session_manager.h @@ -13,8 +13,10 @@ namespace Botan { +namespace TLS { + /** -* TLS_Session_Manager is an interface to systems which can save +* Session_Manager is an interface to systems which can save * session parameters for supporting session resumption. * * Saving sessions is done on a best-effort basis; an implementation is @@ -22,7 +24,7 @@ namespace Botan { * * Implementations should strive to be thread safe */ -class BOTAN_DLL TLS_Session_Manager +class BOTAN_DLL Session_Manager { public: /** @@ -33,7 +35,7 @@ class BOTAN_DLL TLS_Session_Manager * @return true if session was modified */ virtual bool load_from_session_id(const MemoryRegion<byte>& session_id, - TLS_Session& session) = 0; + Session& session) = 0; /** * Try to load a saved session (client side) @@ -44,7 +46,7 @@ class BOTAN_DLL TLS_Session_Manager * @return true if session was modified */ virtual bool load_from_host_info(const std::string& hostname, u16bit port, - TLS_Session& session) = 0; + Session& session) = 0; /** * Remove this session id from the cache, if it exists @@ -59,18 +61,18 @@ class BOTAN_DLL TLS_Session_Manager * * @param session to save */ - virtual void save(const TLS_Session& session) = 0; + virtual void save(const Session& session) = 0; - virtual ~TLS_Session_Manager() {} + virtual ~Session_Manager() {} }; /** -* A simple implementation of TLS_Session_Manager that just saves +* A simple implementation of Session_Manager that just saves * values in memory, with no persistance abilities * * @todo add locking */ -class BOTAN_DLL TLS_Session_Manager_In_Memory : public TLS_Session_Manager +class BOTAN_DLL Session_Manager_In_Memory : public Session_Manager { public: /** @@ -79,32 +81,34 @@ class BOTAN_DLL TLS_Session_Manager_In_Memory : public TLS_Session_Manager * @param session_lifetime sessions are expired after this many * seconds have elapsed from initial handshake. */ - TLS_Session_Manager_In_Memory(size_t max_sessions = 1000, - size_t session_lifetime = 7200) : + Session_Manager_In_Memory(size_t max_sessions = 1000, + size_t session_lifetime = 7200) : max_sessions(max_sessions), session_lifetime(session_lifetime) {} bool load_from_session_id(const MemoryRegion<byte>& session_id, - TLS_Session& session); + Session& session); bool load_from_host_info(const std::string& hostname, u16bit port, - TLS_Session& session); + Session& session); void remove_entry(const MemoryRegion<byte>& session_id); - void save(const TLS_Session& session_data); + void save(const Session& session_data); private: bool load_from_session_str(const std::string& session_str, - TLS_Session& session); + Session& session); size_t max_sessions, session_lifetime; - std::map<std::string, TLS_Session> sessions; // hex(session_id) -> session + std::map<std::string, Session> sessions; // hex(session_id) -> session std::map<std::string, std::string> host_sessions; }; } +} + #endif diff --git a/src/tls/tls_suites.cpp b/src/tls/tls_suites.cpp deleted file mode 100644 index aff15d68f..000000000 --- a/src/tls/tls_suites.cpp +++ /dev/null @@ -1,325 +0,0 @@ -/* -* TLS Cipher Suites -* (C) 2004-2010 Jack Lloyd -* -* Released under the terms of the Botan license -*/ - -#include <botan/tls_suites.h> -#include <botan/tls_exceptn.h> - -namespace Botan { - -/** -* Convert an SSL/TLS ciphersuite to algorithm fields -*/ -TLS_Ciphersuite_Algos TLS_Cipher_Suite::lookup_ciphersuite(u16bit suite) - { - if(suite == TLS_RSA_WITH_RC4_128_MD5) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_RSA | - TLS_ALGO_KEYEXCH_NOKEX | - TLS_ALGO_MAC_MD5 | - TLS_ALGO_CIPHER_RC4_128); - - if(suite == TLS_RSA_WITH_RC4_128_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_RSA | - TLS_ALGO_KEYEXCH_NOKEX | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_RC4_128); - - if(suite == TLS_RSA_WITH_3DES_EDE_CBC_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_RSA | - TLS_ALGO_KEYEXCH_NOKEX | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_3DES_CBC); - - if(suite == TLS_RSA_WITH_AES_128_CBC_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_RSA | - TLS_ALGO_KEYEXCH_NOKEX | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_AES128_CBC); - - if(suite == TLS_RSA_WITH_AES_256_CBC_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_RSA | - TLS_ALGO_KEYEXCH_NOKEX | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_AES256_CBC); - - if(suite == TLS_RSA_WITH_SEED_CBC_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_RSA | - TLS_ALGO_KEYEXCH_NOKEX | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_SEED_CBC); - - if(suite == TLS_RSA_WITH_AES_128_CBC_SHA256) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_RSA | - TLS_ALGO_KEYEXCH_NOKEX | - TLS_ALGO_MAC_SHA256 | - TLS_ALGO_CIPHER_AES128_CBC); - - if(suite == TLS_RSA_WITH_AES_256_CBC_SHA256) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_RSA | - TLS_ALGO_KEYEXCH_NOKEX | - TLS_ALGO_MAC_SHA256 | - TLS_ALGO_CIPHER_AES256_CBC); - - if(suite == TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_DSA | - TLS_ALGO_KEYEXCH_DH | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_3DES_CBC); - - if(suite == TLS_DHE_DSS_WITH_AES_128_CBC_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_DSA | - TLS_ALGO_KEYEXCH_DH | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_AES128_CBC); - - if(suite == TLS_DHE_DSS_WITH_SEED_CBC_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_DSA | - TLS_ALGO_KEYEXCH_DH | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_SEED_CBC); - - if(suite == TLS_DHE_DSS_WITH_RC4_128_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_DSA | - TLS_ALGO_KEYEXCH_DH | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_RC4_128); - - if(suite == TLS_DHE_DSS_WITH_AES_256_CBC_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_DSA | - TLS_ALGO_KEYEXCH_DH | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_AES256_CBC); - - if(suite == TLS_DHE_DSS_WITH_AES_128_CBC_SHA256) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_DSA | - TLS_ALGO_KEYEXCH_DH | - TLS_ALGO_MAC_SHA256 | - TLS_ALGO_CIPHER_AES128_CBC); - - if(suite == TLS_DHE_DSS_WITH_AES_256_CBC_SHA256) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_DSA | - TLS_ALGO_KEYEXCH_DH | - TLS_ALGO_MAC_SHA256 | - TLS_ALGO_CIPHER_AES256_CBC); - - if(suite == TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_RSA | - TLS_ALGO_KEYEXCH_DH | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_3DES_CBC); - - if(suite == TLS_DHE_RSA_WITH_AES_128_CBC_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_RSA | - TLS_ALGO_KEYEXCH_DH | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_AES128_CBC); - - if(suite == TLS_DHE_DSS_WITH_SEED_CBC_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_RSA | - TLS_ALGO_KEYEXCH_DH | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_SEED_CBC); - - if(suite == TLS_DHE_RSA_WITH_AES_256_CBC_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_RSA | - TLS_ALGO_KEYEXCH_DH | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_AES256_CBC); - - if(suite == TLS_DHE_RSA_WITH_AES_128_CBC_SHA256) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_RSA | - TLS_ALGO_KEYEXCH_DH | - TLS_ALGO_MAC_SHA256 | - TLS_ALGO_CIPHER_AES128_CBC); - - if(suite == TLS_DHE_RSA_WITH_AES_256_CBC_SHA256) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_RSA | - TLS_ALGO_KEYEXCH_DH | - TLS_ALGO_MAC_SHA256 | - TLS_ALGO_CIPHER_AES256_CBC); - - // SRP ciphersuites - if(suite == TLS_SRP_SHA_RSA_WITH_3DES_EDE_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_RSA | - TLS_ALGO_KEYEXCH_SRP | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_3DES_CBC); - - if(suite == TLS_SRP_SHA_DSS_WITH_3DES_EDE_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_DSA | - TLS_ALGO_KEYEXCH_SRP | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_3DES_CBC); - - if(suite == TLS_SRP_SHA_RSA_WITH_AES_128_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_RSA | - TLS_ALGO_KEYEXCH_SRP | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_AES128_CBC); - - if(suite == TLS_SRP_SHA_DSS_WITH_AES_128_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_DSA | - TLS_ALGO_KEYEXCH_SRP | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_AES128_CBC); - - if(suite == TLS_SRP_SHA_RSA_WITH_AES_256_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_RSA | - TLS_ALGO_KEYEXCH_SRP | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_AES256_CBC); - - if(suite == TLS_SRP_SHA_DSS_WITH_AES_256_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_DSA | - TLS_ALGO_KEYEXCH_SRP | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_AES256_CBC); - - // ECC ciphersuites - if(suite == TLS_ECDHE_ECDSA_WITH_RC4_128_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_ECDSA | - TLS_ALGO_KEYEXCH_ECDH | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_RC4_128); - - if(suite == TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_ECDSA | - TLS_ALGO_KEYEXCH_ECDH | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_3DES_CBC); - - if(suite == TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_ECDSA | - TLS_ALGO_KEYEXCH_ECDH | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_AES128_CBC); - - if(suite == TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_ECDSA | - TLS_ALGO_KEYEXCH_ECDH | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_AES256_CBC); - - if(suite == TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_ECDSA | - TLS_ALGO_KEYEXCH_ECDH | - TLS_ALGO_MAC_SHA256 | - TLS_ALGO_CIPHER_AES128_CBC); - - if(suite == TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_ECDSA | - TLS_ALGO_KEYEXCH_ECDH | - TLS_ALGO_MAC_SHA384 | - TLS_ALGO_CIPHER_AES256_CBC); - - if(suite == TLS_ECDHE_RSA_WITH_RC4_128_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_RSA | - TLS_ALGO_KEYEXCH_ECDH | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_RC4_128); - - if(suite == TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_RSA | - TLS_ALGO_KEYEXCH_ECDH | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_3DES_CBC); - - if(suite == TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_RSA | - TLS_ALGO_KEYEXCH_ECDH | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_AES128_CBC); - - if(suite == TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_RSA | - TLS_ALGO_KEYEXCH_ECDH | - TLS_ALGO_MAC_SHA1 | - TLS_ALGO_CIPHER_AES256_CBC); - - if(suite == TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_ECDSA | - TLS_ALGO_KEYEXCH_ECDH | - TLS_ALGO_MAC_SHA256 | - TLS_ALGO_CIPHER_AES128_CBC); - - if(suite == TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384) - return TLS_Ciphersuite_Algos(TLS_ALGO_SIGNER_ECDSA | - TLS_ALGO_KEYEXCH_ECDH | - TLS_ALGO_MAC_SHA384 | - TLS_ALGO_CIPHER_AES256_CBC); - - return TLS_Ciphersuite_Algos(0); - } - -namespace { - -std::pair<std::string, size_t> cipher_code_to_name(TLS_Ciphersuite_Algos algo) - { - if((algo & TLS_ALGO_CIPHER_MASK) == TLS_ALGO_CIPHER_RC4_128) - return std::make_pair("ARC4", 16); - - if((algo & TLS_ALGO_CIPHER_MASK) == TLS_ALGO_CIPHER_3DES_CBC) - return std::make_pair("3DES", 24); - - if((algo & TLS_ALGO_CIPHER_MASK) == TLS_ALGO_CIPHER_AES128_CBC) - return std::make_pair("AES-128", 16); - - if((algo & TLS_ALGO_CIPHER_MASK) == TLS_ALGO_CIPHER_AES256_CBC) - return std::make_pair("AES-256", 32); - - if((algo & TLS_ALGO_CIPHER_MASK) == TLS_ALGO_CIPHER_SEED_CBC) - return std::make_pair("SEED", 16); - - throw TLS_Exception(INTERNAL_ERROR, - "TLS_Cipher_Suite: Unknown cipher type " + to_string(algo)); - } - -std::string mac_code_to_name(TLS_Ciphersuite_Algos algo) - { - if((algo & TLS_ALGO_MAC_MASK) == TLS_ALGO_MAC_MD5) - return "MD5"; - - if((algo & TLS_ALGO_MAC_MASK) == TLS_ALGO_MAC_SHA1) - return "SHA-1"; - - if((algo & TLS_ALGO_MAC_MASK) == TLS_ALGO_MAC_SHA256) - return "SHA-256"; - - if((algo & TLS_ALGO_MAC_MASK) == TLS_ALGO_MAC_SHA384) - return "SHA-384"; - - throw TLS_Exception(INTERNAL_ERROR, - "TLS_Cipher_Suite: Unknown MAC type " + to_string(algo)); - } - -} - -/** -* TLS_Cipher_Suite Constructor -*/ -TLS_Cipher_Suite::TLS_Cipher_Suite(u16bit suite_code) - { - if(suite_code == 0) - return; - - TLS_Ciphersuite_Algos algos = lookup_ciphersuite(suite_code); - - if(algos == 0) - throw Invalid_Argument("Unknown ciphersuite: " + to_string(suite_code)); - - sig_algo = TLS_Ciphersuite_Algos(algos & TLS_ALGO_SIGNER_MASK); - - kex_algo = TLS_Ciphersuite_Algos(algos & TLS_ALGO_KEYEXCH_MASK); - - std::pair<std::string, size_t> cipher_info = cipher_code_to_name(algos); - - cipher = cipher_info.first; - cipher_key_length = cipher_info.second; - - mac = mac_code_to_name(algos); - } - -} diff --git a/src/tls/tls_suites.h b/src/tls/tls_suites.h deleted file mode 100644 index 3256dc198..000000000 --- a/src/tls/tls_suites.h +++ /dev/null @@ -1,42 +0,0 @@ -/* -* TLS Cipher Suites -* (C) 2004-2011 Jack Lloyd -* -* Released under the terms of the Botan license -*/ - -#ifndef BOTAN_TLS_CIPHER_SUITES_H__ -#define BOTAN_TLS_CIPHER_SUITES_H__ - -#include <botan/types.h> -#include <botan/tls_magic.h> -#include <string> - -namespace Botan { - -/** -* Ciphersuite Information -*/ -class BOTAN_DLL TLS_Cipher_Suite - { - public: - static TLS_Ciphersuite_Algos lookup_ciphersuite(u16bit suite); - - std::string cipher_algo() const { return cipher; } - std::string mac_algo() const { return mac; } - - size_t cipher_keylen() const { return cipher_key_length; } - - TLS_Ciphersuite_Algos kex_type() const { return kex_algo; } - TLS_Ciphersuite_Algos sig_type() const { return sig_algo; } - - TLS_Cipher_Suite(u16bit ciphersuite_code = 0); - private: - TLS_Ciphersuite_Algos kex_algo, sig_algo; - std::string cipher, mac; - size_t cipher_key_length; - }; - -} - -#endif diff --git a/src/tls/tls_version.cpp b/src/tls/tls_version.cpp new file mode 100644 index 000000000..4445998eb --- /dev/null +++ b/src/tls/tls_version.cpp @@ -0,0 +1,33 @@ +/* +* TLS Protocol Version Management +* (C) 2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#include <botan/tls_version.h> +#include <botan/parsing.h> + +namespace Botan { + +namespace TLS { + +std::string Protocol_Version::to_string() const + { + const byte maj = major_version(); + const byte min = minor_version(); + + // Some very new or very old protocol? + if(maj != 3) + return "Protocol " + Botan::to_string(maj) + "." + Botan::to_string(min); + + if(maj == 3 && min == 0) + return "SSL v3"; + + // The TLS v1.[0123...] case + return "TLS v1." + Botan::to_string(min-1); + } + +} + +} diff --git a/src/tls/tls_version.h b/src/tls/tls_version.h new file mode 100644 index 000000000..aa689b300 --- /dev/null +++ b/src/tls/tls_version.h @@ -0,0 +1,87 @@ +/* +* TLS Protocol Version Management +* (C) 2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#ifndef BOTAN_TLS_PROTOCOL_VERSION_H__ +#define BOTAN_TLS_PROTOCOL_VERSION_H__ + +#include <botan/get_byte.h> +#include <string> + +namespace Botan { + +namespace TLS { + +class BOTAN_DLL Protocol_Version + { + public: + enum Version_Code { + SSL_V3 = 0x0300, + TLS_V10 = 0x0301, + TLS_V11 = 0x0302, + TLS_V12 = 0x0303 + }; + + Protocol_Version() : m_version(0) {} + + Protocol_Version(Version_Code named_version) : + m_version(static_cast<u16bit>(named_version)) {} + + Protocol_Version(byte major, byte minor) : + m_version((static_cast<u16bit>(major) << 8) | minor) {} + + /** + * Get the major version of the protocol version + */ + byte major_version() const { return get_byte(0, m_version); } + + /** + * Get the minor version of the protocol version + */ + byte minor_version() const { return get_byte(1, m_version); } + + bool operator==(const Protocol_Version& other) const + { + return (m_version == other.m_version); + } + + bool operator!=(const Protocol_Version& other) const + { + return (m_version != other.m_version); + } + + bool operator>=(const Protocol_Version& other) const + { + return (m_version >= other.m_version); + } + + bool operator>(const Protocol_Version& other) const + { + return (m_version > other.m_version); + } + + bool operator<=(const Protocol_Version& other) const + { + return (m_version <= other.m_version); + } + + bool operator<(const Protocol_Version& other) const + { + return (m_version < other.m_version); + } + + std::string to_string() const; + + private: + u16bit m_version; + }; + +} + +} + +#endif + |