aboutsummaryrefslogtreecommitdiffstats
path: root/src/tls
diff options
context:
space:
mode:
authorlloyd <[email protected]>2012-11-06 21:29:02 +0000
committerlloyd <[email protected]>2012-11-06 21:29:02 +0000
commit38af5416771e39a1b0c2f628cf5ff01790922bee (patch)
treead4407d1cc1e6f43d3a6621a0cbb8acfa74ca3d9 /src/tls
parentdfbe4b328fa29b80f05bf89dd6c20be304312b17 (diff)
Add Channel::pending_state and Channel::active_state, use where possible
Diffstat (limited to 'src/tls')
-rw-r--r--src/tls/tls_channel.cpp153
-rw-r--r--src/tls/tls_channel.h4
2 files changed, 84 insertions, 73 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index 7065064cc..46c5f4c74 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -76,21 +76,21 @@ std::shared_ptr<Connection_Cipher_State> Channel::write_cipher_state_current() c
std::vector<X509_Certificate> Channel::peer_cert_chain() const
{
- if(!m_active_state)
- return std::vector<X509_Certificate>();
- return get_peer_cert_chain(*m_active_state);
+ if(auto active = active_state())
+ return get_peer_cert_chain(*active);
+ return std::vector<X509_Certificate>();
}
Handshake_State& Channel::create_handshake_state(Protocol_Version version)
{
const size_t dtls_mtu = 1400; // fixme should be settable
- if(m_pending_state)
+ if(pending_state())
throw Internal_Error("create_handshake_state called during handshake");
- if(m_active_state)
+ if(auto active = active_state())
{
- Protocol_Version active_version = m_active_state->version();
+ Protocol_Version active_version = active->version();
if(active_version.is_datagram_protocol() != version.is_datagram_protocol())
throw std::runtime_error("Active state using version " +
@@ -120,22 +120,22 @@ Handshake_State& Channel::create_handshake_state(Protocol_Version version)
m_pending_state.reset(new_handshake_state(io.release()));
- if(m_active_state)
- m_pending_state->set_version(m_active_state->version());
+ if(auto active = active_state())
+ m_pending_state->set_version(active->version());
return *m_pending_state.get();
}
void Channel::renegotiate(bool force_full_renegotiation)
{
- if(m_pending_state) // currently in handshake?
+ if(pending_state()) // currently in handshake?
return;
- if(!m_active_state)
+ if(auto active = active_state())
+ initiate_handshake(create_handshake_state(active->version()),
+ force_full_renegotiation);
+ else
throw std::runtime_error("Cannot renegotiate on inactive connection");
-
- initiate_handshake(create_handshake_state(m_active_state->version()),
- force_full_renegotiation);
}
void Channel::set_maximum_fragment_size(size_t max_fragment)
@@ -148,10 +148,12 @@ void Channel::set_maximum_fragment_size(size_t max_fragment)
void Channel::change_cipher_spec_reader(Connection_Side side)
{
- BOTAN_ASSERT(m_pending_state && m_pending_state->server_hello(),
+ auto pending = pending_state();
+
+ BOTAN_ASSERT(pending && pending->server_hello(),
"Have received server hello");
- if(m_pending_state->server_hello()->compression_method() != NO_COMPRESSION)
+ if(pending->server_hello()->compression_method() != NO_COMPRESSION)
throw Internal_Error("Negotiated unknown compression algorithm");
sequence_numbers().new_read_cipher_state();
@@ -163,20 +165,22 @@ void Channel::change_cipher_spec_reader(Connection_Side side)
// flip side as we are reading
std::shared_ptr<Connection_Cipher_State> read_state(
- new Connection_Cipher_State(m_pending_state->version(),
+ new Connection_Cipher_State(pending->version(),
(side == CLIENT) ? SERVER : CLIENT,
- m_pending_state->ciphersuite(),
- m_pending_state->session_keys()));
+ pending->ciphersuite(),
+ pending->session_keys()));
m_read_cipher_states[epoch] = read_state;
}
void Channel::change_cipher_spec_writer(Connection_Side side)
{
- BOTAN_ASSERT(m_pending_state && m_pending_state->server_hello(),
+ auto pending = pending_state();
+
+ BOTAN_ASSERT(pending && pending->server_hello(),
"Have received server hello");
- if(m_pending_state->server_hello()->compression_method() != NO_COMPRESSION)
+ if(pending->server_hello()->compression_method() != NO_COMPRESSION)
throw Internal_Error("Negotiated unknown compression algorithm");
sequence_numbers().new_write_cipher_state();
@@ -187,17 +191,17 @@ void Channel::change_cipher_spec_writer(Connection_Side side)
"No write cipher state currently set for next epoch");
std::shared_ptr<Connection_Cipher_State> write_state(
- new Connection_Cipher_State(m_pending_state->version(),
+ new Connection_Cipher_State(pending->version(),
side,
- m_pending_state->ciphersuite(),
- m_pending_state->session_keys()));
+ pending->ciphersuite(),
+ pending->session_keys()));
m_write_cipher_states[epoch] = write_state;
}
bool Channel::is_active() const
{
- return m_active_state.get();
+ return (active_state() != nullptr);
}
bool Channel::is_closed() const
@@ -248,15 +252,15 @@ u16bit Channel::get_last_valid_epoch() const
bool Channel::peer_supports_heartbeats() const
{
- if(m_active_state && m_active_state->server_hello())
- return m_active_state->server_hello()->supports_heartbeats();
+ if(auto active = active_state())
+ return active->server_hello()->supports_heartbeats();
return false;
}
bool Channel::heartbeat_sending_allowed() const
{
- if(m_active_state && m_active_state->server_hello())
- return m_active_state->server_hello()->peer_can_send_heartbeats();
+ if(auto active = active_state())
+ return active->server_hello()->peer_can_send_heartbeats();
return false;
}
@@ -318,21 +322,20 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
m_pending_state->handshake_io().add_input(
rec_type, &record[0], record.size(), record_sequence);
- while(m_pending_state)
+ while(auto pending = m_pending_state.get())
{
- auto msg = m_pending_state->get_next_handshake_msg();
+ auto msg = pending->get_next_handshake_msg();
if(msg.first == HANDSHAKE_NONE) // no full handshake yet
break;
- process_handshake_msg(m_active_state.get(),
- *m_pending_state.get(),
+ process_handshake_msg(active_state(), *pending,
msg.first, msg.second);
}
}
else if(rec_type == HEARTBEAT && peer_supports_heartbeats())
{
- if(!m_active_state)
+ if(!active_state())
throw Unexpected_Message("Heartbeat sent before handshake done");
Heartbeat_Message heartbeat(record);
@@ -341,7 +344,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
if(heartbeat.is_request())
{
- if(!m_pending_state) // no heartbeats during handshake
+ if(!pending_state())
{
Heartbeat_Message response(Heartbeat_Message::RESPONSE,
&payload[0], payload.size());
@@ -357,7 +360,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
}
else if(rec_type == APPLICATION_DATA)
{
- if(!m_active_state)
+ if(!active_state())
throw Unexpected_Message("Application data before handshake done");
/*
@@ -379,8 +382,8 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
if(alert_msg.is_fatal())
{
- if(m_active_state && m_active_state->server_hello())
- m_session_manager.remove_entry(m_active_state->server_hello()->session_id());
+ if(auto active = active_state())
+ m_session_manager.remove_entry(active->server_hello()->session_id());
}
if(alert_msg.type() == Alert::CLOSE_NOTIFY)
@@ -529,8 +532,9 @@ void Channel::send_alert(const Alert& alert)
if(alert.type() == Alert::NO_RENEGOTIATION)
m_pending_state.reset();
- if(alert.is_fatal() && m_active_state && m_active_state->server_hello())
- m_session_manager.remove_entry(m_active_state->server_hello()->session_id());
+ if(alert.is_fatal())
+ if(auto active = active_state())
+ m_session_manager.remove_entry(active->server_hello()->session_id());
if(alert.type() == Alert::CLOSE_NOTIFY || alert.is_fatal())
{
@@ -545,9 +549,9 @@ void Channel::secure_renegotiation_check(const Client_Hello* client_hello)
{
const bool secure_renegotiation = client_hello->secure_renegotiation();
- if(m_active_state)
+ if(auto active = active_state())
{
- const bool active_sr = m_active_state->client_hello()->secure_renegotiation();
+ const bool active_sr = active->client_hello()->secure_renegotiation();
if(active_sr != secure_renegotiation)
throw TLS_Exception(Alert::HANDSHAKE_FAILURE,
@@ -568,9 +572,9 @@ void Channel::secure_renegotiation_check(const Server_Hello* server_hello)
{
const bool secure_renegotiation = server_hello->secure_renegotiation();
- if(m_active_state)
+ if(auto active = active_state())
{
- const bool active_sr = m_active_state->client_hello()->secure_renegotiation();
+ const bool active_sr = active->client_hello()->secure_renegotiation();
if(active_sr != secure_renegotiation)
throw TLS_Exception(Alert::HANDSHAKE_FAILURE,
@@ -589,17 +593,17 @@ void Channel::secure_renegotiation_check(const Server_Hello* server_hello)
std::vector<byte> Channel::secure_renegotiation_data_for_client_hello() const
{
- if(m_active_state)
- return m_active_state->client_finished()->verify_data();
+ if(auto active = active_state())
+ return active->client_finished()->verify_data();
return std::vector<byte>();
}
std::vector<byte> Channel::secure_renegotiation_data_for_server_hello() const
{
- if(m_active_state)
+ if(auto active = active_state())
{
- std::vector<byte> buf = m_active_state->client_finished()->verify_data();
- buf += m_active_state->server_finished()->verify_data();
+ std::vector<byte> buf = active->client_finished()->verify_data();
+ buf += active->server_finished()->verify_data();
return buf;
}
@@ -608,10 +612,13 @@ std::vector<byte> Channel::secure_renegotiation_data_for_server_hello() const
bool Channel::secure_renegotiation_supported() const
{
- if(m_active_state)
- return m_active_state->server_hello()->secure_renegotiation();
- if(m_pending_state && m_pending_state->server_hello())
- return m_pending_state->server_hello()->secure_renegotiation();
+ if(auto active = active_state())
+ return active->server_hello()->secure_renegotiation();
+
+ if(auto pending = pending_state())
+ if(auto hello = pending->server_hello())
+ return hello->secure_renegotiation();
+
return false;
}
@@ -619,32 +626,32 @@ SymmetricKey Channel::key_material_export(const std::string& label,
const std::string& context,
size_t length) const
{
- if(!m_active_state)
- throw std::runtime_error("Channel::key_material_export connection not active");
-
- Handshake_State& state = *m_active_state;
+ if(auto active = active_state())
+ {
+ std::unique_ptr<KDF> prf(active->protocol_specific_prf());
- std::unique_ptr<KDF> prf(state.protocol_specific_prf());
+ const secure_vector<byte>& master_secret =
+ active->session_keys().master_secret();
- const secure_vector<byte>& master_secret =
- state.session_keys().master_secret();
+ std::vector<byte> salt;
+ salt += to_byte_vector(label);
+ salt += active->client_hello()->random();
+ salt += active->server_hello()->random();
- std::vector<byte> salt;
- salt += to_byte_vector(label);
- salt += state.client_hello()->random();
- salt += state.server_hello()->random();
+ if(context != "")
+ {
+ size_t context_size = context.length();
+ if(context_size > 0xFFFF)
+ throw std::runtime_error("key_material_export context is too long");
+ salt.push_back(get_byte<u16bit>(0, context_size));
+ salt.push_back(get_byte<u16bit>(1, context_size));
+ salt += to_byte_vector(context);
+ }
- if(context != "")
- {
- size_t context_size = context.length();
- if(context_size > 0xFFFF)
- throw std::runtime_error("key_material_export context is too long");
- salt.push_back(get_byte<u16bit>(0, context_size));
- salt.push_back(get_byte<u16bit>(1, context_size));
- salt += to_byte_vector(context);
+ return prf->derive_key(length, master_secret, salt);
}
-
- return prf->derive_key(length, master_secret, salt);
+ else
+ throw std::runtime_error("Channel::key_material_export connection not active");
}
}
diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h
index 77e7c81f1..3e400ed5d 100644
--- a/src/tls/tls_channel.h
+++ b/src/tls/tls_channel.h
@@ -192,6 +192,10 @@ class BOTAN_DLL Channel
u16bit get_last_valid_epoch() const;
+ const Handshake_State* active_state() const { return m_active_state.get(); }
+
+ const Handshake_State* pending_state() const { return m_pending_state.get(); }
+
/* callbacks */
std::function<bool (const Session&)> m_handshake_fn;
std::function<void (const byte[], size_t, Alert)> m_proc_fn;