aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorlloyd <[email protected]>2012-09-07 16:49:44 +0000
committerlloyd <[email protected]>2012-09-07 16:49:44 +0000
commit6925789077edbeab13540cde3cf84b1d0e6feefc (patch)
treecebb71b3277288020b62f254c27c550c0448ce69 /src
parent9432fe7c5484c2f3515d40fedf117ddc860f6e14 (diff)
Keep two handshake states around, swap them when
Channel::activate_session is called.
Diffstat (limited to 'src')
-rw-r--r--src/tls/tls_channel.cpp117
-rw-r--r--src/tls/tls_channel.h16
-rw-r--r--src/tls/tls_client.cpp2
-rw-r--r--src/tls/tls_server.cpp2
4 files changed, 69 insertions, 68 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index 4a37f1c7a..62fc47c45 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -40,28 +40,28 @@ Channel::~Channel()
Handshake_State& Channel::create_handshake_state()
{
- if(m_state)
+ if(m_pending_state)
throw Internal_Error("create_handshake_state called during handshake");
- m_state.reset(new_handshake_state());
+ m_pending_state.reset(new_handshake_state());
- return *m_state.get();
+ return *m_pending_state.get();
}
void Channel::renegotiate(bool force_full_renegotiation)
{
- if(m_state) // currently in handshake?
+ if(m_pending_state) // currently in handshake?
return;
- m_state.reset(new_handshake_state());
+ m_pending_state.reset(new_handshake_state());
- initiate_handshake(*m_state.get(), force_full_renegotiation);
+ initiate_handshake(*m_pending_state.get(), force_full_renegotiation);
}
void Channel::set_protocol_version(Protocol_Version version)
{
m_current_version = version;
- m_state->set_version(version);
+ m_pending_state->set_version(version);
}
void Channel::set_maximum_fragment_size(size_t max_fragment)
@@ -74,7 +74,10 @@ void Channel::set_maximum_fragment_size(size_t max_fragment)
void Channel::change_cipher_spec_reader(Connection_Side side)
{
- if(m_state->server_hello()->compression_method()!= NO_COMPRESSION)
+ BOTAN_ASSERT(m_pending_state && m_pending_state->server_hello(),
+ "Have received server hello");
+
+ if(m_pending_state->server_hello()->compression_method()!= NO_COMPRESSION)
throw Internal_Error("Negotiated unknown compression algorithm");
m_read_seq_no = 0;
@@ -85,14 +88,17 @@ void Channel::change_cipher_spec_reader(Connection_Side side)
m_read_cipherstate.reset(
new Connection_Cipher_State(current_protocol_version(),
side,
- m_state->ciphersuite(),
- m_state->session_keys())
+ m_pending_state->ciphersuite(),
+ m_pending_state->session_keys())
);
}
void Channel::change_cipher_spec_writer(Connection_Side side)
{
- if(m_state->server_hello()->compression_method()!= NO_COMPRESSION)
+ BOTAN_ASSERT(m_pending_state && m_pending_state->server_hello(),
+ "Have received server hello");
+
+ if(m_pending_state->server_hello()->compression_method()!= NO_COMPRESSION)
throw Internal_Error("Negotiated unknown compression algorithm");
/*
@@ -111,25 +117,32 @@ void Channel::change_cipher_spec_writer(Connection_Side side)
m_write_cipherstate.reset(
new Connection_Cipher_State(current_protocol_version(),
side,
- m_state->ciphersuite(),
- m_state->session_keys())
+ m_pending_state->ciphersuite(),
+ m_pending_state->session_keys())
);
}
-void Channel::activate_session(const std::vector<byte>& session_id)
+void Channel::activate_session()
{
- m_secure_renegotiation.update(m_state->client_finished(),
- m_state->server_finished());
+ m_secure_renegotiation.update(m_pending_state->client_finished(),
+ m_pending_state->server_finished());
- m_state.reset();
- m_handshake_completed = true;
- m_active_session = session_id;
+ std::swap(m_active_state, m_pending_state);
+ m_pending_state.reset();
}
-void Channel::heartbeat_support(bool peer_supports, bool sending_allowed)
+bool Channel::peer_supports_heartbeats() const
{
- m_peer_supports_heartbeats = peer_supports;
- m_heartbeat_sending_allowed = sending_allowed;
+ if(m_active_state && m_active_state->server_hello())
+ return m_active_state->server_hello()->supports_heartbeats();
+ return false;
+ }
+
+bool Channel::heartbeat_sending_allowed() const
+ {
+ if(m_active_state && m_active_state->server_hello())
+ return m_active_state->server_hello()->peer_can_send_heartbeats();
+ return false;
}
size_t Channel::received_data(const byte buf[], size_t buf_size)
@@ -177,25 +190,23 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
if(rec_type == HANDSHAKE || rec_type == CHANGE_CIPHER_SPEC)
{
- if(!m_state)
- m_state.reset(new_handshake_state());
+ if(!m_pending_state)
+ m_pending_state.reset(new_handshake_state());
- m_state->handshake_io().add_input(rec_type,
- &record[0],
- record.size(),
- record_number);
+ m_pending_state->handshake_io().add_input(
+ rec_type, &record[0], record.size(), record_number);
- while(m_state)
+ while(m_pending_state)
{
- auto msg = m_state->get_next_handshake_msg();
+ auto msg = m_pending_state->get_next_handshake_msg();
if(msg.first == HANDSHAKE_NONE) // no full handshake yet
break;
- process_handshake_msg(*m_state.get(), msg.first, msg.second);
+ process_handshake_msg(*m_pending_state.get(), msg.first, msg.second);
}
}
- else if(rec_type == HEARTBEAT && m_peer_supports_heartbeats)
+ else if(rec_type == HEARTBEAT && peer_supports_heartbeats())
{
Heartbeat_Message heartbeat(record);
@@ -203,7 +214,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
if(heartbeat.is_request())
{
- if(!m_state) // no heartbeats during handshake
+ if(!m_pending_state) // no heartbeats during handshake
{
Heartbeat_Message response(Heartbeat_Message::RESPONSE,
&payload[0], payload.size());
@@ -219,7 +230,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
}
else if(rec_type == APPLICATION_DATA)
{
- if(m_handshake_completed)
+ if(m_active_state)
{
/*
* OpenSSL among others sends empty records in versions
@@ -239,10 +250,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
Alert alert_msg(record);
if(alert_msg.type() == Alert::NO_RENEGOTIATION)
- {
- if(m_handshake_completed && m_state)
- m_state.reset();
- }
+ m_pending_state.reset();
m_proc_fn(nullptr, 0, alert_msg);
@@ -257,15 +265,13 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
{
// delete state immediately
- if(!m_active_session.empty())
- {
- m_session_manager.remove_entry(m_active_session);
- m_active_session.clear();
- }
+ if(m_active_state && m_active_state->server_hello())
+ m_session_manager.remove_entry(m_active_state->server_hello()->session_id());
m_connection_closed = true;
- m_state.reset();
+ m_active_state.reset();
+ m_pending_state.reset();
m_write_cipherstate.reset();
m_read_cipherstate.reset();
@@ -302,10 +308,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
void Channel::heartbeat(const byte payload[], size_t payload_size)
{
- if(!is_active())
- throw std::runtime_error("Heartbeat cannot be sent on inactive TLS connection");
-
- if(m_heartbeat_sending_allowed)
+ if(heartbeat_sending_allowed())
{
Heartbeat_Message heartbeat(Heartbeat_Message::REQUEST,
payload, payload_size);
@@ -362,10 +365,10 @@ void Channel::write_record(byte record_type, const byte input[], size_t length)
Protocol_Version record_version = current_protocol_version();
if(!record_version.valid())
{
- BOTAN_ASSERT(m_state && !m_state->server_hello(),
+ BOTAN_ASSERT(m_pending_state && !m_pending_state->server_hello(),
"In first record of client connection");
- record_version = m_state->handshake_io().initial_record_version();
+ record_version = m_pending_state->handshake_io().initial_record_version();
}
TLS::write_record(m_writebuf,
@@ -406,19 +409,17 @@ void Channel::send_alert(const Alert& alert)
}
if(alert.type() == Alert::NO_RENEGOTIATION)
- m_state.reset();
+ m_pending_state.reset();
- if(alert.is_fatal() && !m_active_session.empty())
- {
- m_session_manager.remove_entry(m_active_session);
- m_active_session.clear();
- }
+ if(alert.is_fatal() && m_active_state && m_active_state->server_hello())
+ m_session_manager.remove_entry(m_active_state->server_hello()->session_id());
- if(!m_connection_closed && (alert.type() == Alert::CLOSE_NOTIFY || alert.is_fatal()))
+ if(alert.type() == Alert::CLOSE_NOTIFY || alert.is_fatal())
{
m_connection_closed = true;
- m_state.reset();
+ m_active_state.reset();
+ m_pending_state.reset();
m_write_cipherstate.reset();
}
}
diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h
index 29567da4c..109fbc771 100644
--- a/src/tls/tls_channel.h
+++ b/src/tls/tls_channel.h
@@ -52,7 +52,7 @@ class BOTAN_DLL Channel
/**
* @return true iff the connection is active for sending application data
*/
- bool is_active() const { return m_handshake_completed && !is_closed(); }
+ bool is_active() const { return m_active_state && !is_closed(); }
/**
* @return true iff the connection has been definitely closed
@@ -114,7 +114,7 @@ class BOTAN_DLL Channel
*/
void send_alert(const Alert& alert);
- void activate_session(const std::vector<byte>& session_id);
+ void activate_session();
void heartbeat_support(bool peer_supports, bool allowed_to_send);
@@ -174,6 +174,10 @@ class BOTAN_DLL Channel
void write_record(byte type, const byte input[], size_t length);
+ bool peer_supports_heartbeats() const;
+
+ bool heartbeat_sending_allowed() const;
+
/* callbacks */
std::function<void (const byte[], size_t, Alert)> m_proc_fn;
std::function<void (const byte[], size_t)> m_output_fn;
@@ -190,17 +194,13 @@ class BOTAN_DLL Channel
u64bit m_read_seq_no = 0;
/* connection parameters */
- std::unique_ptr<class Handshake_State> m_state;
+ std::unique_ptr<class Handshake_State> m_active_state;
+ std::unique_ptr<class Handshake_State> m_pending_state;
Protocol_Version m_current_version;
size_t m_max_fragment = MAX_PLAINTEXT_SIZE;
- bool m_peer_supports_heartbeats = false;
- bool m_heartbeat_sending_allowed = false;
-
bool m_connection_closed = false;
- bool m_handshake_completed = false;
- std::vector<byte> m_active_session;
};
}
diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp
index c88d0319a..0fdc24f59 100644
--- a/src/tls/tls_client.cpp
+++ b/src/tls/tls_client.cpp
@@ -516,7 +516,7 @@ void Client::process_handshake_msg(Handshake_State& state,
m_session_manager.remove_entry(session_info.session_id());
}
- activate_session(session_info.session_id());
+ activate_session();
}
else
throw Unexpected_Message("Unknown handshake message received");
diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp
index 5cd00c524..313640008 100644
--- a/src/tls/tls_server.cpp
+++ b/src/tls/tls_server.cpp
@@ -703,7 +703,7 @@ void Server::process_handshake_msg(Handshake_State& state,
);
}
- activate_session(state.server_hello()->session_id());
+ activate_session();
}
else
throw Unexpected_Message("Unknown handshake message received");