aboutsummaryrefslogtreecommitdiffstats
path: root/src/tls
diff options
context:
space:
mode:
authorlloyd <[email protected]>2012-09-13 23:23:42 +0000
committerlloyd <[email protected]>2012-09-13 23:23:42 +0000
commitbddb0f2075bd12ad88a541b1d9ba04d4a80e0767 (patch)
treeb8c8e715ff62738a463dec762535e0beed27e922 /src/tls
parent4393f9d9db263510e59424a41b14f7cde7206825 (diff)
Store the cipher states in the handshake state object as shared_ptrs.
One notable change here is that after we send a close_alert, we ignore any data that follows. That is somewhat unfortunate actually, but overall this change is important (for DTLS).
Diffstat (limited to 'src/tls')
-rw-r--r--src/tls/tls_channel.cpp76
-rw-r--r--src/tls/tls_channel.h19
-rw-r--r--src/tls/tls_handshake_state.cpp23
-rw-r--r--src/tls/tls_handshake_state.h19
4 files changed, 94 insertions, 43 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index 6448ca2d4..95b3d1bbb 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -44,6 +44,24 @@ Connection_Sequence_Numbers& Channel::sequence_numbers() const
return *m_sequence_numbers;
}
+std::shared_ptr<Connection_Cipher_State> Channel::read_cipher_state() const
+ {
+ if(m_pending_state)
+ return m_pending_state->read_cipher_state();
+ if(m_active_state)
+ return m_active_state->read_cipher_state();
+ return std::shared_ptr<Connection_Cipher_State>(nullptr);
+ }
+
+std::shared_ptr<Connection_Cipher_State> Channel::write_cipher_state() const
+ {
+ if(m_pending_state)
+ return m_pending_state->write_cipher_state();
+ if(m_active_state)
+ return m_active_state->write_cipher_state();
+ return std::shared_ptr<Connection_Cipher_State>(nullptr);
+ }
+
std::vector<X509_Certificate> Channel::peer_cert_chain() const
{
if(!m_active_state)
@@ -91,7 +109,10 @@ Handshake_State& Channel::create_handshake_state(Protocol_Version version)
m_pending_state.reset(new_handshake_state(io.release()));
if(m_active_state)
+ {
m_pending_state->set_version(m_active_state->version());
+ m_pending_state->copy_cipher_states(*m_active_state);
+ }
return *m_pending_state.get();
}
@@ -127,14 +148,7 @@ void Channel::change_cipher_spec_reader(Connection_Side side)
sequence_numbers().new_read_cipher_state();
// flip side as we are reading
- side = (side == CLIENT) ? SERVER : CLIENT;
-
- m_read_cipherstate.reset(
- new Connection_Cipher_State(m_pending_state->version(),
- side,
- m_pending_state->ciphersuite(),
- m_pending_state->session_keys())
- );
+ m_pending_state->new_read_cipher_state((side == CLIENT) ? SERVER : CLIENT);
}
void Channel::change_cipher_spec_writer(Connection_Side side)
@@ -147,12 +161,7 @@ void Channel::change_cipher_spec_writer(Connection_Side side)
sequence_numbers().new_write_cipher_state();
- m_write_cipherstate.reset(
- new Connection_Cipher_State(m_pending_state->version(),
- side,
- m_pending_state->ciphersuite(),
- m_pending_state->session_keys())
- );
+ m_pending_state->new_write_cipher_state(side);
}
void Channel::activate_session()
@@ -179,7 +188,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
{
try
{
- while(buf_size)
+ while(!is_closed() && buf_size)
{
byte rec_type = NO_RECORD;
std::vector<byte> record;
@@ -188,6 +197,8 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
size_t consumed = 0;
+ std::shared_ptr<Connection_Cipher_State> cipher_state = read_cipher_state();
+
const size_t needed =
read_record(m_readbuf,
buf,
@@ -198,7 +209,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
record_version,
record_sequence,
m_sequence_numbers.get(),
- m_read_cipherstate.get());
+ cipher_state.get());
BOTAN_ASSERT(consumed <= buf_size,
"Record reader consumed sane amount");
@@ -289,27 +300,22 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
m_proc_fn(nullptr, 0, alert_msg);
- if(alert_msg.type() == Alert::CLOSE_NOTIFY)
- {
- if(!m_connection_closed)
- send_alert(Alert(Alert::CLOSE_NOTIFY)); // reply in kind
- m_read_cipherstate.reset();
- }
- else if(alert_msg.is_fatal())
+ if(alert_msg.is_fatal())
{
- // delete state immediately
-
if(m_active_state && m_active_state->server_hello())
m_session_manager.remove_entry(m_active_state->server_hello()->session_id());
+ }
+ if(alert_msg.type() == Alert::CLOSE_NOTIFY)
+ send_alert(Alert(Alert::CLOSE_NOTIFY)); // reply in kind
+
+ if(alert_msg.type() == Alert::CLOSE_NOTIFY || alert_msg.is_fatal())
+ {
m_connection_closed = true;
m_active_state.reset();
m_pending_state.reset();
- m_write_cipherstate.reset();
- m_read_cipherstate.reset();
-
return 0;
}
}
@@ -370,9 +376,11 @@ void Channel::send_record_array(byte type, const byte input[], size_t length)
*
* See http://www.openssl.org/~bodo/tls-cbc.txt for background.
*/
- if(type == APPLICATION_DATA && m_write_cipherstate->cbc_without_explicit_iv())
+ std::shared_ptr<Connection_Cipher_State> cipher_state = write_cipher_state();
+
+ if(type == APPLICATION_DATA && cipher_state->cbc_without_explicit_iv())
{
- write_record(type, &input[0], 1);
+ write_record(cipher_state.get(), type, &input[0], 1);
input += 1;
length -= 1;
}
@@ -380,7 +388,7 @@ void Channel::send_record_array(byte type, const byte input[], size_t length)
while(length)
{
const size_t sending = std::min(length, m_max_fragment);
- write_record(type, &input[0], sending);
+ write_record(cipher_state.get(), type, &input[0], sending);
input += sending;
length -= sending;
@@ -392,7 +400,8 @@ void Channel::send_record(byte record_type, const std::vector<byte>& record)
send_record_array(record_type, &record[0], record.size());
}
-void Channel::write_record(byte record_type, const byte input[], size_t length)
+void Channel::write_record(Connection_Cipher_State* cipher_state,
+ byte record_type, const byte input[], size_t length)
{
if(length > m_max_fragment)
throw Internal_Error("Record is larger than allowed fragment size");
@@ -409,7 +418,7 @@ void Channel::write_record(byte record_type, const byte input[], size_t length)
length,
record_version,
sequence_numbers(),
- m_write_cipherstate.get(),
+ cipher_state,
m_rng);
m_output_fn(&m_writebuf[0], m_writebuf.size());
@@ -449,7 +458,6 @@ void Channel::send_alert(const Alert& alert)
{
m_active_state.reset();
m_pending_state.reset();
- m_write_cipherstate.reset();
m_connection_closed = true;
}
diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h
index fa1fd3756..95420dcb3 100644
--- a/src/tls/tls_channel.h
+++ b/src/tls/tls_channel.h
@@ -21,6 +21,8 @@ namespace Botan {
namespace TLS {
+class Connection_Cipher_State;
+class Connection_Sequence_Numbers;
class Handshake_State;
/**
@@ -54,7 +56,7 @@ class BOTAN_DLL Channel
/**
* @return true iff the connection is active for sending application data
*/
- bool is_active() const { return m_active_state && !is_closed(); }
+ bool is_active() const { return m_active_state.get(); }
/**
* @return true iff the connection has been definitely closed
@@ -160,13 +162,18 @@ class BOTAN_DLL Channel
void send_record_array(byte type, const byte input[], size_t length);
- void write_record(byte type, const byte input[], size_t length);
+ void write_record(Connection_Cipher_State* cipher_state,
+ byte type, const byte input[], size_t length);
bool peer_supports_heartbeats() const;
bool heartbeat_sending_allowed() const;
- class Connection_Sequence_Numbers& sequence_numbers() const;
+ Connection_Sequence_Numbers& sequence_numbers() const;
+
+ std::shared_ptr<Connection_Cipher_State> read_cipher_state() const;
+
+ std::shared_ptr<Connection_Cipher_State> write_cipher_state() const;
/* callbacks */
std::function<bool (const Session&)> m_handshake_fn;
@@ -177,10 +184,8 @@ class BOTAN_DLL Channel
RandomNumberGenerator& m_rng;
Session_Manager& m_session_manager;
- /* cipher/sequence state */
- std::unique_ptr<class Connection_Sequence_Numbers> m_sequence_numbers;
- std::unique_ptr<class Connection_Cipher_State> m_write_cipherstate;
- std::unique_ptr<class Connection_Cipher_State> m_read_cipherstate;
+ /* sequence number state */
+ std::unique_ptr<Connection_Sequence_Numbers> m_sequence_numbers;
/* I/O buffers */
std::vector<byte> m_writebuf;
diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp
index 30474aae0..044e97366 100644
--- a/src/tls/tls_handshake_state.cpp
+++ b/src/tls/tls_handshake_state.cpp
@@ -7,6 +7,7 @@
#include <botan/internal/tls_handshake_state.h>
#include <botan/internal/tls_messages.h>
+#include <botan/internal/tls_record.h>
#include <botan/internal/assert.h>
#include <botan/lookup.h>
@@ -87,8 +88,8 @@ u32bit bitmask_for_handshake_type(Handshake_Type type)
*/
Handshake_State::Handshake_State(Handshake_IO* io,
std::function<void (const Handshake_Message&)> msg_callback) :
- m_handshake_io(io),
m_msg_callback(msg_callback),
+ m_handshake_io(io),
m_version(m_handshake_io->initial_record_version())
{
}
@@ -199,6 +200,26 @@ void Handshake_State::compute_session_keys(const secure_vector<byte>& resume_mas
m_session_keys = Session_Keys(this, resume_master_secret, true);
}
+void Handshake_State::copy_cipher_states(const Handshake_State& prev_state)
+ {
+ m_write_cipher_state = prev_state.m_write_cipher_state;
+ m_read_cipher_state = prev_state.m_read_cipher_state;
+ }
+
+void Handshake_State::new_read_cipher_state(Connection_Side side)
+ {
+ m_read_cipher_state.reset(
+ new Connection_Cipher_State(version(), side, ciphersuite(), session_keys())
+ );
+ }
+
+void Handshake_State::new_write_cipher_state(Connection_Side side)
+ {
+ m_write_cipher_state.reset(
+ new Connection_Cipher_State(version(), side, ciphersuite(), session_keys())
+ );
+ }
+
void Handshake_State::confirm_transition_to(Handshake_Type handshake_msg)
{
const u32bit mask = bitmask_for_handshake_type(handshake_msg);
diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h
index 5145958ef..fee9bd5d1 100644
--- a/src/tls/tls_handshake_state.h
+++ b/src/tls/tls_handshake_state.h
@@ -41,6 +41,8 @@ class Next_Protocol;
class New_Session_Ticket;
class Finished;
+class Connection_Cipher_State;
+
/**
* SSL/TLS Handshake State
*/
@@ -118,6 +120,10 @@ class Handshake_State
void server_finished(Finished* server_finished);
void client_finished(Finished* client_finished);
+ void new_read_cipher_state(Connection_Side side);
+
+ void new_write_cipher_state(Connection_Side side);
+
const Client_Hello* client_hello() const
{ return m_client_hello.get(); }
@@ -175,11 +181,22 @@ class Handshake_State
m_msg_callback(msg);
}
+ std::shared_ptr<Connection_Cipher_State> read_cipher_state()
+ { return m_read_cipher_state; }
+
+ std::shared_ptr<Connection_Cipher_State> write_cipher_state()
+ { return m_write_cipher_state; }
+
+ void copy_cipher_states(const Handshake_State& prev_state);
+
private:
+ std::function<void (const Handshake_Message&)> m_msg_callback;
+
std::unique_ptr<Handshake_IO> m_handshake_io;
- std::function<void (const Handshake_Message&)> m_msg_callback;
+ std::shared_ptr<Connection_Cipher_State> m_write_cipher_state;
+ std::shared_ptr<Connection_Cipher_State> m_read_cipher_state;
u32bit m_hand_expecting_mask = 0;
u32bit m_hand_received_mask = 0;