aboutsummaryrefslogtreecommitdiffstats
path: root/src/tls
diff options
context:
space:
mode:
authorlloyd <[email protected]>2012-11-06 20:58:33 +0000
committerlloyd <[email protected]>2012-11-06 20:58:33 +0000
commitdfbe4b328fa29b80f05bf89dd6c20be304312b17 (patch)
treee450a2f1cdd6073e1517eac4b101ff32bc2fafdc /src/tls
parente4622803c9e91e14942eae91f041e879bf8957fb (diff)
Store cipher states in Channel instead of Handshake_State. Keep all
around by default, expiring them as they are no longer needed. Expiration logic for DTLS needs some work.
Diffstat (limited to 'src/tls')
-rw-r--r--src/tls/tls_channel.cpp109
-rw-r--r--src/tls/tls_channel.h19
-rw-r--r--src/tls/tls_handshake_state.cpp20
-rw-r--r--src/tls/tls_handshake_state.h17
-rw-r--r--src/tls/tls_record.cpp15
-rw-r--r--src/tls/tls_record.h9
-rw-r--r--src/tls/tls_seq_numbers.h22
7 files changed, 145 insertions, 66 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index be3f1c784..7065064cc 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -44,6 +44,36 @@ Connection_Sequence_Numbers& Channel::sequence_numbers() const
return *m_sequence_numbers;
}
+std::shared_ptr<Connection_Cipher_State> Channel::read_cipher_state_epoch(u16bit epoch) const
+ {
+ auto i = m_read_cipher_states.find(epoch);
+
+ BOTAN_ASSERT(i != m_read_cipher_states.end(),
+ "Have a cipher state for the specified epoch");
+
+ return i->second;
+ }
+
+std::shared_ptr<Connection_Cipher_State> Channel::write_cipher_state_epoch(u16bit epoch) const
+ {
+ auto i = m_write_cipher_states.find(epoch);
+
+ BOTAN_ASSERT(i != m_write_cipher_states.end(),
+ "Have a cipher state for the specified epoch");
+
+ return i->second;
+ }
+
+std::shared_ptr<Connection_Cipher_State> Channel::read_cipher_state_current() const
+ {
+ return read_cipher_state_epoch(sequence_numbers().current_read_epoch());
+ }
+
+std::shared_ptr<Connection_Cipher_State> Channel::write_cipher_state_current() const
+ {
+ return write_cipher_state_epoch(sequence_numbers().current_write_epoch());
+ }
+
std::vector<X509_Certificate> Channel::peer_cert_chain() const
{
if(!m_active_state)
@@ -91,10 +121,7 @@ Handshake_State& Channel::create_handshake_state(Protocol_Version version)
m_pending_state.reset(new_handshake_state(io.release()));
if(m_active_state)
- {
m_pending_state->set_version(m_active_state->version());
- m_pending_state->copy_cipher_states(*m_active_state);
- }
return *m_pending_state.get();
}
@@ -129,8 +156,19 @@ void Channel::change_cipher_spec_reader(Connection_Side side)
sequence_numbers().new_read_cipher_state();
+ const u16bit epoch = sequence_numbers().current_read_epoch();
+
+ BOTAN_ASSERT(m_read_cipher_states.count(epoch) == 0,
+ "No read cipher state currently set for next epoch");
+
// flip side as we are reading
- m_pending_state->new_read_cipher_state((side == CLIENT) ? SERVER : CLIENT);
+ std::shared_ptr<Connection_Cipher_State> read_state(
+ new Connection_Cipher_State(m_pending_state->version(),
+ (side == CLIENT) ? SERVER : CLIENT,
+ m_pending_state->ciphersuite(),
+ m_pending_state->session_keys()));
+
+ m_read_cipher_states[epoch] = read_state;
}
void Channel::change_cipher_spec_writer(Connection_Side side)
@@ -143,7 +181,18 @@ void Channel::change_cipher_spec_writer(Connection_Side side)
sequence_numbers().new_write_cipher_state();
- m_pending_state->new_write_cipher_state(side);
+ const u16bit epoch = sequence_numbers().current_write_epoch();
+
+ BOTAN_ASSERT(m_write_cipher_states.count(epoch) == 0,
+ "No write cipher state currently set for next epoch");
+
+ std::shared_ptr<Connection_Cipher_State> write_state(
+ new Connection_Cipher_State(m_pending_state->version(),
+ side,
+ m_pending_state->ciphersuite(),
+ m_pending_state->session_keys()));
+
+ m_write_cipher_states[epoch] = write_state;
}
bool Channel::is_active() const
@@ -160,6 +209,41 @@ void Channel::activate_session()
{
std::swap(m_active_state, m_pending_state);
m_pending_state.reset();
+
+ const u16bit last_valid_epoch = get_last_valid_epoch();
+
+ const auto obsolete_epoch =
+ [last_valid_epoch](u16bit epoch) { return (epoch < last_valid_epoch); };
+
+ map_remove_if(obsolete_epoch, m_write_cipher_states);
+ map_remove_if(obsolete_epoch, m_read_cipher_states);
+ }
+
+u16bit Channel::get_last_valid_epoch() const
+ {
+ if(m_active_state->version().is_datagram_protocol())
+ {
+ // DTLS: find first epoch less than TCP MSL
+
+ // FIXME: what about lost/retransmitted flights?
+ const std::chrono::seconds tcp_msl(120);
+
+ for(auto i : m_read_cipher_states)
+ {
+ if(i.second->age() <= tcp_msl)
+ return i.first;
+
+ if(i.first == sequence_numbers().current_read_epoch())
+ return i.first;
+ }
+
+ throw std::logic_error("Could not find current DTLS epoch");
+ }
+ else
+ {
+ // TLS is easy case
+ return sequence_numbers().current_write_epoch();
+ }
}
bool Channel::peer_supports_heartbeats() const
@@ -189,11 +273,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
size_t consumed = 0;
- std::shared_ptr<Connection_Cipher_State> cipher_state;
- if(m_pending_state)
- cipher_state = m_pending_state->read_cipher_state();
- else if(m_active_state)
- cipher_state = m_active_state->read_cipher_state();
+ auto cipher_state = read_cipher_state_current();
const size_t needed =
read_record(m_readbuf,
@@ -231,7 +311,8 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
if(!m_pending_state)
{
create_handshake_state(record_version);
- sequence_numbers().read_accept(record_sequence);
+ if(record_version.is_datagram_protocol())
+ sequence_numbers().read_accept(record_sequence);
}
m_pending_state->handshake_io().add_input(
@@ -372,12 +453,8 @@ void Channel::send_record_array(byte type, const byte input[], size_t length)
*
* See http://www.openssl.org/~bodo/tls-cbc.txt for background.
*/
- std::shared_ptr<Connection_Cipher_State> cipher_state;
- if(m_pending_state)
- cipher_state = m_pending_state->write_cipher_state();
- else if(m_active_state)
- cipher_state = m_active_state->write_cipher_state();
+ auto cipher_state = write_cipher_state_current();
if(type == APPLICATION_DATA && cipher_state->cbc_without_explicit_iv())
{
diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h
index 1a30c5604..77e7c81f1 100644
--- a/src/tls/tls_channel.h
+++ b/src/tls/tls_channel.h
@@ -16,6 +16,7 @@
#include <vector>
#include <string>
#include <memory>
+#include <map>
namespace Botan {
@@ -181,6 +182,16 @@ class BOTAN_DLL Channel
Connection_Sequence_Numbers& sequence_numbers() const;
+ std::shared_ptr<Connection_Cipher_State> read_cipher_state_epoch(u16bit epoch) const;
+
+ std::shared_ptr<Connection_Cipher_State> write_cipher_state_epoch(u16bit epoch) const;
+
+ std::shared_ptr<Connection_Cipher_State> read_cipher_state_current() const;
+
+ std::shared_ptr<Connection_Cipher_State> write_cipher_state_current() const;
+
+ u16bit get_last_valid_epoch() const;
+
/* callbacks */
std::function<bool (const Session&)> m_handshake_fn;
std::function<void (const byte[], size_t, Alert)> m_proc_fn;
@@ -197,7 +208,13 @@ class BOTAN_DLL Channel
std::vector<byte> m_writebuf;
std::vector<byte> m_readbuf;
- /* connection parameters */
+ /* cipher states for each epoch - epoch 0 is plaintext, thus null cipher state */
+ std::map<u16bit, std::shared_ptr<Connection_Cipher_State>> m_write_cipher_states =
+ { { 0, nullptr } };
+ std::map<u16bit, std::shared_ptr<Connection_Cipher_State>> m_read_cipher_states =
+ { { 0, nullptr } };
+
+ /* pending and active connection states */
std::unique_ptr<Handshake_State> m_active_state;
std::unique_ptr<Handshake_State> m_pending_state;
diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp
index 044e97366..8ff0fb585 100644
--- a/src/tls/tls_handshake_state.cpp
+++ b/src/tls/tls_handshake_state.cpp
@@ -200,26 +200,6 @@ void Handshake_State::compute_session_keys(const secure_vector<byte>& resume_mas
m_session_keys = Session_Keys(this, resume_master_secret, true);
}
-void Handshake_State::copy_cipher_states(const Handshake_State& prev_state)
- {
- m_write_cipher_state = prev_state.m_write_cipher_state;
- m_read_cipher_state = prev_state.m_read_cipher_state;
- }
-
-void Handshake_State::new_read_cipher_state(Connection_Side side)
- {
- m_read_cipher_state.reset(
- new Connection_Cipher_State(version(), side, ciphersuite(), session_keys())
- );
- }
-
-void Handshake_State::new_write_cipher_state(Connection_Side side)
- {
- m_write_cipher_state.reset(
- new Connection_Cipher_State(version(), side, ciphersuite(), session_keys())
- );
- }
-
void Handshake_State::confirm_transition_to(Handshake_Type handshake_msg)
{
const u32bit mask = bitmask_for_handshake_type(handshake_msg);
diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h
index fee9bd5d1..9afcd0374 100644
--- a/src/tls/tls_handshake_state.h
+++ b/src/tls/tls_handshake_state.h
@@ -41,8 +41,6 @@ class Next_Protocol;
class New_Session_Ticket;
class Finished;
-class Connection_Cipher_State;
-
/**
* SSL/TLS Handshake State
*/
@@ -120,10 +118,6 @@ class Handshake_State
void server_finished(Finished* server_finished);
void client_finished(Finished* client_finished);
- void new_read_cipher_state(Connection_Side side);
-
- void new_write_cipher_state(Connection_Side side);
-
const Client_Hello* client_hello() const
{ return m_client_hello.get(); }
@@ -181,23 +175,12 @@ class Handshake_State
m_msg_callback(msg);
}
- std::shared_ptr<Connection_Cipher_State> read_cipher_state()
- { return m_read_cipher_state; }
-
- std::shared_ptr<Connection_Cipher_State> write_cipher_state()
- { return m_write_cipher_state; }
-
- void copy_cipher_states(const Handshake_State& prev_state);
-
private:
std::function<void (const Handshake_Message&)> m_msg_callback;
std::unique_ptr<Handshake_IO> m_handshake_io;
- std::shared_ptr<Connection_Cipher_State> m_write_cipher_state;
- std::shared_ptr<Connection_Cipher_State> m_read_cipher_state;
-
u32bit m_hand_expecting_mask = 0;
u32bit m_hand_received_mask = 0;
Protocol_Version m_version;
diff --git a/src/tls/tls_record.cpp b/src/tls/tls_record.cpp
index d0bc8bc69..b2addf116 100644
--- a/src/tls/tls_record.cpp
+++ b/src/tls/tls_record.cpp
@@ -24,6 +24,7 @@ Connection_Cipher_State::Connection_Cipher_State(Protocol_Version version,
Connection_Side side,
const Ciphersuite& suite,
const Session_Keys& keys) :
+ m_start_time(std::chrono::system_clock::now()),
m_is_ssl3(version == Protocol_Version::SSL_V3)
{
SymmetricKey mac_key, cipher_key;
@@ -341,13 +342,6 @@ size_t read_record(std::vector<byte>& readbuf,
const size_t header_size =
(record_version.is_datagram_protocol()) ? DTLS_HEADER_SIZE : TLS_HEADER_SIZE;
- if(record_version.is_datagram_protocol())
- record_sequence = load_be<u64bit>(&readbuf[3], 0);
- else if(sequence_numbers)
- record_sequence = sequence_numbers->next_read_sequence();
- else
- record_sequence = 0; // server initial handshake case
-
const size_t record_len = make_u16bit(readbuf[header_size-2],
readbuf[header_size-1]);
@@ -364,6 +358,13 @@ size_t read_record(std::vector<byte>& readbuf,
readbuf.size(),
"Have the full record");
+ if(record_version.is_datagram_protocol())
+ record_sequence = load_be<u64bit>(&readbuf[3], 0);
+ else if(sequence_numbers)
+ record_sequence = sequence_numbers->next_read_sequence();
+ else
+ record_sequence = 0; // server initial handshake case
+
if(sequence_numbers && sequence_numbers->already_seen(record_sequence))
return 0;
diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h
index 5c5d64d0d..dbe77cbd2 100644
--- a/src/tls/tls_record.h
+++ b/src/tls/tls_record.h
@@ -15,6 +15,7 @@
#include <botan/mac.h>
#include <vector>
#include <memory>
+#include <chrono>
namespace Botan {
@@ -59,7 +60,15 @@ class Connection_Cipher_State
bool cbc_without_explicit_iv() const
{ return (m_block_size > 0) && (m_iv_size == 0); }
+
+ std::chrono::seconds age() const
+ {
+ return std::chrono::duration_cast<std::chrono::seconds>(
+ std::chrono::system_clock::now() - m_start_time);
+ }
+
private:
+ std::chrono::system_clock::time_point m_start_time;
std::unique_ptr<BlockCipher> m_block_cipher;
secure_vector<byte> m_block_cipher_cbc_state;
std::unique_ptr<StreamCipher> m_stream_cipher;
diff --git a/src/tls/tls_seq_numbers.h b/src/tls/tls_seq_numbers.h
index c9a334e4b..4a8a0fab8 100644
--- a/src/tls/tls_seq_numbers.h
+++ b/src/tls/tls_seq_numbers.h
@@ -20,9 +20,12 @@ class Connection_Sequence_Numbers
virtual void new_read_cipher_state() = 0;
virtual void new_write_cipher_state() = 0;
- virtual u64bit next_write_sequence() = 0;
+ virtual u16bit current_read_epoch() const = 0;
+ virtual u16bit current_write_epoch() const = 0;
+ virtual u64bit next_write_sequence() = 0;
virtual u64bit next_read_sequence() = 0;
+
virtual bool already_seen(u64bit seq) const = 0;
virtual void read_accept(u64bit seq) = 0;
};
@@ -30,23 +33,28 @@ class Connection_Sequence_Numbers
class Stream_Sequence_Numbers : public Connection_Sequence_Numbers
{
public:
- void new_read_cipher_state() override { m_read_seq_no = 0; }
- void new_write_cipher_state() override { m_write_seq_no = 0; }
+ void new_read_cipher_state() override { m_read_seq_no = 0; m_read_epoch += 1; }
+ void new_write_cipher_state() override { m_write_seq_no = 0; m_write_epoch += 1; }
- u64bit next_write_sequence() override { return m_write_seq_no++; }
+ u16bit current_read_epoch() const override { return m_read_epoch; }
+ u16bit current_write_epoch() const override { return m_write_epoch; }
+ u64bit next_write_sequence() override { return m_write_seq_no++; }
u64bit next_read_sequence() override { return m_read_seq_no; }
+
bool already_seen(u64bit) const override { return false; }
void read_accept(u64bit) override { m_read_seq_no++; }
private:
u64bit m_write_seq_no = 0;
u64bit m_read_seq_no = 0;
+ u16bit m_read_epoch = 0;
+ u16bit m_write_epoch = 0;
};
class Datagram_Sequence_Numbers : public Connection_Sequence_Numbers
{
public:
- void new_read_cipher_state() override {}
+ void new_read_cipher_state() override { m_read_epoch += 1; }
void new_write_cipher_state() override
{
@@ -54,6 +62,9 @@ class Datagram_Sequence_Numbers : public Connection_Sequence_Numbers
m_write_seq_no = ((m_write_seq_no >> 48) + 1) << 48;
}
+ u16bit current_read_epoch() const override { return m_read_epoch; }
+ u16bit current_write_epoch() const override { return (m_write_seq_no >> 48); }
+
u64bit next_write_sequence() override { return m_write_seq_no++; }
u64bit next_read_sequence() override
@@ -101,6 +112,7 @@ class Datagram_Sequence_Numbers : public Connection_Sequence_Numbers
private:
u64bit m_write_seq_no = 0;
+ u16bit m_read_epoch = 0;
u64bit m_window_highest = 0;
u64bit m_window_bits = 0;
};