diff options
Diffstat (limited to 'src/tls')
-rw-r--r-- | src/tls/tls_client.cpp | 31 | ||||
-rw-r--r-- | src/tls/tls_server.cpp | 10 |
2 files changed, 20 insertions, 21 deletions
diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index ec1c40549..a3b817c32 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -94,17 +94,19 @@ void Client::initiate_handshake(Handshake_State& state, current_protocol_version()); } -void Client::initiate_handshake(Handshake_State& state, +void Client::initiate_handshake(Handshake_State& state_base, bool force_full_renegotiation, Protocol_Version version, const std::string& srp_identifier, std::function<std::string (std::vector<std::string>)> next_protocol) { + Client_Handshake_State& state = dynamic_cast<Client_Handshake_State&>(state_base); + if(state.version().is_datagram_protocol()) state.set_expected_next(HELLO_VERIFY_REQUEST); state.set_expected_next(SERVER_HELLO); - dynamic_cast<Client_Handshake_State&>(state).client_npn_cb = next_protocol; + state.client_npn_cb = next_protocol; const bool send_npn_request = static_cast<bool>(next_protocol); @@ -124,8 +126,7 @@ void Client::initiate_handshake(Handshake_State& state, session_info, send_npn_request)); - dynamic_cast<Client_Handshake_State&>(state).resume_master_secret = - session_info.master_secret(); + state.resume_master_secret = session_info.master_secret(); } } } @@ -153,10 +154,12 @@ void Client::initiate_handshake(Handshake_State& state, * Process a handshake message */ void Client::process_handshake_msg(const Handshake_State* /*active_state*/, - Handshake_State& state, + Handshake_State& state_base, Handshake_Type type, const std::vector<byte>& contents) { + Client_Handshake_State& state = dynamic_cast<Client_Handshake_State&>(state_base); + if(type == HELLO_REQUEST) { Hello_Request hello_request(contents); @@ -251,9 +254,7 @@ void Client::process_handshake_msg(const Handshake_State* /*active_state*/, throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Server resumed session but with wrong version"); - state.compute_session_keys( - dynamic_cast<Client_Handshake_State&>(state).resume_master_secret - ); + state.compute_session_keys(state.resume_master_secret); if(state.server_hello()->supports_session_ticket()) state.set_expected_next(NEW_SESSION_TICKET); @@ -338,8 +339,7 @@ void Client::process_handshake_msg(const Handshake_State* /*active_state*/, throw TLS_Exception(Alert::ILLEGAL_PARAMETER, "Certificate key type did not match ciphersuite"); - dynamic_cast<Client_Handshake_State&>(state). - server_public_key.reset(peer_key.release()); + state.server_public_key.reset(peer_key.release()); } else if(type == SERVER_KEX) { @@ -355,8 +355,7 @@ void Client::process_handshake_msg(const Handshake_State* /*active_state*/, if(state.ciphersuite().sig_algo() != "") { - const Public_Key& server_key = - dynamic_cast<Client_Handshake_State&>(state).get_server_public_Key(); + const Public_Key& server_key = state.get_server_public_Key(); if(!state.server_kex()->verify(server_key, state)) { @@ -430,9 +429,8 @@ void Client::process_handshake_msg(const Handshake_State* /*active_state*/, if(state.server_hello()->next_protocol_notification()) { - const std::string protocol = - dynamic_cast<Client_Handshake_State&>(state).client_npn_cb( - state.server_hello()->next_protocols()); + const std::string protocol = state.client_npn_cb( + state.server_hello()->next_protocols()); state.next_protocol( new Next_Protocol(state.handshake_io(), state.hash(), protocol) @@ -480,8 +478,7 @@ void Client::process_handshake_msg(const Handshake_State* /*active_state*/, if(state.server_hello()->next_protocol_notification()) { - const std::string protocol = - dynamic_cast<Client_Handshake_State&>(state).client_npn_cb( + const std::string protocol = state.client_npn_cb( state.server_hello()->next_protocols()); state.next_protocol( diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp index cda53d6b2..d6d9bedc0 100644 --- a/src/tls/tls_server.cpp +++ b/src/tls/tls_server.cpp @@ -252,10 +252,12 @@ void Server::initiate_handshake(Handshake_State& state, * Process a handshake message */ void Server::process_handshake_msg(const Handshake_State* active_state, - Handshake_State& state, + Handshake_State& state_base, Handshake_Type type, const std::vector<byte>& contents) { + Server_Handshake_State& state = dynamic_cast<Server_Handshake_State&>(state_base); + state.confirm_transition_to(type); /* @@ -346,7 +348,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, Session session_info; const bool resuming = - dynamic_cast<Server_Handshake_State&>(state).allow_session_resumption && + state.allow_session_resumption && check_for_resume(session_info, m_session_manager, m_creds, @@ -520,7 +522,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, if(kex_algo == "RSA") { - dynamic_cast<Server_Handshake_State&>(state).server_rsa_kex_key = private_key; + state.server_rsa_kex_key = private_key; } else { @@ -577,7 +579,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, state.client_kex( new Client_Key_Exchange(contents, state, - dynamic_cast<Server_Handshake_State&>(state).server_rsa_kex_key, + state.server_rsa_kex_key, m_creds, m_policy, m_rng) ); |