diff options
Diffstat (limited to 'lib/tls/tls_handshake_state.cpp')
-rw-r--r-- | lib/tls/tls_handshake_state.cpp | 442 |
1 files changed, 442 insertions, 0 deletions
diff --git a/lib/tls/tls_handshake_state.cpp b/lib/tls/tls_handshake_state.cpp new file mode 100644 index 000000000..84b22cc09 --- /dev/null +++ b/lib/tls/tls_handshake_state.cpp @@ -0,0 +1,442 @@ +/* +* TLS Handshaking +* (C) 2004-2006,2011,2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#include <botan/internal/tls_handshake_state.h> +#include <botan/internal/tls_messages.h> +#include <botan/internal/tls_record.h> +#include <botan/lookup.h> + +namespace Botan { + +namespace TLS { + +namespace { + +u32bit bitmask_for_handshake_type(Handshake_Type type) + { + switch(type) + { + case HELLO_VERIFY_REQUEST: + return (1 << 0); + + case HELLO_REQUEST: + return (1 << 1); + + /* + * Same code point for both client hello styles + */ + case CLIENT_HELLO: + case CLIENT_HELLO_SSLV2: + return (1 << 2); + + case SERVER_HELLO: + return (1 << 3); + + case CERTIFICATE: + return (1 << 4); + + case CERTIFICATE_URL: + return (1 << 5); + + case CERTIFICATE_STATUS: + return (1 << 6); + + case SERVER_KEX: + return (1 << 7); + + case CERTIFICATE_REQUEST: + return (1 << 8); + + case SERVER_HELLO_DONE: + return (1 << 9); + + case CERTIFICATE_VERIFY: + return (1 << 10); + + case CLIENT_KEX: + return (1 << 11); + + case NEXT_PROTOCOL: + return (1 << 12); + + case NEW_SESSION_TICKET: + return (1 << 13); + + case HANDSHAKE_CCS: + return (1 << 14); + + case FINISHED: + return (1 << 15); + + // allow explicitly disabling new handshakes + case HANDSHAKE_NONE: + return 0; + } + + throw Internal_Error("Unknown handshake type " + std::to_string(type)); + } + +} + +/* +* Initialize the SSL/TLS Handshake State +*/ +Handshake_State::Handshake_State(Handshake_IO* io, + std::function<void (const Handshake_Message&)> msg_callback) : + m_msg_callback(msg_callback), + m_handshake_io(io), + m_version(m_handshake_io->initial_record_version()) + { + } + +Handshake_State::~Handshake_State() {} + +void Handshake_State::hello_verify_request(const Hello_Verify_Request& hello_verify) + { + note_message(hello_verify); + + m_client_hello->update_hello_cookie(hello_verify); + hash().reset(); + hash().update(handshake_io().send(*m_client_hello)); + note_message(*m_client_hello); + } + +void Handshake_State::client_hello(Client_Hello* client_hello) + { + m_client_hello.reset(client_hello); + note_message(*m_client_hello); + } + +void Handshake_State::server_hello(Server_Hello* server_hello) + { + m_server_hello.reset(server_hello); + m_ciphersuite = Ciphersuite::by_id(m_server_hello->ciphersuite()); + note_message(*m_server_hello); + } + +void Handshake_State::server_certs(Certificate* server_certs) + { + m_server_certs.reset(server_certs); + note_message(*m_server_certs); + } + +void Handshake_State::server_kex(Server_Key_Exchange* server_kex) + { + m_server_kex.reset(server_kex); + note_message(*m_server_kex); + } + +void Handshake_State::cert_req(Certificate_Req* cert_req) + { + m_cert_req.reset(cert_req); + note_message(*m_cert_req); + } + +void Handshake_State::server_hello_done(Server_Hello_Done* server_hello_done) + { + m_server_hello_done.reset(server_hello_done); + note_message(*m_server_hello_done); + } + +void Handshake_State::client_certs(Certificate* client_certs) + { + m_client_certs.reset(client_certs); + note_message(*m_client_certs); + } + +void Handshake_State::client_kex(Client_Key_Exchange* client_kex) + { + m_client_kex.reset(client_kex); + note_message(*m_client_kex); + } + +void Handshake_State::client_verify(Certificate_Verify* client_verify) + { + m_client_verify.reset(client_verify); + note_message(*m_client_verify); + } + +void Handshake_State::next_protocol(Next_Protocol* next_protocol) + { + m_next_protocol.reset(next_protocol); + note_message(*m_next_protocol); + } + +void Handshake_State::new_session_ticket(New_Session_Ticket* new_session_ticket) + { + m_new_session_ticket.reset(new_session_ticket); + note_message(*m_new_session_ticket); + } + +void Handshake_State::server_finished(Finished* server_finished) + { + m_server_finished.reset(server_finished); + note_message(*m_server_finished); + } + +void Handshake_State::client_finished(Finished* client_finished) + { + m_client_finished.reset(client_finished); + note_message(*m_client_finished); + } + +void Handshake_State::set_version(const Protocol_Version& version) + { + m_version = version; + } + +void Handshake_State::compute_session_keys() + { + m_session_keys = Session_Keys(this, client_kex()->pre_master_secret(), false); + } + +void Handshake_State::compute_session_keys(const secure_vector<byte>& resume_master_secret) + { + m_session_keys = Session_Keys(this, resume_master_secret, true); + } + +void Handshake_State::confirm_transition_to(Handshake_Type handshake_msg) + { + const u32bit mask = bitmask_for_handshake_type(handshake_msg); + + m_hand_received_mask |= mask; + + const bool ok = (m_hand_expecting_mask & mask); // overlap? + + if(!ok) + throw Unexpected_Message("Unexpected state transition in handshake, got " + + std::to_string(handshake_msg) + + " expected " + std::to_string(m_hand_expecting_mask) + + " received " + std::to_string(m_hand_received_mask)); + + /* We don't know what to expect next, so force a call to + set_expected_next; if it doesn't happen, the next transition + check will always fail which is what we want. + */ + m_hand_expecting_mask = 0; + } + +void Handshake_State::set_expected_next(Handshake_Type handshake_msg) + { + m_hand_expecting_mask |= bitmask_for_handshake_type(handshake_msg); + } + +bool Handshake_State::received_handshake_msg(Handshake_Type handshake_msg) const + { + const u32bit mask = bitmask_for_handshake_type(handshake_msg); + + return (m_hand_received_mask & mask); + } + +std::pair<Handshake_Type, std::vector<byte>> +Handshake_State::get_next_handshake_msg() + { + const bool expecting_ccs = + (bitmask_for_handshake_type(HANDSHAKE_CCS) & m_hand_expecting_mask); + + return m_handshake_io->get_next_record(expecting_ccs); + } + +std::string Handshake_State::srp_identifier() const + { + if(ciphersuite().valid() && ciphersuite().kex_algo() == "SRP_SHA") + return client_hello()->srp_identifier(); + + return ""; + } + +std::vector<byte> Handshake_State::session_ticket() const + { + if(new_session_ticket() && !new_session_ticket()->ticket().empty()) + return new_session_ticket()->ticket(); + + return client_hello()->session_ticket(); + } + +KDF* Handshake_State::protocol_specific_prf() const + { + if(version() == Protocol_Version::SSL_V3) + { + return get_kdf("SSL3-PRF"); + } + else if(version().supports_ciphersuite_specific_prf()) + { + const std::string prf_algo = ciphersuite().prf_algo(); + + if(prf_algo == "MD5" || prf_algo == "SHA-1") + return get_kdf("TLS-12-PRF(SHA-256)"); + + return get_kdf("TLS-12-PRF(" + prf_algo + ")"); + } + else + { + // TLS v1.0, v1.1 and DTLS v1.0 + return get_kdf("TLS-PRF"); + } + + throw Internal_Error("Unknown version code " + version().to_string()); + } + +namespace { + +std::string choose_hash(const std::string& sig_algo, + Protocol_Version negotiated_version, + const Policy& policy, + bool for_client_auth, + const Client_Hello* client_hello, + const Certificate_Req* cert_req) + { + if(!negotiated_version.supports_negotiable_signature_algorithms()) + { + if(for_client_auth && negotiated_version == Protocol_Version::SSL_V3) + return "Raw"; + + if(sig_algo == "RSA") + return "Parallel(MD5,SHA-160)"; + + if(sig_algo == "DSA") + return "SHA-1"; + + if(sig_algo == "ECDSA") + return "SHA-1"; + + throw Internal_Error("Unknown TLS signature algo " + sig_algo); + } + + const auto supported_algos = for_client_auth ? + cert_req->supported_algos() : + client_hello->supported_algos(); + + if(!supported_algos.empty()) + { + const auto hashes = policy.allowed_signature_hashes(); + + /* + * Choose our most preferred hash that the counterparty supports + * in pairing with the signature algorithm we want to use. + */ + for(auto hash : hashes) + { + for(auto algo : supported_algos) + { + if(algo.first == hash && algo.second == sig_algo) + return hash; + } + } + } + + // TLS v1.2 default hash if the counterparty sent nothing + return "SHA-1"; + } + +} + +std::pair<std::string, Signature_Format> +Handshake_State::choose_sig_format(const Private_Key& key, + std::string& hash_algo_out, + std::string& sig_algo_out, + bool for_client_auth, + const Policy& policy) const + { + const std::string sig_algo = key.algo_name(); + + const std::string hash_algo = + choose_hash(sig_algo, + this->version(), + policy, + for_client_auth, + client_hello(), + cert_req()); + + if(this->version().supports_negotiable_signature_algorithms()) + { + hash_algo_out = hash_algo; + sig_algo_out = sig_algo; + } + + if(sig_algo == "RSA") + { + const std::string padding = "EMSA3(" + hash_algo + ")"; + + return std::make_pair(padding, IEEE_1363); + } + else if(sig_algo == "DSA" || sig_algo == "ECDSA") + { + const std::string padding = "EMSA1(" + hash_algo + ")"; + + return std::make_pair(padding, DER_SEQUENCE); + } + + throw Invalid_Argument(sig_algo + " is invalid/unknown for TLS signatures"); + } + +std::pair<std::string, Signature_Format> +Handshake_State::understand_sig_format(const Public_Key& key, + std::string hash_algo, + std::string sig_algo, + bool for_client_auth) const + { + const std::string algo_name = key.algo_name(); + + /* + FIXME: This should check what was sent against the client hello + preferences, or the certificate request, to ensure it was allowed + by those restrictions. + + Or not? + */ + + if(this->version().supports_negotiable_signature_algorithms()) + { + if(hash_algo == "") + throw Decoding_Error("Counterparty did not send hash/sig IDS"); + + if(sig_algo != algo_name) + throw Decoding_Error("Counterparty sent inconsistent key and sig types"); + } + else + { + if(hash_algo != "" || sig_algo != "") + throw Decoding_Error("Counterparty sent hash/sig IDs with old version"); + } + + if(algo_name == "RSA") + { + if(for_client_auth && this->version() == Protocol_Version::SSL_V3) + { + hash_algo = "Raw"; + } + else if(!this->version().supports_negotiable_signature_algorithms()) + { + hash_algo = "Parallel(MD5,SHA-160)"; + } + + const std::string padding = "EMSA3(" + hash_algo + ")"; + return std::make_pair(padding, IEEE_1363); + } + else if(algo_name == "DSA" || algo_name == "ECDSA") + { + if(algo_name == "DSA" && for_client_auth && this->version() == Protocol_Version::SSL_V3) + { + hash_algo = "Raw"; + } + else if(!this->version().supports_negotiable_signature_algorithms()) + { + hash_algo = "SHA-1"; + } + + const std::string padding = "EMSA1(" + hash_algo + ")"; + + return std::make_pair(padding, DER_SEQUENCE); + } + + throw Invalid_Argument(algo_name + " is invalid/unknown for TLS signatures"); + } + +} + +} |