diff options
author | lloyd <[email protected]> | 2012-08-08 18:22:12 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2012-08-08 18:22:12 +0000 |
commit | 67dc8001da77de044c21a262087e666fe205c10f (patch) | |
tree | d2e50e65b03f1792209f88f32f49420985b01b64 /src/tls | |
parent | 1f8370e2a54a68a1fb18cb48babf721086e45dc3 (diff) |
DTLS needs some help with ChangeCipherSpec because it is not included
in the message_seq count. When we are asking for the next handshake
msg, tell the handshake IO layer if we are expecting a CCS or not.
Then DTLS just needs to track which epoch(s) it has seen the CCS for,
and which epoch it is currently in. This is all ignored by the stream
IO layer.
Diffstat (limited to 'src/tls')
-rw-r--r-- | src/tls/tls_channel.cpp | 2 | ||||
-rw-r--r-- | src/tls/tls_handshake_io.cpp | 18 | ||||
-rw-r--r-- | src/tls/tls_handshake_io.h | 9 | ||||
-rw-r--r-- | src/tls/tls_handshake_state.cpp | 9 | ||||
-rw-r--r-- | src/tls/tls_handshake_state.h | 24 |
5 files changed, 51 insertions, 11 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index ed1c5fc75..3e5bdbabd 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -79,7 +79,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) while(m_state) { - auto msg = m_state->handshake_io().get_next_record(); + auto msg = m_state->get_next_handshake_msg(); if(msg.first == HANDSHAKE_NONE) // no full handshake yet break; diff --git a/src/tls/tls_handshake_io.cpp b/src/tls/tls_handshake_io.cpp index f46fdb094..5901eae04 100644 --- a/src/tls/tls_handshake_io.cpp +++ b/src/tls/tls_handshake_io.cpp @@ -61,7 +61,7 @@ void Stream_Handshake_IO::add_input(const byte rec_type, } std::pair<Handshake_Type, std::vector<byte>> -Stream_Handshake_IO::get_next_record() +Stream_Handshake_IO::get_next_record(bool) { if(m_queue.size() >= 4) { @@ -119,9 +119,11 @@ void Datagram_Handshake_IO::add_input(const byte rec_type, size_t record_size, u64bit record_number) { + const u16bit epoch = static_cast<u16bit>(record_number >> 48); + if(rec_type == CHANGE_CIPHER_SPEC) { - m_ccs_epochs.insert(static_cast<u16bit>(record_number >> 48)); + m_ccs_epochs.insert(epoch); return; } @@ -147,8 +149,18 @@ void Datagram_Handshake_IO::add_input(const byte rec_type, } std::pair<Handshake_Type, std::vector<byte>> -Datagram_Handshake_IO::get_next_record() +Datagram_Handshake_IO::get_next_record(bool expecting_ccs) { + if(expecting_ccs) + { + const u16bit current_epoch = 0; // fixme + + if(m_ccs_epochs.count(current_epoch)) + return std::make_pair(HANDSHAKE_CCS, std::vector<byte>()); + else + return std::make_pair(HANDSHAKE_NONE, std::vector<byte>()); + } + auto i = m_messages.find(m_in_message_seq); if(i == m_messages.end() || !i->second.complete()) diff --git a/src/tls/tls_handshake_io.h b/src/tls/tls_handshake_io.h index 7ed459049..b83425281 100644 --- a/src/tls/tls_handshake_io.h +++ b/src/tls/tls_handshake_io.h @@ -46,7 +46,8 @@ class Handshake_IO /** * Returns (HANDSHAKE_NONE, std::vector<>()) if no message currently available */ - virtual std::pair<Handshake_Type, std::vector<byte> > get_next_record() = 0; + virtual std::pair<Handshake_Type, std::vector<byte>> + get_next_record(bool expecting_ccs) = 0; Handshake_IO() {} @@ -78,7 +79,8 @@ class Stream_Handshake_IO : public Handshake_IO size_t record_size, u64bit record_number) override; - std::pair<Handshake_Type, std::vector<byte> > get_next_record() override; + std::pair<Handshake_Type, std::vector<byte>> + get_next_record(bool expecting_ccs) override; private: std::deque<byte> m_queue; Record_Writer& m_writer; @@ -105,7 +107,8 @@ class Datagram_Handshake_IO : public Handshake_IO size_t record_size, u64bit record_number) override; - std::pair<Handshake_Type, std::vector<byte>> get_next_record() override; + std::pair<Handshake_Type, std::vector<byte>> + get_next_record(bool expecting_ccs) override; private: class Handshake_Reassembly { diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp index 082461dc9..24e00772d 100644 --- a/src/tls/tls_handshake_state.cpp +++ b/src/tls/tls_handshake_state.cpp @@ -222,6 +222,15 @@ bool Handshake_State::received_handshake_msg(Handshake_Type handshake_msg) const return (m_hand_received_mask & mask); } +std::pair<Handshake_Type, std::vector<byte>> +Handshake_State::get_next_handshake_msg() + { + const bool expecting_ccs = + (bitmask_for_handshake_type(HANDSHAKE_CCS) & m_hand_expecting_mask); + + return m_handshake_io->get_next_record(expecting_ccs); + } + std::string Handshake_State::srp_identifier() const { if(ciphersuite().valid() && ciphersuite().kex_algo() == "SRP_SHA") diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h index 81a603c6f..486e6bfb8 100644 --- a/src/tls/tls_handshake_state.h +++ b/src/tls/tls_handshake_state.h @@ -56,10 +56,26 @@ class Handshake_State Handshake_IO& handshake_io() { return *m_handshake_io; } - bool received_handshake_msg(Handshake_Type handshake_msg) const; - - void confirm_transition_to(Handshake_Type handshake_msg); - void set_expected_next(Handshake_Type handshake_msg); + /** + * Return true iff we have received a particular message already + * @param msg_type the message type + */ + bool received_handshake_msg(Handshake_Type msg_type) const; + + /** + * Confirm that we were expecting this message type + * @param msg_type the message type + */ + void confirm_transition_to(Handshake_Type msg_type); + + /** + * Record that we are expecting a particular message type next + * @param msg_type the message type + */ + void set_expected_next(Handshake_Type msg_type); + + std::pair<Handshake_Type, std::vector<byte>> + get_next_handshake_msg(); const std::vector<byte>& session_ticket() const; |