aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorlloyd <[email protected]>2012-08-08 18:22:12 +0000
committerlloyd <[email protected]>2012-08-08 18:22:12 +0000
commit67dc8001da77de044c21a262087e666fe205c10f (patch)
treed2e50e65b03f1792209f88f32f49420985b01b64
parent1f8370e2a54a68a1fb18cb48babf721086e45dc3 (diff)
DTLS needs some help with ChangeCipherSpec because it is not included
in the message_seq count. When we are asking for the next handshake msg, tell the handshake IO layer if we are expecting a CCS or not. Then DTLS just needs to track which epoch(s) it has seen the CCS for, and which epoch it is currently in. This is all ignored by the stream IO layer.
-rw-r--r--src/tls/tls_channel.cpp2
-rw-r--r--src/tls/tls_handshake_io.cpp18
-rw-r--r--src/tls/tls_handshake_io.h9
-rw-r--r--src/tls/tls_handshake_state.cpp9
-rw-r--r--src/tls/tls_handshake_state.h24
5 files changed, 51 insertions, 11 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index ed1c5fc75..3e5bdbabd 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -79,7 +79,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
while(m_state)
{
- auto msg = m_state->handshake_io().get_next_record();
+ auto msg = m_state->get_next_handshake_msg();
if(msg.first == HANDSHAKE_NONE) // no full handshake yet
break;
diff --git a/src/tls/tls_handshake_io.cpp b/src/tls/tls_handshake_io.cpp
index f46fdb094..5901eae04 100644
--- a/src/tls/tls_handshake_io.cpp
+++ b/src/tls/tls_handshake_io.cpp
@@ -61,7 +61,7 @@ void Stream_Handshake_IO::add_input(const byte rec_type,
}
std::pair<Handshake_Type, std::vector<byte>>
-Stream_Handshake_IO::get_next_record()
+Stream_Handshake_IO::get_next_record(bool)
{
if(m_queue.size() >= 4)
{
@@ -119,9 +119,11 @@ void Datagram_Handshake_IO::add_input(const byte rec_type,
size_t record_size,
u64bit record_number)
{
+ const u16bit epoch = static_cast<u16bit>(record_number >> 48);
+
if(rec_type == CHANGE_CIPHER_SPEC)
{
- m_ccs_epochs.insert(static_cast<u16bit>(record_number >> 48));
+ m_ccs_epochs.insert(epoch);
return;
}
@@ -147,8 +149,18 @@ void Datagram_Handshake_IO::add_input(const byte rec_type,
}
std::pair<Handshake_Type, std::vector<byte>>
-Datagram_Handshake_IO::get_next_record()
+Datagram_Handshake_IO::get_next_record(bool expecting_ccs)
{
+ if(expecting_ccs)
+ {
+ const u16bit current_epoch = 0; // fixme
+
+ if(m_ccs_epochs.count(current_epoch))
+ return std::make_pair(HANDSHAKE_CCS, std::vector<byte>());
+ else
+ return std::make_pair(HANDSHAKE_NONE, std::vector<byte>());
+ }
+
auto i = m_messages.find(m_in_message_seq);
if(i == m_messages.end() || !i->second.complete())
diff --git a/src/tls/tls_handshake_io.h b/src/tls/tls_handshake_io.h
index 7ed459049..b83425281 100644
--- a/src/tls/tls_handshake_io.h
+++ b/src/tls/tls_handshake_io.h
@@ -46,7 +46,8 @@ class Handshake_IO
/**
* Returns (HANDSHAKE_NONE, std::vector<>()) if no message currently available
*/
- virtual std::pair<Handshake_Type, std::vector<byte> > get_next_record() = 0;
+ virtual std::pair<Handshake_Type, std::vector<byte>>
+ get_next_record(bool expecting_ccs) = 0;
Handshake_IO() {}
@@ -78,7 +79,8 @@ class Stream_Handshake_IO : public Handshake_IO
size_t record_size,
u64bit record_number) override;
- std::pair<Handshake_Type, std::vector<byte> > get_next_record() override;
+ std::pair<Handshake_Type, std::vector<byte>>
+ get_next_record(bool expecting_ccs) override;
private:
std::deque<byte> m_queue;
Record_Writer& m_writer;
@@ -105,7 +107,8 @@ class Datagram_Handshake_IO : public Handshake_IO
size_t record_size,
u64bit record_number) override;
- std::pair<Handshake_Type, std::vector<byte>> get_next_record() override;
+ std::pair<Handshake_Type, std::vector<byte>>
+ get_next_record(bool expecting_ccs) override;
private:
class Handshake_Reassembly
{
diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp
index 082461dc9..24e00772d 100644
--- a/src/tls/tls_handshake_state.cpp
+++ b/src/tls/tls_handshake_state.cpp
@@ -222,6 +222,15 @@ bool Handshake_State::received_handshake_msg(Handshake_Type handshake_msg) const
return (m_hand_received_mask & mask);
}
+std::pair<Handshake_Type, std::vector<byte>>
+Handshake_State::get_next_handshake_msg()
+ {
+ const bool expecting_ccs =
+ (bitmask_for_handshake_type(HANDSHAKE_CCS) & m_hand_expecting_mask);
+
+ return m_handshake_io->get_next_record(expecting_ccs);
+ }
+
std::string Handshake_State::srp_identifier() const
{
if(ciphersuite().valid() && ciphersuite().kex_algo() == "SRP_SHA")
diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h
index 81a603c6f..486e6bfb8 100644
--- a/src/tls/tls_handshake_state.h
+++ b/src/tls/tls_handshake_state.h
@@ -56,10 +56,26 @@ class Handshake_State
Handshake_IO& handshake_io() { return *m_handshake_io; }
- bool received_handshake_msg(Handshake_Type handshake_msg) const;
-
- void confirm_transition_to(Handshake_Type handshake_msg);
- void set_expected_next(Handshake_Type handshake_msg);
+ /**
+ * Return true iff we have received a particular message already
+ * @param msg_type the message type
+ */
+ bool received_handshake_msg(Handshake_Type msg_type) const;
+
+ /**
+ * Confirm that we were expecting this message type
+ * @param msg_type the message type
+ */
+ void confirm_transition_to(Handshake_Type msg_type);
+
+ /**
+ * Record that we are expecting a particular message type next
+ * @param msg_type the message type
+ */
+ void set_expected_next(Handshake_Type msg_type);
+
+ std::pair<Handshake_Type, std::vector<byte>>
+ get_next_handshake_msg();
const std::vector<byte>& session_ticket() const;