aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/tls/tls_handshake_io.cpp133
-rw-r--r--src/tls/tls_handshake_io.h51
2 files changed, 184 insertions, 0 deletions
diff --git a/src/tls/tls_handshake_io.cpp b/src/tls/tls_handshake_io.cpp
index fe1b9c790..8ab0bea9c 100644
--- a/src/tls/tls_handshake_io.cpp
+++ b/src/tls/tls_handshake_io.cpp
@@ -119,6 +119,139 @@ std::vector<byte> Stream_Handshake_IO::send(Handshake_Message& msg)
return buf;
}
+void Datagram_Handshake_IO::add_input(const byte rec_type,
+ const byte record[],
+ size_t record_size)
+ {
+ if(rec_type == CHANGE_CIPHER_SPEC)
+ {
+ const u16bit message_seq = 666; // fixme
+ m_messages[message_seq].add_fragment(nullptr, 0, 0, HANDSHAKE_CCS, 0);
+ return;
+ }
+
+ const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
+
+ if(record_size < DTLS_HANDSHAKE_HEADER_LEN)
+ return; // completely bogus? at least degenerate/weird
+
+ const byte msg_type = record[0];
+ const size_t msg_len = load_be24(&record[1]);
+ const u16bit message_seq = load_be<u16bit>(&record[4], 0);
+ const size_t fragment_offset = load_be24(&record[6]);
+ const size_t fragment_length = load_be24(&record[9]);
+
+ if(fragment_length + DTLS_HANDSHAKE_HEADER_LEN != record_size)
+ throw Decoding_Error("Bogus DTLS handshake, header sizes do not match");
+
+ m_messages[message_seq].add_fragment(&record[DTLS_HANDSHAKE_HEADER_LEN],
+ fragment_length,
+ fragment_offset,
+ msg_type,
+ msg_len);
+ }
+
+bool Datagram_Handshake_IO::empty() const
+ {
+ return m_messages.find(m_in_message_seq) == m_messages.end();
+ }
+
+bool Datagram_Handshake_IO::have_full_record() const
+ {
+ auto i = m_messages.find(m_in_message_seq);
+
+ const bool complete = (i != m_messages.end() && i->second.complete());
+
+ return complete;
+ }
+
+std::pair<Handshake_Type, std::vector<byte> > Datagram_Handshake_IO::get_next_record()
+ {
+ auto i = m_messages.find(m_in_message_seq);
+
+ if(i == m_messages.end() || !i->second.complete())
+ throw Internal_Error("Datagram_Handshake_IO::get_next_record called without a full record");
+
+
+ //return i->second.message();
+ auto m = i->second.message();
+
+ m_in_message_seq += 1;
+
+ return m;
+ }
+
+void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment(
+ const byte fragment[],
+ size_t fragment_length,
+ size_t fragment_offset,
+ byte msg_type,
+ size_t msg_length)
+ {
+ if(m_msg_type == HANDSHAKE_NONE)
+ {
+ m_msg_type = msg_type;
+ m_msg_length = msg_length;
+#warning DoS should resize as inputs are added (?)
+ m_buffer.resize(m_msg_length);
+ }
+
+ if(msg_type != m_msg_type || msg_length != m_msg_length)
+ throw Decoding_Error("Datagram_Handshake_IO - inconsistent values");
+
+ copy_mem(&m_buffer[fragment_offset], fragment, fragment_length);
+ }
+
+bool Datagram_Handshake_IO::Handshake_Reassembly::complete() const
+ {
+ return true; // fixme!
+ }
+
+std::pair<Handshake_Type, std::vector<byte>>
+Datagram_Handshake_IO::Handshake_Reassembly::message() const
+ {
+ if(!complete())
+ throw Internal_Error("Datagram_Handshake_IO - message not complete");
+
+ auto msg = std::make_pair(static_cast<Handshake_Type>(m_msg_type), m_buffer);
+
+ return msg;
+ }
+
+std::vector<byte>
+Datagram_Handshake_IO::format(const std::vector<byte>& msg,
+ Handshake_Type type)
+ {
+ std::vector<byte> send_buf(12 + msg.size());
+
+ const size_t buf_size = msg.size();
+
+ send_buf[0] = type;
+
+ store_be24(&send_buf[1], buf_size);
+
+ store_be(static_cast<u16bit>(m_in_message_seq - 1), &send_buf[4]);
+
+ store_be24(&send_buf[6], 0); // fragment_offset
+ store_be24(&send_buf[9], buf_size); // fragment_length
+
+ copy_mem(&send_buf[12], &msg[0], msg.size());
+
+ return send_buf;
+ }
+
+std::vector<byte>
+Datagram_Handshake_IO::send(Handshake_Message& msg)
+ {
+ const std::vector<byte> buf = format(msg.serialize(), msg.type());
+
+ // FIXME: fragment to mtu size
+ m_writer.send(HANDSHAKE, &buf[0], buf.size());
+
+ return buf;
+
+ }
+
}
}
diff --git a/src/tls/tls_handshake_io.h b/src/tls/tls_handshake_io.h
index f71b2c034..039f92121 100644
--- a/src/tls/tls_handshake_io.h
+++ b/src/tls/tls_handshake_io.h
@@ -12,6 +12,7 @@
#include <botan/loadstor.h>
#include <vector>
#include <deque>
+#include <map>
#include <utility>
namespace Botan {
@@ -80,6 +81,56 @@ class Stream_Handshake_IO : public Handshake_IO
Record_Writer& m_writer;
};
+/**
+* Handshake IO for datagram-based handshakes
+*/
+class Datagram_Handshake_IO : public Handshake_IO
+ {
+ public:
+ Datagram_Handshake_IO(Record_Writer& writer) : m_writer(writer) {}
+
+ std::vector<byte> send(Handshake_Message& msg) override;
+
+ std::vector<byte> format(
+ const std::vector<byte>& handshake_msg,
+ Handshake_Type handshake_type) override;
+
+ void add_input(const byte rec_type,
+ const byte record[],
+ size_t record_size) override;
+
+ bool empty() const override;
+
+ bool have_full_record() const override;
+
+ std::pair<Handshake_Type, std::vector<byte>> get_next_record() override;
+ private:
+ class Handshake_Reassembly
+ {
+ public:
+ void add_fragment(const byte fragment[],
+ size_t fragment_length,
+ size_t fragment_offset,
+ byte msg_type,
+ size_t msg_length);
+
+ bool complete() const;
+
+ std::pair<Handshake_Type, std::vector<byte>> message() const;
+ private:
+ byte m_msg_type = HANDSHAKE_NONE;
+ size_t m_msg_length = 0;
+
+ std::vector<byte> m_buffer;
+ };
+
+ std::map<u16bit, Handshake_Reassembly> m_messages;
+
+ u16bit m_in_message_seq = 0;
+ u16bit m_out_message_seq = 0;
+ Record_Writer& m_writer;
+ };
+
}
}