aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/tls/tls_channel.cpp2
-rw-r--r--src/tls/tls_handshake_io.cpp18
-rw-r--r--src/tls/tls_handshake_io.h9
-rw-r--r--src/tls/tls_handshake_state.cpp9
-rw-r--r--src/tls/tls_handshake_state.h24
5 files changed, 51 insertions, 11 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index ed1c5fc75..3e5bdbabd 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -79,7 +79,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
while(m_state)
{
- auto msg = m_state->handshake_io().get_next_record();
+ auto msg = m_state->get_next_handshake_msg();
if(msg.first == HANDSHAKE_NONE) // no full handshake yet
break;
diff --git a/src/tls/tls_handshake_io.cpp b/src/tls/tls_handshake_io.cpp
index f46fdb094..5901eae04 100644
--- a/src/tls/tls_handshake_io.cpp
+++ b/src/tls/tls_handshake_io.cpp
@@ -61,7 +61,7 @@ void Stream_Handshake_IO::add_input(const byte rec_type,
}
std::pair<Handshake_Type, std::vector<byte>>
-Stream_Handshake_IO::get_next_record()
+Stream_Handshake_IO::get_next_record(bool)
{
if(m_queue.size() >= 4)
{
@@ -119,9 +119,11 @@ void Datagram_Handshake_IO::add_input(const byte rec_type,
size_t record_size,
u64bit record_number)
{
+ const u16bit epoch = static_cast<u16bit>(record_number >> 48);
+
if(rec_type == CHANGE_CIPHER_SPEC)
{
- m_ccs_epochs.insert(static_cast<u16bit>(record_number >> 48));
+ m_ccs_epochs.insert(epoch);
return;
}
@@ -147,8 +149,18 @@ void Datagram_Handshake_IO::add_input(const byte rec_type,
}
std::pair<Handshake_Type, std::vector<byte>>
-Datagram_Handshake_IO::get_next_record()
+Datagram_Handshake_IO::get_next_record(bool expecting_ccs)
{
+ if(expecting_ccs)
+ {
+ const u16bit current_epoch = 0; // fixme
+
+ if(m_ccs_epochs.count(current_epoch))
+ return std::make_pair(HANDSHAKE_CCS, std::vector<byte>());
+ else
+ return std::make_pair(HANDSHAKE_NONE, std::vector<byte>());
+ }
+
auto i = m_messages.find(m_in_message_seq);
if(i == m_messages.end() || !i->second.complete())
diff --git a/src/tls/tls_handshake_io.h b/src/tls/tls_handshake_io.h
index 7ed459049..b83425281 100644
--- a/src/tls/tls_handshake_io.h
+++ b/src/tls/tls_handshake_io.h
@@ -46,7 +46,8 @@ class Handshake_IO
/**
* Returns (HANDSHAKE_NONE, std::vector<>()) if no message currently available
*/
- virtual std::pair<Handshake_Type, std::vector<byte> > get_next_record() = 0;
+ virtual std::pair<Handshake_Type, std::vector<byte>>
+ get_next_record(bool expecting_ccs) = 0;
Handshake_IO() {}
@@ -78,7 +79,8 @@ class Stream_Handshake_IO : public Handshake_IO
size_t record_size,
u64bit record_number) override;
- std::pair<Handshake_Type, std::vector<byte> > get_next_record() override;
+ std::pair<Handshake_Type, std::vector<byte>>
+ get_next_record(bool expecting_ccs) override;
private:
std::deque<byte> m_queue;
Record_Writer& m_writer;
@@ -105,7 +107,8 @@ class Datagram_Handshake_IO : public Handshake_IO
size_t record_size,
u64bit record_number) override;
- std::pair<Handshake_Type, std::vector<byte>> get_next_record() override;
+ std::pair<Handshake_Type, std::vector<byte>>
+ get_next_record(bool expecting_ccs) override;
private:
class Handshake_Reassembly
{
diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp
index 082461dc9..24e00772d 100644
--- a/src/tls/tls_handshake_state.cpp
+++ b/src/tls/tls_handshake_state.cpp
@@ -222,6 +222,15 @@ bool Handshake_State::received_handshake_msg(Handshake_Type handshake_msg) const
return (m_hand_received_mask & mask);
}
+std::pair<Handshake_Type, std::vector<byte>>
+Handshake_State::get_next_handshake_msg()
+ {
+ const bool expecting_ccs =
+ (bitmask_for_handshake_type(HANDSHAKE_CCS) & m_hand_expecting_mask);
+
+ return m_handshake_io->get_next_record(expecting_ccs);
+ }
+
std::string Handshake_State::srp_identifier() const
{
if(ciphersuite().valid() && ciphersuite().kex_algo() == "SRP_SHA")
diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h
index 81a603c6f..486e6bfb8 100644
--- a/src/tls/tls_handshake_state.h
+++ b/src/tls/tls_handshake_state.h
@@ -56,10 +56,26 @@ class Handshake_State
Handshake_IO& handshake_io() { return *m_handshake_io; }
- bool received_handshake_msg(Handshake_Type handshake_msg) const;
-
- void confirm_transition_to(Handshake_Type handshake_msg);
- void set_expected_next(Handshake_Type handshake_msg);
+ /**
+ * Return true iff we have received a particular message already
+ * @param msg_type the message type
+ */
+ bool received_handshake_msg(Handshake_Type msg_type) const;
+
+ /**
+ * Confirm that we were expecting this message type
+ * @param msg_type the message type
+ */
+ void confirm_transition_to(Handshake_Type msg_type);
+
+ /**
+ * Record that we are expecting a particular message type next
+ * @param msg_type the message type
+ */
+ void set_expected_next(Handshake_Type msg_type);
+
+ std::pair<Handshake_Type, std::vector<byte>>
+ get_next_handshake_msg();
const std::vector<byte>& session_ticket() const;