diff options
author | lloyd <[email protected]> | 2012-09-06 17:51:06 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2012-09-06 17:51:06 +0000 |
commit | e5ab9f9ecf9f76d132815e7e89814727c81d7294 (patch) | |
tree | 46b8091307cc51f1d7bbbe8d526d272c7b7f7d2d | |
parent | e0029df2f95364f21a9538b493e24661e54efa21 (diff) |
Pass process_handshake_msg a reference to the Handshake_State
-rw-r--r-- | src/tls/msg_cert_verify.cpp | 26 | ||||
-rw-r--r-- | src/tls/msg_client_kex.cpp | 38 | ||||
-rw-r--r-- | src/tls/msg_finished.cpp | 20 | ||||
-rw-r--r-- | src/tls/msg_server_kex.cpp | 28 | ||||
-rw-r--r-- | src/tls/tls_channel.cpp | 5 | ||||
-rw-r--r-- | src/tls/tls_channel.h | 11 | ||||
-rw-r--r-- | src/tls/tls_client.cpp | 228 | ||||
-rw-r--r-- | src/tls/tls_client.h | 3 | ||||
-rw-r--r-- | src/tls/tls_messages.h | 16 | ||||
-rw-r--r-- | src/tls/tls_server.cpp | 235 | ||||
-rw-r--r-- | src/tls/tls_server.h | 4 |
11 files changed, 307 insertions, 307 deletions
diff --git a/src/tls/msg_cert_verify.cpp b/src/tls/msg_cert_verify.cpp index 2d283edca..18d851e53 100644 --- a/src/tls/msg_cert_verify.cpp +++ b/src/tls/msg_cert_verify.cpp @@ -20,7 +20,7 @@ namespace TLS { * Create a new Certificate Verify message */ Certificate_Verify::Certificate_Verify(Handshake_IO& io, - Handshake_State* state, + Handshake_State& state, const Policy& policy, RandomNumberGenerator& rng, const Private_Key* priv_key) @@ -28,14 +28,14 @@ Certificate_Verify::Certificate_Verify(Handshake_IO& io, BOTAN_ASSERT_NONNULL(priv_key); std::pair<std::string, Signature_Format> format = - state->choose_sig_format(priv_key, m_hash_algo, m_sig_algo, true, policy); + state.choose_sig_format(priv_key, m_hash_algo, m_sig_algo, true, policy); PK_Signer signer(*priv_key, format.first, format.second); - if(state->version() == Protocol_Version::SSL_V3) + if(state.version() == Protocol_Version::SSL_V3) { - secure_vector<byte> md5_sha = state->hash().final_ssl3( - state->session_keys().master_secret()); + secure_vector<byte> md5_sha = state.hash().final_ssl3( + state.session_keys().master_secret()); if(priv_key->algo_name() == "DSA") m_signature = signer.sign_message(&md5_sha[16], md5_sha.size()-16, rng); @@ -44,10 +44,10 @@ Certificate_Verify::Certificate_Verify(Handshake_IO& io, } else { - m_signature = signer.sign_message(state->hash().get_contents(), rng); + m_signature = signer.sign_message(state.hash().get_contents(), rng); } - state->hash().update(io.send(*this)); + state.hash().update(io.send(*this)); } /* @@ -92,25 +92,25 @@ std::vector<byte> Certificate_Verify::serialize() const * Verify a Certificate Verify message */ bool Certificate_Verify::verify(const X509_Certificate& cert, - const Handshake_State* state) const + const Handshake_State& state) const { std::unique_ptr<Public_Key> key(cert.subject_public_key()); std::pair<std::string, Signature_Format> format = - state->understand_sig_format(key.get(), m_hash_algo, m_sig_algo, true); + state.understand_sig_format(key.get(), m_hash_algo, m_sig_algo, true); PK_Verifier verifier(*key, format.first, format.second); - if(state->version() == Protocol_Version::SSL_V3) + if(state.version() == Protocol_Version::SSL_V3) { - secure_vector<byte> md5_sha = state->hash().final_ssl3( - state->session_keys().master_secret()); + secure_vector<byte> md5_sha = state.hash().final_ssl3( + state.session_keys().master_secret()); return verifier.verify_message(&md5_sha[16], md5_sha.size()-16, &m_signature[0], m_signature.size()); } - return verifier.verify_message(state->hash().get_contents(), m_signature); + return verifier.verify_message(state.hash().get_contents(), m_signature); } } diff --git a/src/tls/msg_client_kex.cpp b/src/tls/msg_client_kex.cpp index 22cad4e5c..d129969a9 100644 --- a/src/tls/msg_client_kex.cpp +++ b/src/tls/msg_client_kex.cpp @@ -48,26 +48,26 @@ secure_vector<byte> strip_leading_zeros(const secure_vector<byte>& input) * Create a new Client Key Exchange message */ Client_Key_Exchange::Client_Key_Exchange(Handshake_IO& io, - Handshake_State* state, + Handshake_State& state, const Policy& policy, Credentials_Manager& creds, const std::vector<X509_Certificate>& peer_certs, const std::string& hostname, RandomNumberGenerator& rng) { - const std::string kex_algo = state->ciphersuite().kex_algo(); + const std::string kex_algo = state.ciphersuite().kex_algo(); if(kex_algo == "PSK") { std::string identity_hint = ""; - if(state->server_kex()) + if(state.server_kex()) { - TLS_Data_Reader reader(state->server_kex()->params()); + TLS_Data_Reader reader(state.server_kex()->params()); identity_hint = reader.get_string(2, 0, 65535); } - const std::string hostname = state->client_hello()->sni_hostname(); + const std::string hostname = state.client_hello()->sni_hostname(); const std::string psk_identity = creds.psk_identity("tls-client", hostname, @@ -82,9 +82,9 @@ Client_Key_Exchange::Client_Key_Exchange(Handshake_IO& io, append_tls_length_value(m_pre_master, zeros, 2); append_tls_length_value(m_pre_master, psk.bits_of(), 2); } - else if(state->server_kex()) + else if(state.server_kex()) { - TLS_Data_Reader reader(state->server_kex()->params()); + TLS_Data_Reader reader(state.server_kex()->params()); SymmetricKey psk; @@ -92,7 +92,7 @@ Client_Key_Exchange::Client_Key_Exchange(Handshake_IO& io, { std::string identity_hint = reader.get_string(2, 0, 65535); - const std::string hostname = state->client_hello()->sni_hostname(); + const std::string hostname = state.client_hello()->sni_hostname(); const std::string psk_identity = creds.psk_identity("tls-client", hostname, @@ -239,7 +239,7 @@ Client_Key_Exchange::Client_Key_Exchange(Handshake_IO& io, if(const RSA_PublicKey* rsa_pub = dynamic_cast<const RSA_PublicKey*>(pub_key.get())) { - const Protocol_Version offered_version = state->client_hello()->version(); + const Protocol_Version offered_version = state.client_hello()->version(); m_pre_master = rng.random_vec(48); m_pre_master[0] = offered_version.major_version(); @@ -249,7 +249,7 @@ Client_Key_Exchange::Client_Key_Exchange(Handshake_IO& io, std::vector<byte> encrypted_key = encryptor.encrypt(m_pre_master, rng); - if(state->version() == Protocol_Version::SSL_V3) + if(state.version() == Protocol_Version::SSL_V3) m_key_material = encrypted_key; // no length field else append_tls_length_value(m_key_material, encrypted_key, 2); @@ -260,24 +260,24 @@ Client_Key_Exchange::Client_Key_Exchange(Handshake_IO& io, pub_key->algo_name()); } - state->hash().update(io.send(*this)); + state.hash().update(io.send(*this)); } /* * Read a Client Key Exchange message */ Client_Key_Exchange::Client_Key_Exchange(const std::vector<byte>& contents, - const Handshake_State* state, + const Handshake_State& state, const Private_Key* server_rsa_kex_key, Credentials_Manager& creds, const Policy& policy, RandomNumberGenerator& rng) { - const std::string kex_algo = state->ciphersuite().kex_algo(); + const std::string kex_algo = state.ciphersuite().kex_algo(); if(kex_algo == "RSA") { - BOTAN_ASSERT(state->server_certs() && !state->server_certs()->cert_chain().empty(), + BOTAN_ASSERT(state.server_certs() && !state.server_certs()->cert_chain().empty(), "RSA key exchange negotiated so server sent a certificate"); if(!server_rsa_kex_key) @@ -288,11 +288,11 @@ Client_Key_Exchange::Client_Key_Exchange(const std::vector<byte>& contents, PK_Decryptor_EME decryptor(*server_rsa_kex_key, "PKCS1v15"); - Protocol_Version client_version = state->client_hello()->version(); + Protocol_Version client_version = state.client_hello()->version(); try { - if(state->version() == Protocol_Version::SSL_V3) + if(state.version() == Protocol_Version::SSL_V3) { m_pre_master = decryptor.decrypt(contents); } @@ -328,7 +328,7 @@ Client_Key_Exchange::Client_Key_Exchange(const std::vector<byte>& contents, const std::string psk_identity = reader.get_string(2, 0, 65535); psk = creds.psk("tls-server", - state->client_hello()->sni_hostname(), + state.client_hello()->sni_hostname(), psk_identity); if(psk.length() == 0) @@ -349,14 +349,14 @@ Client_Key_Exchange::Client_Key_Exchange(const std::vector<byte>& contents, } else if(kex_algo == "SRP_SHA") { - SRP6_Server_Session& srp = state->server_kex()->server_srp_params(); + SRP6_Server_Session& srp = state.server_kex()->server_srp_params(); m_pre_master = srp.step2(BigInt::decode(reader.get_range<byte>(2, 0, 65535))).bits_of(); } else if(kex_algo == "DH" || kex_algo == "DHE_PSK" || kex_algo == "ECDH" || kex_algo == "ECDHE_PSK") { - const Private_Key& private_key = state->server_kex()->server_kex_key(); + const Private_Key& private_key = state.server_kex()->server_kex_key(); const PK_Key_Agreement_Key* ka_key = dynamic_cast<const PK_Key_Agreement_Key*>(&private_key); diff --git a/src/tls/msg_finished.cpp b/src/tls/msg_finished.cpp index 390f05300..059ed8363 100644 --- a/src/tls/msg_finished.cpp +++ b/src/tls/msg_finished.cpp @@ -18,15 +18,15 @@ namespace { /* * Compute the verify_data */ -std::vector<byte> finished_compute_verify(const Handshake_State* state, +std::vector<byte> finished_compute_verify(const Handshake_State& state, Connection_Side side) { - if(state->version() == Protocol_Version::SSL_V3) + if(state.version() == Protocol_Version::SSL_V3) { const byte SSL_CLIENT_LABEL[] = { 0x43, 0x4C, 0x4E, 0x54 }; const byte SSL_SERVER_LABEL[] = { 0x53, 0x52, 0x56, 0x52 }; - Handshake_Hash hash = state->hash(); // don't modify state + Handshake_Hash hash = state.hash(); // don't modify state std::vector<byte> ssl3_finished; @@ -35,7 +35,7 @@ std::vector<byte> finished_compute_verify(const Handshake_State* state, else hash.update(SSL_SERVER_LABEL, sizeof(SSL_SERVER_LABEL)); - return unlock(hash.final_ssl3(state->session_keys().master_secret())); + return unlock(hash.final_ssl3(state.session_keys().master_secret())); } else { @@ -47,7 +47,7 @@ std::vector<byte> finished_compute_verify(const Handshake_State* state, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x20, 0x66, 0x69, 0x6E, 0x69, 0x73, 0x68, 0x65, 0x64 }; - std::unique_ptr<KDF> prf(state->protocol_specific_prf()); + std::unique_ptr<KDF> prf(state.protocol_specific_prf()); std::vector<byte> input; if(side == CLIENT) @@ -55,9 +55,9 @@ std::vector<byte> finished_compute_verify(const Handshake_State* state, else input += std::make_pair(TLS_SERVER_LABEL, sizeof(TLS_SERVER_LABEL)); - input += state->hash().final(state->version(), state->ciphersuite().mac_algo()); + input += state.hash().final(state.version(), state.ciphersuite().mac_algo()); - return unlock(prf->derive_key(12, state->session_keys().master_secret(), input)); + return unlock(prf->derive_key(12, state.session_keys().master_secret(), input)); } } @@ -67,11 +67,11 @@ std::vector<byte> finished_compute_verify(const Handshake_State* state, * Create a new Finished message */ Finished::Finished(Handshake_IO& io, - Handshake_State* state, + Handshake_State& state, Connection_Side side) { m_verification_data = finished_compute_verify(state, side); - state->hash().update(io.send(*this)); + state.hash().update(io.send(*this)); } /* @@ -93,7 +93,7 @@ Finished::Finished(const std::vector<byte>& buf) /* * Verify a Finished message */ -bool Finished::verify(const Handshake_State* state, +bool Finished::verify(const Handshake_State& state, Connection_Side side) const { return (m_verification_data == finished_compute_verify(state, side)); diff --git a/src/tls/msg_server_kex.cpp b/src/tls/msg_server_kex.cpp index b3c4e9017..250a8c126 100644 --- a/src/tls/msg_server_kex.cpp +++ b/src/tls/msg_server_kex.cpp @@ -28,14 +28,14 @@ namespace TLS { * Create a new Server Key Exchange message */ Server_Key_Exchange::Server_Key_Exchange(Handshake_IO& io, - Handshake_State* state, + Handshake_State& state, const Policy& policy, Credentials_Manager& creds, RandomNumberGenerator& rng, const Private_Key* signing_key) { - const std::string hostname = state->client_hello()->sni_hostname(); - const std::string kex_algo = state->ciphersuite().kex_algo(); + const std::string hostname = state.client_hello()->sni_hostname(); + const std::string kex_algo = state.ciphersuite().kex_algo(); if(kex_algo == "PSK" || kex_algo == "DHE_PSK" || kex_algo == "ECDHE_PSK") { @@ -57,7 +57,7 @@ Server_Key_Exchange::Server_Key_Exchange(Handshake_IO& io, else if(kex_algo == "ECDH" || kex_algo == "ECDHE_PSK") { const std::vector<std::string>& curves = - state->client_hello()->supported_ecc_curves(); + state.client_hello()->supported_ecc_curves(); if(curves.empty()) throw Internal_Error("Client sent no ECC extension but we negotiated ECDH"); @@ -90,7 +90,7 @@ Server_Key_Exchange::Server_Key_Exchange(Handshake_IO& io, } else if(kex_algo == "SRP_SHA") { - const std::string srp_identifier = state->client_hello()->srp_identifier(); + const std::string srp_identifier = state.client_hello()->srp_identifier(); std::string group_id; BigInt v; @@ -120,22 +120,22 @@ Server_Key_Exchange::Server_Key_Exchange(Handshake_IO& io, else if(kex_algo != "PSK") throw Internal_Error("Server_Key_Exchange: Unknown kex type " + kex_algo); - if(state->ciphersuite().sig_algo() != "") + if(state.ciphersuite().sig_algo() != "") { BOTAN_ASSERT(signing_key, "Signing key was set"); std::pair<std::string, Signature_Format> format = - state->choose_sig_format(signing_key, m_hash_algo, m_sig_algo, false, policy); + state.choose_sig_format(signing_key, m_hash_algo, m_sig_algo, false, policy); PK_Signer signer(*signing_key, format.first, format.second); - signer.update(state->client_hello()->random()); - signer.update(state->server_hello()->random()); + signer.update(state.client_hello()->random()); + signer.update(state.server_hello()->random()); signer.update(params()); m_signature = signer.signature(rng); } - state->hash().update(io.send(*this)); + state.hash().update(io.send(*this)); } /** @@ -255,17 +255,17 @@ std::vector<byte> Server_Key_Exchange::serialize() const * Verify a Server Key Exchange message */ bool Server_Key_Exchange::verify(const X509_Certificate& cert, - const Handshake_State* state) const + const Handshake_State& state) const { std::unique_ptr<Public_Key> key(cert.subject_public_key()); std::pair<std::string, Signature_Format> format = - state->understand_sig_format(key.get(), m_hash_algo, m_sig_algo, false); + state.understand_sig_format(key.get(), m_hash_algo, m_sig_algo, false); PK_Verifier verifier(*key, format.first, format.second); - verifier.update(state->client_hello()->random()); - verifier.update(state->server_hello()->random()); + verifier.update(state.client_hello()->random()); + verifier.update(state.server_hello()->random()); verifier.update(params()); return verifier.check_signature(m_signature); diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index 024aec099..a2260e448 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -175,7 +175,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) if(msg.first == HANDSHAKE_NONE) // no full handshake yet break; - process_handshake_msg(msg.first, msg.second); + process_handshake_msg(*m_state.get(), msg.first, msg.second); } } else if(rec_type == HEARTBEAT && m_peer_supports_heartbeats) @@ -383,6 +383,9 @@ void Channel::send_alert(const Alert& alert) catch(...) { /* swallow it */ } } + if(alert.type() == Alert::NO_RENEGOTIATION) + m_state.reset(); + if(alert.is_fatal() && !m_active_session.empty()) { m_session_manager.remove_entry(m_active_session); diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h index 1730be792..d1a6e8b00 100644 --- a/src/tls/tls_channel.h +++ b/src/tls/tls_channel.h @@ -96,6 +96,12 @@ class BOTAN_DLL Channel virtual ~Channel(); protected: + virtual void process_handshake_msg(class Handshake_State& state, + Handshake_Type type, + const std::vector<byte>& contents) = 0; + + virtual class Handshake_State* new_handshake_state() = 0; + /** * Send a TLS alert message. If the alert is fatal, the internal * state (keys, etc) will be reset. @@ -105,11 +111,6 @@ class BOTAN_DLL Channel void activate_session(const std::vector<byte>& session_id); - virtual void process_handshake_msg(Handshake_Type type, - const std::vector<byte>& contents) = 0; - - virtual class Handshake_State* new_handshake_state() = 0; - void heartbeat_support(bool peer_supports, bool allowed_to_send); void set_protocol_version(Protocol_Version version); diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index d2508d579..8cb89ffb5 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -146,25 +146,21 @@ void Client::initiate_handshake(bool force_full_renegotiation, /* * Process a handshake message */ -void Client::process_handshake_msg(Handshake_Type type, +void Client::process_handshake_msg(Handshake_State& state, + Handshake_Type type, const std::vector<byte>& contents) { - if(!m_state) - throw Unexpected_Message("Unexpected handshake message from server"); - if(type == HELLO_REQUEST) { Hello_Request hello_request(contents); // Ignore request entirely if we are currently negotiating a handshake - if(m_state->client_hello()) + if(state.client_hello()) return; if(!m_policy.allow_server_initiated_renegotiation() || (!m_policy.allow_insecure_renegotiation() && !m_secure_renegotiation.supported())) { - m_state.reset(); - // RFC 5746 section 4.2 send_alert(Alert(Alert::NO_RENEGOTIATION)); return; @@ -175,69 +171,69 @@ void Client::process_handshake_msg(Handshake_Type type, return; } - m_state->confirm_transition_to(type); + state.confirm_transition_to(type); if(type != HANDSHAKE_CCS && type != FINISHED && type != HELLO_VERIFY_REQUEST) - m_state->hash().update(m_state->handshake_io().format(contents, type)); + state.hash().update(state.handshake_io().format(contents, type)); if(type == HELLO_VERIFY_REQUEST) { - m_state->set_expected_next(SERVER_HELLO); - m_state->set_expected_next(HELLO_VERIFY_REQUEST); // might get it again + state.set_expected_next(SERVER_HELLO); + state.set_expected_next(HELLO_VERIFY_REQUEST); // might get it again Hello_Verify_Request hello_verify_request(contents); - m_state->note_message(hello_verify_request); + state.note_message(hello_verify_request); std::unique_ptr<Client_Hello> client_hello_w_cookie( - new Client_Hello(m_state->handshake_io(), - m_state->hash(), - *m_state->client_hello(), + new Client_Hello(state.handshake_io(), + state.hash(), + *state.client_hello(), hello_verify_request)); - m_state->client_hello(client_hello_w_cookie.release()); + state.client_hello(client_hello_w_cookie.release()); } else if(type == SERVER_HELLO) { - m_state->server_hello(new Server_Hello(contents)); + state.server_hello(new Server_Hello(contents)); - if(!m_state->client_hello()->offered_suite(m_state->server_hello()->ciphersuite())) + if(!state.client_hello()->offered_suite(state.server_hello()->ciphersuite())) { throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Server replied with ciphersuite we didn't send"); } - if(!value_exists(m_state->client_hello()->compression_methods(), - m_state->server_hello()->compression_method())) + if(!value_exists(state.client_hello()->compression_methods(), + state.server_hello()->compression_method())) { throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Server replied with compression method we didn't send"); } - if(!m_state->client_hello()->next_protocol_notification() && - m_state->server_hello()->next_protocol_notification()) + if(!state.client_hello()->next_protocol_notification() && + state.server_hello()->next_protocol_notification()) { throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Server sent next protocol but we didn't request it"); } - if(m_state->server_hello()->supports_session_ticket()) + if(state.server_hello()->supports_session_ticket()) { - if(!m_state->client_hello()->supports_session_ticket()) + if(!state.client_hello()->supports_session_ticket()) throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Server sent session ticket extension but we did not"); } - set_protocol_version(m_state->server_hello()->version()); + set_protocol_version(state.server_hello()->version()); - m_secure_renegotiation.update(m_state->server_hello()); + m_secure_renegotiation.update(state.server_hello()); - heartbeat_support(m_state->server_hello()->supports_heartbeats(), - m_state->server_hello()->peer_can_send_heartbeats()); + heartbeat_support(state.server_hello()->supports_heartbeats(), + state.server_hello()->peer_can_send_heartbeats()); const bool server_returned_same_session_id = - !m_state->server_hello()->session_id().empty() && - (m_state->server_hello()->session_id() == m_state->client_hello()->session_id()); + !state.server_hello()->session_id().empty() && + (state.server_hello()->session_id() == state.client_hello()->session_id()); if(server_returned_same_session_id) { @@ -247,40 +243,40 @@ void Client::process_handshake_msg(Handshake_Type type, * In this case, we offered the version used in the original * session, and the server must resume with the same version. */ - if(m_state->server_hello()->version() != m_state->client_hello()->version()) + if(state.server_hello()->version() != state.client_hello()->version()) throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Server resumed session but with wrong version"); - m_state->compute_session_keys( - dynamic_cast<Client_Handshake_State&>(*m_state).resume_master_secret + state.compute_session_keys( + dynamic_cast<Client_Handshake_State&>(state).resume_master_secret ); - if(m_state->server_hello()->supports_session_ticket()) - m_state->set_expected_next(NEW_SESSION_TICKET); + if(state.server_hello()->supports_session_ticket()) + state.set_expected_next(NEW_SESSION_TICKET); else - m_state->set_expected_next(HANDSHAKE_CCS); + state.set_expected_next(HANDSHAKE_CCS); } else { // new session - if(m_state->version() > m_state->client_hello()->version()) + if(state.version() > state.client_hello()->version()) { throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Server replied with later version than in hello"); } - if(!m_policy.acceptable_protocol_version(m_state->version())) + if(!m_policy.acceptable_protocol_version(state.version())) { throw TLS_Exception(Alert::PROTOCOL_VERSION, "Server version is unacceptable by policy"); } - if(m_state->ciphersuite().sig_algo() != "") + if(state.ciphersuite().sig_algo() != "") { - m_state->set_expected_next(CERTIFICATE); + state.set_expected_next(CERTIFICATE); } - else if(m_state->ciphersuite().kex_algo() == "PSK") + else if(state.ciphersuite().kex_algo() == "PSK") { /* PSK is anonymous so no certificate/cert req message is ever sent. The server may or may not send a server kex, @@ -290,35 +286,35 @@ void Client::process_handshake_msg(Handshake_Type type, DH exchange portion. */ - m_state->set_expected_next(SERVER_KEX); - m_state->set_expected_next(SERVER_HELLO_DONE); + state.set_expected_next(SERVER_KEX); + state.set_expected_next(SERVER_HELLO_DONE); } - else if(m_state->ciphersuite().kex_algo() != "RSA") + else if(state.ciphersuite().kex_algo() != "RSA") { - m_state->set_expected_next(SERVER_KEX); + state.set_expected_next(SERVER_KEX); } else { - m_state->set_expected_next(CERTIFICATE_REQUEST); // optional - m_state->set_expected_next(SERVER_HELLO_DONE); + state.set_expected_next(CERTIFICATE_REQUEST); // optional + state.set_expected_next(SERVER_HELLO_DONE); } } } else if(type == CERTIFICATE) { - if(m_state->ciphersuite().kex_algo() != "RSA") + if(state.ciphersuite().kex_algo() != "RSA") { - m_state->set_expected_next(SERVER_KEX); + state.set_expected_next(SERVER_KEX); } else { - m_state->set_expected_next(CERTIFICATE_REQUEST); // optional - m_state->set_expected_next(SERVER_HELLO_DONE); + state.set_expected_next(CERTIFICATE_REQUEST); // optional + state.set_expected_next(SERVER_HELLO_DONE); } - m_state->server_certs(new Certificate(contents)); + state.server_certs(new Certificate(contents)); - m_peer_certs = m_state->server_certs()->cert_chain(); + m_peer_certs = state.server_certs()->cert_chain(); if(m_peer_certs.empty()) throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Client: No certificates sent by server"); @@ -334,25 +330,25 @@ void Client::process_handshake_msg(Handshake_Type type, std::unique_ptr<Public_Key> peer_key(m_peer_certs[0].subject_public_key()); - if(peer_key->algo_name() != m_state->ciphersuite().sig_algo()) + if(peer_key->algo_name() != state.ciphersuite().sig_algo()) throw TLS_Exception(Alert::ILLEGAL_PARAMETER, "Certificate key type did not match ciphersuite"); } else if(type == SERVER_KEX) { - m_state->set_expected_next(CERTIFICATE_REQUEST); // optional - m_state->set_expected_next(SERVER_HELLO_DONE); + state.set_expected_next(CERTIFICATE_REQUEST); // optional + state.set_expected_next(SERVER_HELLO_DONE); - m_state->server_kex( + state.server_kex( new Server_Key_Exchange(contents, - m_state->ciphersuite().kex_algo(), - m_state->ciphersuite().sig_algo(), - m_state->version()) + state.ciphersuite().kex_algo(), + state.ciphersuite().sig_algo(), + state.version()) ); - if(m_state->ciphersuite().sig_algo() != "") + if(state.ciphersuite().sig_algo() != "") { - if(!m_state->server_kex()->verify(m_peer_certs[0], m_state.get())) + if(!state.server_kex()->verify(m_peer_certs[0], state)) { throw TLS_Exception(Alert::DECRYPT_ERROR, "Bad signature on server key exchange"); @@ -361,37 +357,37 @@ void Client::process_handshake_msg(Handshake_Type type, } else if(type == CERTIFICATE_REQUEST) { - m_state->set_expected_next(SERVER_HELLO_DONE); - m_state->cert_req( - new Certificate_Req(contents, m_state->version()) + state.set_expected_next(SERVER_HELLO_DONE); + state.cert_req( + new Certificate_Req(contents, state.version()) ); } else if(type == SERVER_HELLO_DONE) { - m_state->server_hello_done( + state.server_hello_done( new Server_Hello_Done(contents) ); - if(m_state->received_handshake_msg(CERTIFICATE_REQUEST)) + if(state.received_handshake_msg(CERTIFICATE_REQUEST)) { const std::vector<std::string>& types = - m_state->cert_req()->acceptable_cert_types(); + state.cert_req()->acceptable_cert_types(); std::vector<X509_Certificate> client_certs = m_creds.cert_chain(types, "tls-client", m_hostname); - m_state->client_certs( - new Certificate(m_state->handshake_io(), - m_state->hash(), + state.client_certs( + new Certificate(state.handshake_io(), + state.hash(), client_certs) ); } - m_state->client_kex( - new Client_Key_Exchange(m_state->handshake_io(), - m_state.get(), + state.client_kex( + new Client_Key_Exchange(state.handshake_io(), + state, m_policy, m_creds, m_peer_certs, @@ -399,114 +395,114 @@ void Client::process_handshake_msg(Handshake_Type type, m_rng) ); - m_state->compute_session_keys(); + state.compute_session_keys(); - if(m_state->received_handshake_msg(CERTIFICATE_REQUEST) && - !m_state->client_certs()->empty()) + if(state.received_handshake_msg(CERTIFICATE_REQUEST) && + !state.client_certs()->empty()) { Private_Key* private_key = - m_creds.private_key_for(m_state->client_certs()->cert_chain()[0], + m_creds.private_key_for(state.client_certs()->cert_chain()[0], "tls-client", m_hostname); - m_state->client_verify( - new Certificate_Verify(m_state->handshake_io(), - m_state.get(), + state.client_verify( + new Certificate_Verify(state.handshake_io(), + state, m_policy, m_rng, private_key) ); } - m_state->handshake_io().send(Change_Cipher_Spec()); + state.handshake_io().send(Change_Cipher_Spec()); change_cipher_spec_writer(CLIENT); - if(m_state->server_hello()->next_protocol_notification()) + if(state.server_hello()->next_protocol_notification()) { const std::string protocol = - dynamic_cast<Client_Handshake_State&>(*m_state).client_npn_cb( - m_state->server_hello()->next_protocols()); + dynamic_cast<Client_Handshake_State&>(state).client_npn_cb( + state.server_hello()->next_protocols()); - m_state->next_protocol( - new Next_Protocol(m_state->handshake_io(), m_state->hash(), protocol) + state.next_protocol( + new Next_Protocol(state.handshake_io(), state.hash(), protocol) ); } - m_state->client_finished( - new Finished(m_state->handshake_io(), m_state.get(), CLIENT) + state.client_finished( + new Finished(state.handshake_io(), state, CLIENT) ); - if(m_state->server_hello()->supports_session_ticket()) - m_state->set_expected_next(NEW_SESSION_TICKET); + if(state.server_hello()->supports_session_ticket()) + state.set_expected_next(NEW_SESSION_TICKET); else - m_state->set_expected_next(HANDSHAKE_CCS); + state.set_expected_next(HANDSHAKE_CCS); } else if(type == NEW_SESSION_TICKET) { - m_state->new_session_ticket(new New_Session_Ticket(contents)); + state.new_session_ticket(new New_Session_Ticket(contents)); - m_state->set_expected_next(HANDSHAKE_CCS); + state.set_expected_next(HANDSHAKE_CCS); } else if(type == HANDSHAKE_CCS) { - m_state->set_expected_next(FINISHED); + state.set_expected_next(FINISHED); change_cipher_spec_reader(CLIENT); } else if(type == FINISHED) { - m_state->set_expected_next(HELLO_REQUEST); + state.set_expected_next(HELLO_REQUEST); - m_state->server_finished(new Finished(contents)); + state.server_finished(new Finished(contents)); - if(!m_state->server_finished()->verify(m_state.get(), SERVER)) + if(!state.server_finished()->verify(state, SERVER)) throw TLS_Exception(Alert::DECRYPT_ERROR, "Finished message didn't verify"); - m_state->hash().update(m_state->handshake_io().format(contents, type)); + state.hash().update(state.handshake_io().format(contents, type)); - if(!m_state->client_finished()) // session resume case + if(!state.client_finished()) // session resume case { - m_state->handshake_io().send(Change_Cipher_Spec()); + state.handshake_io().send(Change_Cipher_Spec()); change_cipher_spec_writer(CLIENT); - if(m_state->server_hello()->next_protocol_notification()) + if(state.server_hello()->next_protocol_notification()) { const std::string protocol = dynamic_cast<Client_Handshake_State&>(*m_state).client_npn_cb( - m_state->server_hello()->next_protocols()); + state.server_hello()->next_protocols()); - m_state->next_protocol( - new Next_Protocol(m_state->handshake_io(), m_state->hash(), protocol) + state.next_protocol( + new Next_Protocol(state.handshake_io(), state.hash(), protocol) ); } - m_state->client_finished( - new Finished(m_state->handshake_io(), m_state.get(), CLIENT) + state.client_finished( + new Finished(state.handshake_io(), state, CLIENT) ); } - m_secure_renegotiation.update(m_state->client_finished(), - m_state->server_finished()); + m_secure_renegotiation.update(state.client_finished(), + state.server_finished()); - std::vector<byte> session_id = m_state->server_hello()->session_id(); + std::vector<byte> session_id = state.server_hello()->session_id(); - const std::vector<byte>& session_ticket = m_state->session_ticket(); + const std::vector<byte>& session_ticket = state.session_ticket(); if(session_id.empty() && !session_ticket.empty()) session_id = make_hello_random(m_rng); Session session_info( session_id, - m_state->session_keys().master_secret(), - m_state->server_hello()->version(), - m_state->server_hello()->ciphersuite(), - m_state->server_hello()->compression_method(), + state.session_keys().master_secret(), + state.server_hello()->version(), + state.server_hello()->ciphersuite(), + state.server_hello()->compression_method(), CLIENT, m_secure_renegotiation.supported(), - m_state->server_hello()->fragment_size(), + state.server_hello()->fragment_size(), m_peer_certs, session_ticket, m_hostname, diff --git a/src/tls/tls_client.h b/src/tls/tls_client.h index b256a75e9..d7c7598b5 100644 --- a/src/tls/tls_client.h +++ b/src/tls/tls_client.h @@ -74,7 +74,8 @@ class BOTAN_DLL Client : public Channel std::function<std::string (std::vector<std::string>)> next_protocol = std::function<std::string (std::vector<std::string>)>()); - void process_handshake_msg(Handshake_Type type, + void process_handshake_msg(Handshake_State& state, + Handshake_Type type, const std::vector<byte>& contents) override; class Handshake_State* new_handshake_state() override; diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h index 555520073..6fd7f675c 100644 --- a/src/tls/tls_messages.h +++ b/src/tls/tls_messages.h @@ -237,7 +237,7 @@ class Client_Key_Exchange : public Handshake_Message { return m_pre_master; } Client_Key_Exchange(Handshake_IO& io, - Handshake_State* state, + Handshake_State& state, const Policy& policy, Credentials_Manager& creds, const std::vector<X509_Certificate>& peer_certs, @@ -245,7 +245,7 @@ class Client_Key_Exchange : public Handshake_Message RandomNumberGenerator& rng); Client_Key_Exchange(const std::vector<byte>& buf, - const Handshake_State* state, + const Handshake_State& state, const Private_Key* server_rsa_kex_key, Credentials_Manager& creds, const Policy& policy, @@ -329,10 +329,10 @@ class Certificate_Verify : public Handshake_Message * @param state the handshake state */ bool verify(const X509_Certificate& cert, - const Handshake_State* state) const; + const Handshake_State& state) const; Certificate_Verify(Handshake_IO& io, - Handshake_State* state, + Handshake_State& state, const Policy& policy, RandomNumberGenerator& rng, const Private_Key* key); @@ -358,11 +358,11 @@ class Finished : public Handshake_Message std::vector<byte> verify_data() const { return m_verification_data; } - bool verify(const Handshake_State* state, + bool verify(const Handshake_State& state, Connection_Side side) const; Finished(Handshake_IO& io, - Handshake_State* state, + Handshake_State& state, Connection_Side side); Finished(const std::vector<byte>& buf); @@ -398,7 +398,7 @@ class Server_Key_Exchange : public Handshake_Message const std::vector<byte>& params() const { return m_params; } bool verify(const X509_Certificate& cert, - const Handshake_State* state) const; + const Handshake_State& state) const; // Only valid for certain kex types const Private_Key& server_kex_key() const; @@ -407,7 +407,7 @@ class Server_Key_Exchange : public Handshake_Message SRP6_Server_Session& server_srp_params() const; Server_Key_Exchange(Handshake_IO& io, - Handshake_State* state, + Handshake_State& state, const Policy& policy, Credentials_Manager& creds, RandomNumberGenerator& rng, diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp index 21ea568e4..d858592d4 100644 --- a/src/tls/tls_server.cpp +++ b/src/tls/tls_server.cpp @@ -254,13 +254,11 @@ void Server::renegotiate(bool force_full_renegotiation) /* * Process a handshake message */ -void Server::process_handshake_msg(Handshake_Type type, +void Server::process_handshake_msg(Handshake_State& state, + Handshake_Type type, const std::vector<byte>& contents) { - if(!m_state) - throw Unexpected_Message("Unexpected handshake message from client"); - - m_state->confirm_transition_to(type); + state.confirm_transition_to(type); /* * The change cipher spec message isn't technically a handshake @@ -272,9 +270,9 @@ void Server::process_handshake_msg(Handshake_Type type, if(type != HANDSHAKE_CCS && type != FINISHED && type != CERTIFICATE_VERIFY) { if(type == CLIENT_HELLO_SSLV2) - m_state->hash().update(contents); + state.hash().update(contents); else - m_state->hash().update(m_state->handshake_io().format(contents, type)); + state.hash().update(state.handshake_io().format(contents, type)); } if(type == CLIENT_HELLO || type == CLIENT_HELLO_SSLV2) @@ -282,17 +280,16 @@ void Server::process_handshake_msg(Handshake_Type type, if(!m_policy.allow_insecure_renegotiation() && !(m_secure_renegotiation.initial_handshake() || m_secure_renegotiation.supported())) { - m_state.reset(); send_alert(Alert(Alert::NO_RENEGOTIATION)); return; } - m_state->client_hello(new Client_Hello(contents, type)); + state.client_hello(new Client_Hello(contents, type)); - if(m_state->client_hello()->sni_hostname() != "") - m_hostname = m_state->client_hello()->sni_hostname(); + if(state.client_hello()->sni_hostname() != "") + m_hostname = state.client_hello()->sni_hostname(); - Protocol_Version client_version = m_state->client_hello()->version(); + Protocol_Version client_version = state.client_hello()->version(); const Protocol_Version prev_version = current_protocol_version(); const bool is_renegotiation = prev_version.valid(); @@ -346,20 +343,20 @@ void Server::process_handshake_msg(Handshake_Type type, "Client version is unacceptable by policy"); } - m_secure_renegotiation.update(m_state->client_hello()); + m_secure_renegotiation.update(state.client_hello()); set_protocol_version(negotiated_version); - heartbeat_support(m_state->client_hello()->supports_heartbeats(), - m_state->client_hello()->peer_can_send_heartbeats()); + heartbeat_support(state.client_hello()->supports_heartbeats(), + state.client_hello()->peer_can_send_heartbeats()); Session session_info; const bool resuming = - dynamic_cast<Server_Handshake_State&>(*m_state).allow_session_resumption && + dynamic_cast<Server_Handshake_State&>(state).allow_session_resumption && check_for_resume(session_info, m_session_manager, m_creds, - m_state->client_hello(), + state.client_hello(), std::chrono::seconds(m_policy.session_ticket_lifetime())); bool have_session_ticket_key = false; @@ -376,15 +373,15 @@ void Server::process_handshake_msg(Handshake_Type type, // resume session const bool offer_new_session_ticket = - (m_state->client_hello()->supports_session_ticket() && - m_state->client_hello()->session_ticket().empty() && + (state.client_hello()->supports_session_ticket() && + state.client_hello()->session_ticket().empty() && have_session_ticket_key); - m_state->server_hello( + state.server_hello( new Server_Hello( - m_state->handshake_io(), - m_state->hash(), - m_state->client_hello()->session_id(), + state.handshake_io(), + state.hash(), + state.client_hello()->session_id(), Protocol_Version(session_info.version()), session_info.ciphersuite_code(), session_info.compression_method(), @@ -392,64 +389,64 @@ void Server::process_handshake_msg(Handshake_Type type, m_secure_renegotiation.supported(), m_secure_renegotiation.for_server_hello(), offer_new_session_ticket, - m_state->client_hello()->next_protocol_notification(), + state.client_hello()->next_protocol_notification(), m_possible_protocols, - m_state->client_hello()->supports_heartbeats(), + state.client_hello()->supports_heartbeats(), m_rng) ); - m_secure_renegotiation.update(m_state->server_hello()); + m_secure_renegotiation.update(state.server_hello()); if(session_info.fragment_size()) set_maximum_fragment_size(session_info.fragment_size()); - m_state->compute_session_keys(session_info.master_secret()); + state.compute_session_keys(session_info.master_secret()); if(!m_handshake_fn(session_info)) { m_session_manager.remove_entry(session_info.session_id()); - if(m_state->server_hello()->supports_session_ticket()) // send an empty ticket + if(state.server_hello()->supports_session_ticket()) // send an empty ticket { - m_state->new_session_ticket( - new New_Session_Ticket(m_state->handshake_io(), - m_state->hash()) + state.new_session_ticket( + new New_Session_Ticket(state.handshake_io(), + state.hash()) ); } } - if(m_state->server_hello()->supports_session_ticket() && !m_state->new_session_ticket()) + if(state.server_hello()->supports_session_ticket() && !state.new_session_ticket()) { try { const SymmetricKey ticket_key = m_creds.psk("tls-server", "session-ticket", ""); - m_state->new_session_ticket( - new New_Session_Ticket(m_state->handshake_io(), - m_state->hash(), + state.new_session_ticket( + new New_Session_Ticket(state.handshake_io(), + state.hash(), session_info.encrypt(ticket_key, m_rng), m_policy.session_ticket_lifetime()) ); } catch(...) {} - if(!m_state->new_session_ticket()) + if(!state.new_session_ticket()) { - m_state->new_session_ticket( - new New_Session_Ticket(m_state->handshake_io(), m_state->hash()) + state.new_session_ticket( + new New_Session_Ticket(state.handshake_io(), state.hash()) ); } } - m_state->handshake_io().send(Change_Cipher_Spec()); + state.handshake_io().send(Change_Cipher_Spec()); change_cipher_spec_writer(SERVER); - m_state->server_finished( - new Finished(m_state->handshake_io(), m_state.get(), SERVER) + state.server_finished( + new Finished(state.handshake_io(), state, SERVER) ); - m_state->set_expected_next(HANDSHAKE_CCS); + state.set_expected_next(HANDSHAKE_CCS); } else // new session { @@ -472,44 +469,44 @@ void Server::process_handshake_msg(Handshake_Type type, send_alert(Alert(Alert::UNRECOGNIZED_NAME)); } - m_state->server_hello( + state.server_hello( new Server_Hello( - m_state->handshake_io(), - m_state->hash(), + state.handshake_io(), + state.hash(), make_hello_random(m_rng), // new session ID - m_state->version(), + state.version(), choose_ciphersuite(m_policy, - m_state->version(), + state.version(), m_creds, cert_chains, - m_state->client_hello()), - choose_compression(m_policy, m_state->client_hello()->compression_methods()), - m_state->client_hello()->fragment_size(), + state.client_hello()), + choose_compression(m_policy, state.client_hello()->compression_methods()), + state.client_hello()->fragment_size(), m_secure_renegotiation.supported(), m_secure_renegotiation.for_server_hello(), - m_state->client_hello()->supports_session_ticket() && have_session_ticket_key, - m_state->client_hello()->next_protocol_notification(), + state.client_hello()->supports_session_ticket() && have_session_ticket_key, + state.client_hello()->next_protocol_notification(), m_possible_protocols, - m_state->client_hello()->supports_heartbeats(), + state.client_hello()->supports_heartbeats(), m_rng) ); - m_secure_renegotiation.update(m_state->server_hello()); + m_secure_renegotiation.update(state.server_hello()); - if(m_state->client_hello()->fragment_size()) - set_maximum_fragment_size(m_state->client_hello()->fragment_size()); + if(state.client_hello()->fragment_size()) + set_maximum_fragment_size(state.client_hello()->fragment_size()); - const std::string sig_algo = m_state->ciphersuite().sig_algo(); - const std::string kex_algo = m_state->ciphersuite().kex_algo(); + const std::string sig_algo = state.ciphersuite().sig_algo(); + const std::string kex_algo = state.ciphersuite().kex_algo(); if(sig_algo != "") { BOTAN_ASSERT(!cert_chains[sig_algo].empty(), "Attempting to send empty certificate chain"); - m_state->server_certs( - new Certificate(m_state->handshake_io(), - m_state->hash(), + state.server_certs( + new Certificate(state.handshake_io(), + state.hash(), cert_chains[sig_algo]) ); } @@ -519,7 +516,7 @@ void Server::process_handshake_msg(Handshake_Type type, if(kex_algo == "RSA" || sig_algo != "") { private_key = m_creds.private_key_for( - m_state->server_certs()->cert_chain()[0], + state.server_certs()->cert_chain()[0], "tls-server", m_hostname); @@ -529,13 +526,13 @@ void Server::process_handshake_msg(Handshake_Type type, if(kex_algo == "RSA") { - dynamic_cast<Server_Handshake_State&>(*m_state).server_rsa_kex_key = private_key; + dynamic_cast<Server_Handshake_State&>(state).server_rsa_kex_key = private_key; } else { - m_state->server_kex( - new Server_Key_Exchange(m_state->handshake_io(), - m_state.get(), + state.server_kex( + new Server_Key_Exchange(state.handshake_io(), + state, m_policy, m_creds, m_rng, @@ -546,17 +543,17 @@ void Server::process_handshake_msg(Handshake_Type type, std::vector<X509_Certificate> client_auth_CAs = m_creds.trusted_certificate_authorities("tls-server", m_hostname); - if(!client_auth_CAs.empty() && m_state->ciphersuite().sig_algo() != "") + if(!client_auth_CAs.empty() && state.ciphersuite().sig_algo() != "") { - m_state->cert_req( - new Certificate_Req(m_state->handshake_io(), - m_state->hash(), + state.cert_req( + new Certificate_Req(state.handshake_io(), + state.hash(), m_policy, client_auth_CAs, - m_state->version()) + state.version()) ); - m_state->set_expected_next(CERTIFICATE); + state.set_expected_next(CERTIFICATE); } /* @@ -564,44 +561,44 @@ void Server::process_handshake_msg(Handshake_Type type, * allowed to send either an empty cert message or proceed * directly to the client key exchange, so allow either case. */ - m_state->set_expected_next(CLIENT_KEX); + state.set_expected_next(CLIENT_KEX); - m_state->server_hello_done( - new Server_Hello_Done(m_state->handshake_io(), m_state->hash()) + state.server_hello_done( + new Server_Hello_Done(state.handshake_io(), state.hash()) ); } } else if(type == CERTIFICATE) { - m_state->client_certs(new Certificate(contents)); + state.client_certs(new Certificate(contents)); - m_state->set_expected_next(CLIENT_KEX); + state.set_expected_next(CLIENT_KEX); } else if(type == CLIENT_KEX) { - if(m_state->received_handshake_msg(CERTIFICATE) && !m_state->client_certs()->empty()) - m_state->set_expected_next(CERTIFICATE_VERIFY); + if(state.received_handshake_msg(CERTIFICATE) && !state.client_certs()->empty()) + state.set_expected_next(CERTIFICATE_VERIFY); else - m_state->set_expected_next(HANDSHAKE_CCS); + state.set_expected_next(HANDSHAKE_CCS); - m_state->client_kex( - new Client_Key_Exchange(contents, m_state.get(), - dynamic_cast<Server_Handshake_State&>(*m_state).server_rsa_kex_key, + state.client_kex( + new Client_Key_Exchange(contents, state, + dynamic_cast<Server_Handshake_State&>(state).server_rsa_kex_key, m_creds, m_policy, m_rng) ); - m_state->compute_session_keys(); + state.compute_session_keys(); } else if(type == CERTIFICATE_VERIFY) { - m_state->client_verify(new Certificate_Verify(contents, m_state->version())); + state.client_verify(new Certificate_Verify(contents, state.version())); - m_peer_certs = m_state->client_certs()->cert_chain(); + m_peer_certs = state.client_certs()->cert_chain(); const bool sig_valid = - m_state->client_verify()->verify(m_peer_certs[0], m_state.get()); + state.client_verify()->verify(m_peer_certs[0], state); - m_state->hash().update(m_state->handshake_io().format(contents, type)); + state.hash().update(state.handshake_io().format(contents, type)); /* * Using DECRYPT_ERROR looks weird here, but per RFC 4346 is for @@ -620,68 +617,68 @@ void Server::process_handshake_msg(Handshake_Type type, throw TLS_Exception(Alert::BAD_CERTIFICATE, e.what()); } - m_state->set_expected_next(HANDSHAKE_CCS); + state.set_expected_next(HANDSHAKE_CCS); } else if(type == HANDSHAKE_CCS) { - if(m_state->server_hello()->next_protocol_notification()) - m_state->set_expected_next(NEXT_PROTOCOL); + if(state.server_hello()->next_protocol_notification()) + state.set_expected_next(NEXT_PROTOCOL); else - m_state->set_expected_next(FINISHED); + state.set_expected_next(FINISHED); change_cipher_spec_reader(SERVER); } else if(type == NEXT_PROTOCOL) { - m_state->set_expected_next(FINISHED); + state.set_expected_next(FINISHED); - m_state->next_protocol(new Next_Protocol(contents)); + state.next_protocol(new Next_Protocol(contents)); // should this be a callback? - m_next_protocol = m_state->next_protocol()->protocol(); + m_next_protocol = state.next_protocol()->protocol(); } else if(type == FINISHED) { - m_state->set_expected_next(HANDSHAKE_NONE); + state.set_expected_next(HANDSHAKE_NONE); - m_state->client_finished(new Finished(contents)); + state.client_finished(new Finished(contents)); - if(!m_state->client_finished()->verify(m_state.get(), CLIENT)) + if(!state.client_finished()->verify(state, CLIENT)) throw TLS_Exception(Alert::DECRYPT_ERROR, "Finished message didn't verify"); - if(!m_state->server_finished()) + if(!state.server_finished()) { // already sent finished if resuming, so this is a new session - m_state->hash().update(m_state->handshake_io().format(contents, type)); + state.hash().update(state.handshake_io().format(contents, type)); Session session_info( - m_state->server_hello()->session_id(), - m_state->session_keys().master_secret(), - m_state->server_hello()->version(), - m_state->server_hello()->ciphersuite(), - m_state->server_hello()->compression_method(), + state.server_hello()->session_id(), + state.session_keys().master_secret(), + state.server_hello()->version(), + state.server_hello()->ciphersuite(), + state.server_hello()->compression_method(), SERVER, m_secure_renegotiation.supported(), - m_state->server_hello()->fragment_size(), + state.server_hello()->fragment_size(), m_peer_certs, std::vector<byte>(), m_hostname, - m_state->srp_identifier() + state.srp_identifier() ); if(m_handshake_fn(session_info)) { - if(m_state->server_hello()->supports_session_ticket()) + if(state.server_hello()->supports_session_ticket()) { try { const SymmetricKey ticket_key = m_creds.psk("tls-server", "session-ticket", ""); - m_state->new_session_ticket( - new New_Session_Ticket(m_state->handshake_io(), - m_state->hash(), + state.new_session_ticket( + new New_Session_Ticket(state.handshake_io(), + state.hash(), session_info.encrypt(ticket_key, m_rng), m_policy.session_ticket_lifetime()) ); @@ -692,24 +689,24 @@ void Server::process_handshake_msg(Handshake_Type type, m_session_manager.save(session_info); } - if(!m_state->new_session_ticket() && - m_state->server_hello()->supports_session_ticket()) + if(!state.new_session_ticket() && + state.server_hello()->supports_session_ticket()) { - m_state->new_session_ticket( - new New_Session_Ticket(m_state->handshake_io(), m_state->hash()) + state.new_session_ticket( + new New_Session_Ticket(state.handshake_io(), state.hash()) ); } - m_state->handshake_io().send(Change_Cipher_Spec()); + state.handshake_io().send(Change_Cipher_Spec()); change_cipher_spec_writer(SERVER); - m_state->server_finished( - new Finished(m_state->handshake_io(), m_state.get(), SERVER) + state.server_finished( + new Finished(state.handshake_io(), state, SERVER) ); } - activate_session(m_state->server_hello()->session_id()); + activate_session(state.server_hello()->session_id()); } else throw Unexpected_Message("Unknown handshake message received"); diff --git a/src/tls/tls_server.h b/src/tls/tls_server.h index 1add1ff40..e70989ad0 100644 --- a/src/tls/tls_server.h +++ b/src/tls/tls_server.h @@ -50,7 +50,9 @@ class BOTAN_DLL Server : public Channel { return m_next_protocol; } private: - void process_handshake_msg(Handshake_Type, const std::vector<byte>&) override; + void process_handshake_msg(Handshake_State& state, + Handshake_Type type, + const std::vector<byte>& contents) override; class Handshake_State* new_handshake_state() override; |