diff options
-rw-r--r-- | src/tls/tls_channel.cpp | 2 | ||||
-rw-r--r-- | src/tls/tls_handshake_io.cpp | 18 | ||||
-rw-r--r-- | src/tls/tls_handshake_io.h | 9 | ||||
-rw-r--r-- | src/tls/tls_handshake_state.cpp | 9 | ||||
-rw-r--r-- | src/tls/tls_handshake_state.h | 24 |
5 files changed, 51 insertions, 11 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index ed1c5fc75..3e5bdbabd 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -79,7 +79,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) while(m_state) { - auto msg = m_state->handshake_io().get_next_record(); + auto msg = m_state->get_next_handshake_msg(); if(msg.first == HANDSHAKE_NONE) // no full handshake yet break; diff --git a/src/tls/tls_handshake_io.cpp b/src/tls/tls_handshake_io.cpp index f46fdb094..5901eae04 100644 --- a/src/tls/tls_handshake_io.cpp +++ b/src/tls/tls_handshake_io.cpp @@ -61,7 +61,7 @@ void Stream_Handshake_IO::add_input(const byte rec_type, } std::pair<Handshake_Type, std::vector<byte>> -Stream_Handshake_IO::get_next_record() +Stream_Handshake_IO::get_next_record(bool) { if(m_queue.size() >= 4) { @@ -119,9 +119,11 @@ void Datagram_Handshake_IO::add_input(const byte rec_type, size_t record_size, u64bit record_number) { + const u16bit epoch = static_cast<u16bit>(record_number >> 48); + if(rec_type == CHANGE_CIPHER_SPEC) { - m_ccs_epochs.insert(static_cast<u16bit>(record_number >> 48)); + m_ccs_epochs.insert(epoch); return; } @@ -147,8 +149,18 @@ void Datagram_Handshake_IO::add_input(const byte rec_type, } std::pair<Handshake_Type, std::vector<byte>> -Datagram_Handshake_IO::get_next_record() +Datagram_Handshake_IO::get_next_record(bool expecting_ccs) { + if(expecting_ccs) + { + const u16bit current_epoch = 0; // fixme + + if(m_ccs_epochs.count(current_epoch)) + return std::make_pair(HANDSHAKE_CCS, std::vector<byte>()); + else + return std::make_pair(HANDSHAKE_NONE, std::vector<byte>()); + } + auto i = m_messages.find(m_in_message_seq); if(i == m_messages.end() || !i->second.complete()) diff --git a/src/tls/tls_handshake_io.h b/src/tls/tls_handshake_io.h index 7ed459049..b83425281 100644 --- a/src/tls/tls_handshake_io.h +++ b/src/tls/tls_handshake_io.h @@ -46,7 +46,8 @@ class Handshake_IO /** * Returns (HANDSHAKE_NONE, std::vector<>()) if no message currently available */ - virtual std::pair<Handshake_Type, std::vector<byte> > get_next_record() = 0; + virtual std::pair<Handshake_Type, std::vector<byte>> + get_next_record(bool expecting_ccs) = 0; Handshake_IO() {} @@ -78,7 +79,8 @@ class Stream_Handshake_IO : public Handshake_IO size_t record_size, u64bit record_number) override; - std::pair<Handshake_Type, std::vector<byte> > get_next_record() override; + std::pair<Handshake_Type, std::vector<byte>> + get_next_record(bool expecting_ccs) override; private: std::deque<byte> m_queue; Record_Writer& m_writer; @@ -105,7 +107,8 @@ class Datagram_Handshake_IO : public Handshake_IO size_t record_size, u64bit record_number) override; - std::pair<Handshake_Type, std::vector<byte>> get_next_record() override; + std::pair<Handshake_Type, std::vector<byte>> + get_next_record(bool expecting_ccs) override; private: class Handshake_Reassembly { diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp index 082461dc9..24e00772d 100644 --- a/src/tls/tls_handshake_state.cpp +++ b/src/tls/tls_handshake_state.cpp @@ -222,6 +222,15 @@ bool Handshake_State::received_handshake_msg(Handshake_Type handshake_msg) const return (m_hand_received_mask & mask); } +std::pair<Handshake_Type, std::vector<byte>> +Handshake_State::get_next_handshake_msg() + { + const bool expecting_ccs = + (bitmask_for_handshake_type(HANDSHAKE_CCS) & m_hand_expecting_mask); + + return m_handshake_io->get_next_record(expecting_ccs); + } + std::string Handshake_State::srp_identifier() const { if(ciphersuite().valid() && ciphersuite().kex_algo() == "SRP_SHA") diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h index 81a603c6f..486e6bfb8 100644 --- a/src/tls/tls_handshake_state.h +++ b/src/tls/tls_handshake_state.h @@ -56,10 +56,26 @@ class Handshake_State Handshake_IO& handshake_io() { return *m_handshake_io; } - bool received_handshake_msg(Handshake_Type handshake_msg) const; - - void confirm_transition_to(Handshake_Type handshake_msg); - void set_expected_next(Handshake_Type handshake_msg); + /** + * Return true iff we have received a particular message already + * @param msg_type the message type + */ + bool received_handshake_msg(Handshake_Type msg_type) const; + + /** + * Confirm that we were expecting this message type + * @param msg_type the message type + */ + void confirm_transition_to(Handshake_Type msg_type); + + /** + * Record that we are expecting a particular message type next + * @param msg_type the message type + */ + void set_expected_next(Handshake_Type msg_type); + + std::pair<Handshake_Type, std::vector<byte>> + get_next_handshake_msg(); const std::vector<byte>& session_ticket() const; |