aboutsummaryrefslogtreecommitdiffstats
path: root/src/tls
diff options
context:
space:
mode:
Diffstat (limited to 'src/tls')
-rw-r--r--src/tls/tls_client.cpp89
-rw-r--r--src/tls/tls_client.h5
2 files changed, 39 insertions, 55 deletions
diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp
index fe25736d7..8a7492d8e 100644
--- a/src/tls/tls_client.cpp
+++ b/src/tls/tls_client.cpp
@@ -37,21 +37,47 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn,
{
m_writer.set_version(Protocol_Version::SSL_V3);
+ const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_hostname);
+
+ initiate_handshake(false, srp_identifier, next_protocol);
+ }
+
+Handshake_State* Client::new_handshake_state()
+ {
+ return new Handshake_State(new Stream_Handshake_Reader,
+ new Stream_Handshake_Writer(m_writer));
+ }
+
+/*
+* Send a new client hello to renegotiate
+*/
+void Client::renegotiate(bool force_full_renegotiation)
+ {
+ if(m_state && m_state->client_hello)
+ return; // currently in active handshake
+
+ delete m_state;
+
+ initiate_handshake(force_full_renegotiation);
+ }
+
+void Client::initiate_handshake(bool force_full_renegotiation,
+ const std::string& srp_identifier,
+ std::function<std::string (std::vector<std::string>)> next_protocol)
+ {
m_state = new_handshake_state();
m_state->set_expected_next(SERVER_HELLO);
m_state->client_npn_cb = next_protocol;
- const std::string srp_identifier = m_creds.srp_identifier("tls-client", hostname);
-
const bool send_npn_request = static_cast<bool>(next_protocol);
- if(hostname != "")
+ if(!force_full_renegotiation && m_hostname != "")
{
Session session_info;
if(m_session_manager.load_from_host_info(m_hostname, m_port, session_info))
{
- if(session_info.srp_identifier() == srp_identifier)
+ if(srp_identifier == "" || session_info.srp_identifier() == srp_identifier)
{
m_state->client_hello = new Client_Hello(
m_state->handshake_writer(),
@@ -67,12 +93,15 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn,
}
}
+ const Protocol_Version version = m_reader.get_version().valid() ?
+ m_reader.get_version() : m_policy.pref_version();
+
if(!m_state->client_hello) // not resuming
{
m_state->client_hello = new Client_Hello(
m_state->handshake_writer(),
m_state->hash,
- m_policy.pref_version(),
+ version,
m_policy,
m_rng,
m_secure_renegotiation.for_client_hello(),
@@ -84,56 +113,6 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn,
m_secure_renegotiation.update(m_state->client_hello);
}
-Handshake_State* Client::new_handshake_state()
- {
- return new Handshake_State(new Stream_Handshake_Reader,
- new Stream_Handshake_Writer(m_writer));
- }
-
-/*
-* Send a new client hello to renegotiate
-*/
-void Client::renegotiate(bool force_full_renegotiation)
- {
- if(m_state && m_state->client_hello)
- return; // currently in active handshake
-
- delete m_state;
- m_state = new_handshake_state();
-
- m_state->set_expected_next(SERVER_HELLO);
-
- if(!force_full_renegotiation)
- {
- Session session_info;
- if(m_session_manager.load_from_host_info(m_hostname, m_port, session_info))
- {
- m_state->client_hello = new Client_Hello(
- m_state->handshake_writer(),
- m_state->hash,
- m_policy,
- m_rng,
- m_secure_renegotiation.for_client_hello(),
- session_info);
-
- m_state->resume_master_secret = session_info.master_secret();
- }
- }
-
- if(!m_state->client_hello)
- {
- m_state->client_hello = new Client_Hello(
- m_state->handshake_writer(),
- m_state->hash,
- m_reader.get_version(),
- m_policy,
- m_rng,
- m_secure_renegotiation.for_client_hello());
- }
-
- m_secure_renegotiation.update(m_state->client_hello);
- }
-
void Client::alert_notify(const Alert& alert)
{
if(alert.type() == Alert::NO_RENEGOTIATION)
diff --git a/src/tls/tls_client.h b/src/tls/tls_client.h
index ad13a94dc..f8af662d3 100644
--- a/src/tls/tls_client.h
+++ b/src/tls/tls_client.h
@@ -68,6 +68,11 @@ class BOTAN_DLL Client : public Channel
void renegotiate(bool force_full_renegotiation = false) override;
private:
+ void initiate_handshake(bool force_full_renegotiation,
+ const std::string& srp_identifier = "",
+ std::function<std::string (std::vector<std::string>)> next_protocol =
+ std::function<std::string (std::vector<std::string>)>());
+
void process_handshake_msg(Handshake_Type type,
const std::vector<byte>& contents) override;