aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorlloyd <[email protected]>2012-01-05 23:01:06 +0000
committerlloyd <[email protected]>2012-01-05 23:01:06 +0000
commitf452ca334eeb469d13d816c43227a7ea2f49efeb (patch)
tree51b21923652a596d3d04f6e24ff601e32ff97eb6 /src
parent74226be019b1a66f8eae9a6516f2eb28a53fb9e2 (diff)
Make record reading faster (less copying, no queue at all), at the
expense of significant complexity. Needs careful testing for corner cases and malicious inputs, but seems to work well with randomly chosen segmentations in a correctly formatted stream at least.
Diffstat (limited to 'src')
-rw-r--r--src/tls/rec_read.cpp150
-rw-r--r--src/tls/tls_channel.cpp26
-rw-r--r--src/tls/tls_record.h35
3 files changed, 138 insertions, 73 deletions
diff --git a/src/tls/rec_read.cpp b/src/tls/rec_read.cpp
index 518540bab..957e86aaa 100644
--- a/src/tls/rec_read.cpp
+++ b/src/tls/rec_read.cpp
@@ -19,6 +19,9 @@ Record_Reader::Record_Reader()
m_mac = 0;
reset();
set_maximum_fragment_size(0);
+
+ // A single record is never larger than this
+ m_readbuf.resize(MAX_CIPHERTEXT_SIZE);
}
/*
@@ -31,6 +34,9 @@ void Record_Reader::reset()
delete m_mac;
m_mac = 0;
+ zeroise(m_readbuf);
+ m_readbuf_pos = 0;
+
m_mac_size = 0;
m_block_size = 0;
m_iv_size = 0;
@@ -137,62 +143,91 @@ void Record_Reader::activate(const TLS_Cipher_Suite& suite,
throw Invalid_Argument("Record_Reader: Unknown hash " + mac_algo);
}
-void Record_Reader::add_input(const byte input[], size_t input_size)
+void Record_Reader::consume_input(const byte*& input,
+ size_t& input_size,
+ size_t& input_consumed,
+ size_t desired)
{
- m_input_queue.write(input, input_size);
+ const size_t space_available = (m_readbuf.size() - m_readbuf_pos);
+ const size_t taken = std::min(input_size, desired);
+
+ if(taken > space_available)
+ throw TLS_Exception(RECORD_OVERFLOW,
+ "Record is larger than allowed maximum size");
+
+ copy_mem(&m_readbuf[m_readbuf_pos], input, taken);
+ m_readbuf_pos += taken;
+ input_consumed += taken;
+ input_size -= taken;
+ input += taken;
}
/*
* Retrieve the next record
*/
-size_t Record_Reader::get_record(byte& msg_type,
- MemoryVector<byte>& output)
+size_t Record_Reader::add_input(const byte input_array[], size_t input_size,
+ size_t& input_consumed,
+ byte& msg_type,
+ MemoryVector<byte>& msg)
{
- byte header[5] = { 0 };
+ const byte* input = &input_array[0];
- const size_t have_in_queue = m_input_queue.size();
+ input_consumed = 0;
- if(have_in_queue < sizeof(header))
- return (sizeof(header) - have_in_queue);
+ const size_t HEADER_SIZE = 5;
- /*
- * We peek first to make sure we have the full record
- */
- m_input_queue.peek(header, sizeof(header));
+ if(m_readbuf_pos < HEADER_SIZE) // header incomplete?
+ {
+ consume_input(input, input_size, input_consumed, HEADER_SIZE - m_readbuf_pos);
+
+ if(m_readbuf_pos < HEADER_SIZE)
+ return (HEADER_SIZE - m_readbuf_pos); // header still incomplete
+
+ BOTAN_ASSERT_EQUAL(m_readbuf_pos, HEADER_SIZE,
+ "Buffer error in SSL header");
+ }
// SSLv2-format client hello?
- if(header[0] & 0x80 && header[2] == 1 && header[3] == 3)
+ if(m_readbuf[0] & 0x80 && m_readbuf[2] == 1 && m_readbuf[3] >= 3)
{
- size_t record_len = make_u16bit(header[0], header[1]) & 0x7FFF;
+ size_t record_len = make_u16bit(m_readbuf[0], m_readbuf[1]) & 0x7FFF;
+
+ consume_input(input, input_size, input_consumed, (record_len + 2) - m_readbuf_pos);
+
+ if(m_readbuf_pos < (record_len + 2))
+ return ((record_len + 2) - m_readbuf_pos);
- if(have_in_queue < record_len + 2)
- return (record_len + 2 - have_in_queue);
+ BOTAN_ASSERT_EQUAL(m_readbuf_pos, (record_len + 2),
+ "Buffer error in SSLv2 hello");
msg_type = HANDSHAKE;
- output.resize(record_len + 4);
- m_input_queue.read(&output[2], record_len + 2);
- output[0] = CLIENT_HELLO_SSLV2;
- output[1] = 0;
- output[2] = header[0] & 0x7F;
- output[3] = header[1];
+ msg.resize(record_len + 4);
+ // Fake v3-style handshake message wrapper
+ msg[0] = CLIENT_HELLO_SSLV2;
+ msg[1] = 0;
+ msg[2] = m_readbuf[0] & 0x7F;
+ msg[3] = m_readbuf[1];
+
+ copy_mem(&msg[4], &m_readbuf[2], m_readbuf_pos - 2);
+ m_readbuf_pos = 0;
return 0;
}
- if(header[0] != CHANGE_CIPHER_SPEC &&
- header[0] != ALERT &&
- header[0] != HANDSHAKE &&
- header[0] != APPLICATION_DATA)
+ if(m_readbuf[0] != CHANGE_CIPHER_SPEC &&
+ m_readbuf[0] != ALERT &&
+ m_readbuf[0] != HANDSHAKE &&
+ m_readbuf[0] != APPLICATION_DATA)
{
throw TLS_Exception(UNEXPECTED_MESSAGE,
"Record_Reader: Unknown record type");
}
- const u16bit version = make_u16bit(header[1], header[2]);
- const u16bit record_len = make_u16bit(header[3], header[4]);
+ const u16bit version = make_u16bit(m_readbuf[1], m_readbuf[2]);
+ const u16bit record_len = make_u16bit(m_readbuf[3], m_readbuf[4]);
- if(m_major && (header[1] != m_major || header[2] != m_minor))
+ if(m_major && (m_readbuf[1] != m_major || m_readbuf[2] != m_minor))
throw TLS_Exception(PROTOCOL_VERSION,
"Record_Reader: Got unexpected version");
@@ -200,42 +235,52 @@ size_t Record_Reader::get_record(byte& msg_type,
throw TLS_Exception(RECORD_OVERFLOW,
"Got message that exceeds maximum size");
- // If insufficient data, return without doing anything
- if(have_in_queue < (sizeof(header) + record_len))
- return (sizeof(header) + record_len - have_in_queue);
+ consume_input(input, input_size, input_consumed,
+ (HEADER_SIZE + record_len) - m_readbuf_pos);
+
+ if(m_readbuf_pos < (HEADER_SIZE + record_len))
+ return ((HEADER_SIZE + record_len) - m_readbuf_pos);
+ BOTAN_ASSERT_EQUAL(HEADER_SIZE + record_len, m_readbuf_pos,
+ "Bad buffer handling in record body");
+
+ /*
m_readbuf.resize(record_len);
m_input_queue.read(header, sizeof(header)); // pull off the header
- m_input_queue.read(&m_readbuf[0], m_readbuf.size());
+ m_input_queue.read(&m_readbuf[0], record_len);
+ */
// Null mac means no encryption either, only valid during handshake
if(m_mac_size == 0)
{
- if(header[0] != CHANGE_CIPHER_SPEC &&
- header[0] != ALERT &&
- header[0] != HANDSHAKE)
+ if(m_readbuf[0] != CHANGE_CIPHER_SPEC &&
+ m_readbuf[0] != ALERT &&
+ m_readbuf[0] != HANDSHAKE)
{
throw TLS_Exception(DECODE_ERROR, "Invalid msg type received during handshake");
}
- msg_type = header[0];
- std::swap(output, m_readbuf); // move semantics
+ msg_type = m_readbuf[0];
+ msg.resize(record_len);
+ copy_mem(&msg[0], &m_readbuf[5], record_len);
+
+ m_readbuf_pos = 0;
return 0; // got a full record
}
// Otherwise, decrypt, check MAC, return plaintext
// FIXME: process in-place
- m_cipher.process_msg(m_readbuf);
- size_t got_back = m_cipher.read(&m_readbuf[0], m_readbuf.size(), Pipe::LAST_MESSAGE);
- BOTAN_ASSERT_EQUAL(got_back, m_readbuf.size(), "Cipher didn't decrypt full amount");
+ m_cipher.process_msg(&m_readbuf[5], record_len);
+ size_t got_back = m_cipher.read(&m_readbuf[5], record_len, Pipe::LAST_MESSAGE);
+ BOTAN_ASSERT_EQUAL(got_back, record_len, "Cipher didn't decrypt full amount");
size_t pad_size = 0;
if(m_block_size)
{
- byte pad_value = m_readbuf[m_readbuf.size()-1];
+ byte pad_value = m_readbuf[5 + (record_len-1)];
pad_size = pad_value + 1;
/*
@@ -256,7 +301,7 @@ size_t Record_Reader::get_record(byte& msg_type,
bool padding_good = true;
for(size_t i = 0; i != pad_size; ++i)
- if(m_readbuf[m_readbuf.size()-i-1] != pad_value)
+ if(m_readbuf[5 + (record_len-i-1)] != pad_value)
padding_good = false;
if(!padding_good)
@@ -264,41 +309,42 @@ size_t Record_Reader::get_record(byte& msg_type,
}
}
- if(m_readbuf.size() < m_mac_size + pad_size + m_iv_size)
+ if(record_len < m_mac_size + pad_size + m_iv_size)
throw Decoding_Error("Record_Reader: Record truncated");
- const u16bit plain_length = m_readbuf.size() - (m_mac_size + pad_size + m_iv_size);
+ const u16bit plain_length = record_len - (m_mac_size + pad_size + m_iv_size);
if(plain_length > m_max_fragment)
throw TLS_Exception(RECORD_OVERFLOW, "Plaintext record is too large");
m_mac->update_be(m_seq_no);
- m_mac->update(header[0]); // msg_type
+ m_mac->update(m_readbuf[0]); // msg_type
if(version != SSL_V3)
for(size_t i = 0; i != 2; ++i)
m_mac->update(get_byte(i, version));
m_mac->update_be(plain_length);
- m_mac->update(&m_readbuf[m_iv_size], plain_length);
+ m_mac->update(&m_readbuf[5 + m_iv_size], plain_length);
++m_seq_no;
MemoryVector<byte> computed_mac = m_mac->final();
- const size_t mac_offset = m_readbuf.size() - (m_mac_size + pad_size);
+ const size_t mac_offset = record_len - (m_mac_size + pad_size);
if(computed_mac.size() != m_mac_size)
throw TLS_Exception(INTERNAL_ERROR,
"MAC produced value of unexpected size");
- if(!same_mem(&m_readbuf[mac_offset], &computed_mac[0], m_mac_size))
+ if(!same_mem(&m_readbuf[5 + mac_offset], &computed_mac[0], m_mac_size))
throw TLS_Exception(BAD_RECORD_MAC, "Record_Reader: MAC failure");
- msg_type = header[0];
+ msg_type = m_readbuf[0];
- output.resize(plain_length);
- copy_mem(&output[0], &m_readbuf[m_iv_size], plain_length);
+ msg.resize(plain_length);
+ copy_mem(&msg[0], &m_readbuf[5 + m_iv_size], plain_length);
+ m_readbuf_pos = 0;
return 0;
}
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index 73c4fd4ab..7fda4bc86 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -1,6 +1,6 @@
/*
* TLS Channels
-* (C) 2011 Jack Lloyd
+* (C) 2011-2012 Jack Lloyd
*
* Released under the terms of the Botan license
*/
@@ -8,6 +8,7 @@
#include <botan/tls_channel.h>
#include <botan/internal/tls_alerts.h>
#include <botan/internal/tls_handshake_state.h>
+#include <botan/internal/assert.h>
#include <botan/loadstor.h>
namespace Botan {
@@ -42,17 +43,21 @@ size_t TLS_Channel::received_data(const byte buf[], size_t buf_size)
{
try
{
- reader.add_input(buf, buf_size);
+ while(buf_size)
+ {
+ byte rec_type = CONNECTION_CLOSED;
+ MemoryVector<byte> record;
+ size_t consumed = 0;
- byte rec_type = CONNECTION_CLOSED;
- MemoryVector<byte> record;
+ const size_t needed = reader.add_input(buf, buf_size,
+ consumed,
+ rec_type, record);
- while(!reader.currently_empty())
- {
- const size_t bytes_needed = reader.get_record(rec_type, record);
+ buf += consumed;
+ buf_size -= consumed;
- if(bytes_needed > 0)
- return bytes_needed;
+ if(buf_size == 0 && needed != 0)
+ return needed; // need more data to complete record
if(rec_type == APPLICATION_DATA)
{
@@ -95,7 +100,8 @@ size_t TLS_Channel::received_data(const byte buf[], size_t buf_size)
}
}
else
- throw Unexpected_Message("Unknown message type received");
+ throw Unexpected_Message("Unknown TLS message type " +
+ to_string(rec_type) + " received");
}
return 0; // on a record boundary
diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h
index 8e89b9f8a..f4f3e697f 100644
--- a/src/tls/tls_record.h
+++ b/src/tls/tls_record.h
@@ -59,6 +59,9 @@ class BOTAN_DLL Record_Writer
~Record_Writer() { delete m_mac; }
private:
+ Record_Writer(const Record_Writer&) {}
+ Record_Writer& operator=(const Record_Writer&) { return (*this); }
+
void send_record(byte type, const byte input[], size_t length);
std::tr1::function<void (const byte[], size_t)> m_output_fn;
@@ -80,17 +83,21 @@ class BOTAN_DLL Record_Writer
class BOTAN_DLL Record_Reader
{
public:
- void add_input(const byte input[], size_t input_size);
/**
- * @param msg_type (output variable)
- * @param buffer (output variable)
- * @return Number of bytes still needed (minimum), or 0 if success
+ * @param input new input data (may be NULL if input_size == 0)
+ * @param input_size size of input in bytes
+ * @param input_consumed is set to the number of bytes of input
+ * that were consumed
+ * @param msg_type is set to the type of the message just read if
+ * this function returns 0
+ * @param msg is set to the contents of the record
+ * @return number of bytes still needed (minimum), or 0 if success
*/
- size_t get_record(byte& msg_type,
- MemoryVector<byte>& buffer);
-
- SecureVector<byte> get_record(byte& msg_type);
+ size_t add_input(const byte input[], size_t input_size,
+ size_t& input_consumed,
+ byte& msg_type,
+ MemoryVector<byte>& msg);
void activate(const TLS_Cipher_Suite& suite,
const SessionKeys& keys,
@@ -102,16 +109,22 @@ class BOTAN_DLL Record_Reader
void reset();
- bool currently_empty() const { return m_input_queue.size() == 0; }
-
void set_maximum_fragment_size(size_t max_fragment);
Record_Reader();
~Record_Reader() { delete m_mac; }
private:
+ Record_Reader(const Record_Reader&) {}
+ Record_Reader& operator=(const Record_Reader&) { return (*this); }
+
+ void consume_input(const byte*& input,
+ size_t& input_size,
+ size_t& input_consumed,
+ size_t desired);
+
MemoryVector<byte> m_readbuf;
- SecureQueue m_input_queue;
+ size_t m_readbuf_pos;
Pipe m_cipher;
MessageAuthenticationCode* m_mac;