diff options
Diffstat (limited to 'src/tls')
-rw-r--r-- | src/tls/tls_client.cpp | 89 | ||||
-rw-r--r-- | src/tls/tls_client.h | 5 |
2 files changed, 39 insertions, 55 deletions
diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index fe25736d7..8a7492d8e 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -37,21 +37,47 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn, { m_writer.set_version(Protocol_Version::SSL_V3); + const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_hostname); + + initiate_handshake(false, srp_identifier, next_protocol); + } + +Handshake_State* Client::new_handshake_state() + { + return new Handshake_State(new Stream_Handshake_Reader, + new Stream_Handshake_Writer(m_writer)); + } + +/* +* Send a new client hello to renegotiate +*/ +void Client::renegotiate(bool force_full_renegotiation) + { + if(m_state && m_state->client_hello) + return; // currently in active handshake + + delete m_state; + + initiate_handshake(force_full_renegotiation); + } + +void Client::initiate_handshake(bool force_full_renegotiation, + const std::string& srp_identifier, + std::function<std::string (std::vector<std::string>)> next_protocol) + { m_state = new_handshake_state(); m_state->set_expected_next(SERVER_HELLO); m_state->client_npn_cb = next_protocol; - const std::string srp_identifier = m_creds.srp_identifier("tls-client", hostname); - const bool send_npn_request = static_cast<bool>(next_protocol); - if(hostname != "") + if(!force_full_renegotiation && m_hostname != "") { Session session_info; if(m_session_manager.load_from_host_info(m_hostname, m_port, session_info)) { - if(session_info.srp_identifier() == srp_identifier) + if(srp_identifier == "" || session_info.srp_identifier() == srp_identifier) { m_state->client_hello = new Client_Hello( m_state->handshake_writer(), @@ -67,12 +93,15 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn, } } + const Protocol_Version version = m_reader.get_version().valid() ? + m_reader.get_version() : m_policy.pref_version(); + if(!m_state->client_hello) // not resuming { m_state->client_hello = new Client_Hello( m_state->handshake_writer(), m_state->hash, - m_policy.pref_version(), + version, m_policy, m_rng, m_secure_renegotiation.for_client_hello(), @@ -84,56 +113,6 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn, m_secure_renegotiation.update(m_state->client_hello); } -Handshake_State* Client::new_handshake_state() - { - return new Handshake_State(new Stream_Handshake_Reader, - new Stream_Handshake_Writer(m_writer)); - } - -/* -* Send a new client hello to renegotiate -*/ -void Client::renegotiate(bool force_full_renegotiation) - { - if(m_state && m_state->client_hello) - return; // currently in active handshake - - delete m_state; - m_state = new_handshake_state(); - - m_state->set_expected_next(SERVER_HELLO); - - if(!force_full_renegotiation) - { - Session session_info; - if(m_session_manager.load_from_host_info(m_hostname, m_port, session_info)) - { - m_state->client_hello = new Client_Hello( - m_state->handshake_writer(), - m_state->hash, - m_policy, - m_rng, - m_secure_renegotiation.for_client_hello(), - session_info); - - m_state->resume_master_secret = session_info.master_secret(); - } - } - - if(!m_state->client_hello) - { - m_state->client_hello = new Client_Hello( - m_state->handshake_writer(), - m_state->hash, - m_reader.get_version(), - m_policy, - m_rng, - m_secure_renegotiation.for_client_hello()); - } - - m_secure_renegotiation.update(m_state->client_hello); - } - void Client::alert_notify(const Alert& alert) { if(alert.type() == Alert::NO_RENEGOTIATION) diff --git a/src/tls/tls_client.h b/src/tls/tls_client.h index ad13a94dc..f8af662d3 100644 --- a/src/tls/tls_client.h +++ b/src/tls/tls_client.h @@ -68,6 +68,11 @@ class BOTAN_DLL Client : public Channel void renegotiate(bool force_full_renegotiation = false) override; private: + void initiate_handshake(bool force_full_renegotiation, + const std::string& srp_identifier = "", + std::function<std::string (std::vector<std::string>)> next_protocol = + std::function<std::string (std::vector<std::string>)>()); + void process_handshake_msg(Handshake_Type type, const std::vector<byte>& contents) override; |