diff options
-rw-r--r-- | src/tls/rec_read.cpp | 7 | ||||
-rw-r--r-- | src/tls/tls_channel.cpp | 11 | ||||
-rw-r--r-- | src/tls/tls_handshake_io.cpp | 15 | ||||
-rw-r--r-- | src/tls/tls_handshake_io.h | 11 | ||||
-rw-r--r-- | src/tls/tls_record.h | 4 |
5 files changed, 33 insertions, 15 deletions
diff --git a/src/tls/rec_read.cpp b/src/tls/rec_read.cpp index d57e70f59..276ae5732 100644 --- a/src/tls/rec_read.cpp +++ b/src/tls/rec_read.cpp @@ -219,7 +219,8 @@ size_t tls_padding_check(Protocol_Version version, size_t Record_Reader::add_input(const byte input_array[], size_t input_sz, size_t& consumed, byte& msg_type, - std::vector<byte>& msg) + std::vector<byte>& msg, + u64bit& msg_sequence) { const byte* input = &input_array[0]; @@ -263,6 +264,7 @@ size_t Record_Reader::add_input(const byte input_array[], size_t input_sz, copy_mem(&msg[4], &m_readbuf[2], m_readbuf_pos - 2); m_readbuf_pos = 0; + msg_sequence = m_seq_no++; return 0; } } @@ -317,6 +319,7 @@ size_t Record_Reader::add_input(const byte input_array[], size_t input_sz, copy_mem(&msg[0], &m_readbuf[TLS_HEADER_SIZE], record_len); m_readbuf_pos = 0; + msg_sequence = m_seq_no++; return 0; // got a full record } @@ -364,7 +367,7 @@ size_t Record_Reader::add_input(const byte input_array[], size_t input_sz, m_mac->update_be(plain_length); m_mac->update(&m_readbuf[TLS_HEADER_SIZE + m_iv_size], plain_length); - ++m_seq_no; + msg_sequence = m_seq_no++; m_mac->final(&m_macbuf[0]); diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index a09dc5afc..060799826 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -45,11 +45,15 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) { byte rec_type = CONNECTION_CLOSED; std::vector<byte> record; + u64bit record_number = 0; + size_t consumed = 0; const size_t needed = m_reader.add_input(buf, buf_size, consumed, - rec_type, record); + rec_type, + record, + record_number); BOTAN_ASSERT(consumed <= buf_size, "Record reader consumed sane amount"); @@ -68,7 +72,10 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) if(!m_state) m_state.reset(new_handshake_state()); - m_state->handshake_io().add_input(rec_type, &record[0], record.size()); + m_state->handshake_io().add_input(rec_type, + &record[0], + record.size(), + record_number); while(m_state && m_state->handshake_io().have_full_record()) { diff --git a/src/tls/tls_handshake_io.cpp b/src/tls/tls_handshake_io.cpp index 452fa8f15..5cb52caf2 100644 --- a/src/tls/tls_handshake_io.cpp +++ b/src/tls/tls_handshake_io.cpp @@ -40,7 +40,8 @@ Protocol_Version Stream_Handshake_IO::initial_record_version() const void Stream_Handshake_IO::add_input(const byte rec_type, const byte record[], - size_t record_size) + size_t record_size, + u64bit /*record_number*/) { if(rec_type == HANDSHAKE) { @@ -51,6 +52,7 @@ void Stream_Handshake_IO::add_input(const byte rec_type, if(record_size != 1 || record[0] != 1) throw Decoding_Error("Invalid ChangeCipherSpec"); + // Pretend it's a regular handshake message of zero length const byte ccs_hs[] = { HANDSHAKE_CCS, 0, 0, 0 }; m_queue.insert(m_queue.end(), ccs_hs, ccs_hs + sizeof(ccs_hs)); } @@ -131,12 +133,12 @@ Protocol_Version Datagram_Handshake_IO::initial_record_version() const void Datagram_Handshake_IO::add_input(const byte rec_type, const byte record[], - size_t record_size) + size_t record_size, + u64bit record_number) { if(rec_type == CHANGE_CIPHER_SPEC) { - const u16bit message_seq = 666; // fixme - m_messages[message_seq].add_fragment(nullptr, 0, 0, HANDSHAKE_CCS, 0); + m_ccs_epochs.insert(static_cast<u16bit>(record_number >> 48)); return; } @@ -182,13 +184,12 @@ std::pair<Handshake_Type, std::vector<byte> > Datagram_Handshake_IO::get_next_re if(i == m_messages.end() || !i->second.complete()) throw Internal_Error("Datagram_Handshake_IO::get_next_record called without a full record"); - - //return i->second.message(); auto m = i->second.message(); m_in_message_seq += 1; return m; + //return i->second.message(); } void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment( @@ -214,7 +215,7 @@ void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment( bool Datagram_Handshake_IO::Handshake_Reassembly::complete() const { - return true; // fixme! + return true; // FIXME } std::pair<Handshake_Type, std::vector<byte>> diff --git a/src/tls/tls_handshake_io.h b/src/tls/tls_handshake_io.h index da3bfd5c8..cd79a0b72 100644 --- a/src/tls/tls_handshake_io.h +++ b/src/tls/tls_handshake_io.h @@ -14,6 +14,7 @@ #include <vector> #include <deque> #include <map> +#include <set> #include <utility> namespace Botan { @@ -39,7 +40,8 @@ class Handshake_IO virtual void add_input(byte record_type, const byte record[], - size_t record_size) = 0; + size_t record_size, + u64bit record_number) = 0; virtual bool empty() const = 0; @@ -74,7 +76,8 @@ class Stream_Handshake_IO : public Handshake_IO void add_input(byte record_type, const byte record[], - size_t record_size) override; + size_t record_size, + u64bit record_number) override; bool empty() const override; @@ -104,7 +107,8 @@ class Datagram_Handshake_IO : public Handshake_IO void add_input(const byte rec_type, const byte record[], - size_t record_size) override; + size_t record_size, + u64bit record_number) override; bool empty() const override; @@ -132,6 +136,7 @@ class Datagram_Handshake_IO : public Handshake_IO }; std::map<u16bit, Handshake_Reassembly> m_messages; + std::set<u16bit> m_ccs_epochs; u16bit m_in_message_seq = 0; u16bit m_out_message_seq = 0; diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h index 584fba52f..820de0958 100644 --- a/src/tls/tls_record.h +++ b/src/tls/tls_record.h @@ -90,12 +90,14 @@ class BOTAN_DLL Record_Reader * @param msg_type is set to the type of the message just read if * this function returns 0 * @param msg is set to the contents of the record + * @param msg_sequence is set to this records sequence number * @return number of bytes still needed (minimum), or 0 if success */ size_t add_input(const byte input[], size_t input_size, size_t& input_consumed, byte& msg_type, - std::vector<byte>& msg); + std::vector<byte>& msg, + u64bit& msg_sequence); void change_cipher_spec(Connection_Side side, const Ciphersuite& suite, |