diff options
Diffstat (limited to 'src/lib/tls/tls_server.cpp')
-rw-r--r-- | src/lib/tls/tls_server.cpp | 58 |
1 files changed, 38 insertions, 20 deletions
diff --git a/src/lib/tls/tls_server.cpp b/src/lib/tls/tls_server.cpp index 330135e63..774827346 100644 --- a/src/lib/tls/tls_server.cpp +++ b/src/lib/tls/tls_server.cpp @@ -21,8 +21,7 @@ class Server_Handshake_State : public Handshake_State public: // using Handshake_State::Handshake_State; - Server_Handshake_State(Handshake_IO* io, hs_msg_cb cb = hs_msg_cb()) : - Handshake_State(io, cb) {} + Server_Handshake_State(Handshake_IO* io, handshake_msg_cb cb) : Handshake_State(io, cb) {} // Used by the server only, in case of RSA key exchange. Not owned Private_Key* server_rsa_kex_key = nullptr; @@ -215,9 +214,26 @@ Server::Server(output_fn output, next_protocol_fn next_proto, bool is_datagram, size_t io_buf_sz) : - Channel(output, data_cb, alert_cb, handshake_cb, - session_manager, rng, is_datagram, io_buf_sz), - m_policy(policy), + Channel(output, data_cb, alert_cb, handshake_cb, Channel::handshake_msg_cb(), + session_manager, rng, policy, is_datagram, io_buf_sz), + m_creds(creds), + m_choose_next_protocol(next_proto) + { + } + +Server::Server(output_fn output, + data_cb data_cb, + alert_cb alert_cb, + handshake_cb handshake_cb, + handshake_msg_cb hs_msg_cb, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + next_protocol_fn next_proto, + bool is_datagram) : + Channel(output, data_cb, alert_cb, handshake_cb, hs_msg_cb, + session_manager, rng, policy, is_datagram), m_creds(creds), m_choose_next_protocol(next_proto) { @@ -225,7 +241,9 @@ Server::Server(output_fn output, Handshake_State* Server::new_handshake_state(Handshake_IO* io) { - std::unique_ptr<Handshake_State> state(new Server_Handshake_State(io)); + std::unique_ptr<Handshake_State> state( + new Server_Handshake_State(io, get_handshake_msg_cb())); + state->set_expected_next(CLIENT_HELLO); return state.release(); } @@ -278,7 +296,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, { const bool initial_handshake = !active_state; - if(!m_policy.allow_insecure_renegotiation() && + if(!policy().allow_insecure_renegotiation() && !(initial_handshake || secure_renegotiation_supported())) { send_warning_alert(Alert::NO_RENEGOTIATION); @@ -292,7 +310,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, Protocol_Version negotiated_version; const Protocol_Version latest_supported = - m_policy.latest_supported_version(client_version.is_datagram_protocol()); + policy().latest_supported_version(client_version.is_datagram_protocol()); if((initial_handshake && client_version.known_version()) || (!initial_handshake && client_version == active_state->version())) @@ -334,7 +352,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, negotiated_version = latest_supported; } - if(!m_policy.acceptable_protocol_version(negotiated_version)) + if(!policy().acceptable_protocol_version(negotiated_version)) { throw TLS_Exception(Alert::PROTOCOL_VERSION, "Client version " + negotiated_version.to_string() + @@ -359,7 +377,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, session_manager(), m_creds, state.client_hello(), - std::chrono::seconds(m_policy.session_ticket_lifetime())); + std::chrono::seconds(policy().session_ticket_lifetime())); bool have_session_ticket_key = false; @@ -387,7 +405,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, state.server_hello(new Server_Hello( state.handshake_io(), state.hash(), - m_policy, + policy(), rng(), secure_renegotiation_data_for_server_hello(), *state.client_hello(), @@ -423,7 +441,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, new New_Session_Ticket(state.handshake_io(), state.hash(), session_info.encrypt(ticket_key, rng()), - m_policy.session_ticket_lifetime()) + policy().session_ticket_lifetime()) ); } catch(...) {} @@ -469,14 +487,14 @@ void Server::process_handshake_msg(const Handshake_State* active_state, state.server_hello(new Server_Hello( state.handshake_io(), state.hash(), - m_policy, + policy(), rng(), secure_renegotiation_data_for_server_hello(), *state.client_hello(), - make_hello_random(rng(), m_policy), // new session ID + make_hello_random(rng(), policy()), // new session ID state.version(), - choose_ciphersuite(m_policy, state.version(), m_creds, cert_chains, state.client_hello()), - choose_compression(m_policy, state.client_hello()->compression_methods()), + choose_ciphersuite(policy(), state.version(), m_creds, cert_chains, state.client_hello()), + choose_compression(policy(), state.client_hello()->compression_methods()), have_session_ticket_key, m_next_protocol) ); @@ -516,7 +534,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, else { state.server_kex(new Server_Key_Exchange(state.handshake_io(), - state, m_policy, + state, policy(), m_creds, rng(), private_key)); } @@ -534,7 +552,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, { state.cert_req( new Certificate_Req(state.handshake_io(), state.hash(), - m_policy, client_auth_CAs, state.version())); + policy(), client_auth_CAs, state.version())); state.set_expected_next(CERTIFICATE); } @@ -565,7 +583,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, state.client_kex( new Client_Key_Exchange(contents, state, state.server_rsa_kex_key, - m_creds, m_policy, rng()) + m_creds, policy(), rng()) ); state.compute_session_keys(); @@ -649,7 +667,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, new New_Session_Ticket(state.handshake_io(), state.hash(), session_info.encrypt(ticket_key, rng()), - m_policy.session_ticket_lifetime()) + policy().session_ticket_lifetime()) ); } catch(...) {} |