diff options
author | lloyd <[email protected]> | 2012-01-05 23:01:06 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2012-01-05 23:01:06 +0000 |
commit | f452ca334eeb469d13d816c43227a7ea2f49efeb (patch) | |
tree | 51b21923652a596d3d04f6e24ff601e32ff97eb6 /src/tls | |
parent | 74226be019b1a66f8eae9a6516f2eb28a53fb9e2 (diff) |
Make record reading faster (less copying, no queue at all), at the
expense of significant complexity. Needs careful testing for corner
cases and malicious inputs, but seems to work well with randomly
chosen segmentations in a correctly formatted stream at least.
Diffstat (limited to 'src/tls')
-rw-r--r-- | src/tls/rec_read.cpp | 150 | ||||
-rw-r--r-- | src/tls/tls_channel.cpp | 26 | ||||
-rw-r--r-- | src/tls/tls_record.h | 35 |
3 files changed, 138 insertions, 73 deletions
diff --git a/src/tls/rec_read.cpp b/src/tls/rec_read.cpp index 518540bab..957e86aaa 100644 --- a/src/tls/rec_read.cpp +++ b/src/tls/rec_read.cpp @@ -19,6 +19,9 @@ Record_Reader::Record_Reader() m_mac = 0; reset(); set_maximum_fragment_size(0); + + // A single record is never larger than this + m_readbuf.resize(MAX_CIPHERTEXT_SIZE); } /* @@ -31,6 +34,9 @@ void Record_Reader::reset() delete m_mac; m_mac = 0; + zeroise(m_readbuf); + m_readbuf_pos = 0; + m_mac_size = 0; m_block_size = 0; m_iv_size = 0; @@ -137,62 +143,91 @@ void Record_Reader::activate(const TLS_Cipher_Suite& suite, throw Invalid_Argument("Record_Reader: Unknown hash " + mac_algo); } -void Record_Reader::add_input(const byte input[], size_t input_size) +void Record_Reader::consume_input(const byte*& input, + size_t& input_size, + size_t& input_consumed, + size_t desired) { - m_input_queue.write(input, input_size); + const size_t space_available = (m_readbuf.size() - m_readbuf_pos); + const size_t taken = std::min(input_size, desired); + + if(taken > space_available) + throw TLS_Exception(RECORD_OVERFLOW, + "Record is larger than allowed maximum size"); + + copy_mem(&m_readbuf[m_readbuf_pos], input, taken); + m_readbuf_pos += taken; + input_consumed += taken; + input_size -= taken; + input += taken; } /* * Retrieve the next record */ -size_t Record_Reader::get_record(byte& msg_type, - MemoryVector<byte>& output) +size_t Record_Reader::add_input(const byte input_array[], size_t input_size, + size_t& input_consumed, + byte& msg_type, + MemoryVector<byte>& msg) { - byte header[5] = { 0 }; + const byte* input = &input_array[0]; - const size_t have_in_queue = m_input_queue.size(); + input_consumed = 0; - if(have_in_queue < sizeof(header)) - return (sizeof(header) - have_in_queue); + const size_t HEADER_SIZE = 5; - /* - * We peek first to make sure we have the full record - */ - m_input_queue.peek(header, sizeof(header)); + if(m_readbuf_pos < HEADER_SIZE) // header incomplete? + { + consume_input(input, input_size, input_consumed, HEADER_SIZE - m_readbuf_pos); + + if(m_readbuf_pos < HEADER_SIZE) + return (HEADER_SIZE - m_readbuf_pos); // header still incomplete + + BOTAN_ASSERT_EQUAL(m_readbuf_pos, HEADER_SIZE, + "Buffer error in SSL header"); + } // SSLv2-format client hello? - if(header[0] & 0x80 && header[2] == 1 && header[3] == 3) + if(m_readbuf[0] & 0x80 && m_readbuf[2] == 1 && m_readbuf[3] >= 3) { - size_t record_len = make_u16bit(header[0], header[1]) & 0x7FFF; + size_t record_len = make_u16bit(m_readbuf[0], m_readbuf[1]) & 0x7FFF; + + consume_input(input, input_size, input_consumed, (record_len + 2) - m_readbuf_pos); + + if(m_readbuf_pos < (record_len + 2)) + return ((record_len + 2) - m_readbuf_pos); - if(have_in_queue < record_len + 2) - return (record_len + 2 - have_in_queue); + BOTAN_ASSERT_EQUAL(m_readbuf_pos, (record_len + 2), + "Buffer error in SSLv2 hello"); msg_type = HANDSHAKE; - output.resize(record_len + 4); - m_input_queue.read(&output[2], record_len + 2); - output[0] = CLIENT_HELLO_SSLV2; - output[1] = 0; - output[2] = header[0] & 0x7F; - output[3] = header[1]; + msg.resize(record_len + 4); + // Fake v3-style handshake message wrapper + msg[0] = CLIENT_HELLO_SSLV2; + msg[1] = 0; + msg[2] = m_readbuf[0] & 0x7F; + msg[3] = m_readbuf[1]; + + copy_mem(&msg[4], &m_readbuf[2], m_readbuf_pos - 2); + m_readbuf_pos = 0; return 0; } - if(header[0] != CHANGE_CIPHER_SPEC && - header[0] != ALERT && - header[0] != HANDSHAKE && - header[0] != APPLICATION_DATA) + if(m_readbuf[0] != CHANGE_CIPHER_SPEC && + m_readbuf[0] != ALERT && + m_readbuf[0] != HANDSHAKE && + m_readbuf[0] != APPLICATION_DATA) { throw TLS_Exception(UNEXPECTED_MESSAGE, "Record_Reader: Unknown record type"); } - const u16bit version = make_u16bit(header[1], header[2]); - const u16bit record_len = make_u16bit(header[3], header[4]); + const u16bit version = make_u16bit(m_readbuf[1], m_readbuf[2]); + const u16bit record_len = make_u16bit(m_readbuf[3], m_readbuf[4]); - if(m_major && (header[1] != m_major || header[2] != m_minor)) + if(m_major && (m_readbuf[1] != m_major || m_readbuf[2] != m_minor)) throw TLS_Exception(PROTOCOL_VERSION, "Record_Reader: Got unexpected version"); @@ -200,42 +235,52 @@ size_t Record_Reader::get_record(byte& msg_type, throw TLS_Exception(RECORD_OVERFLOW, "Got message that exceeds maximum size"); - // If insufficient data, return without doing anything - if(have_in_queue < (sizeof(header) + record_len)) - return (sizeof(header) + record_len - have_in_queue); + consume_input(input, input_size, input_consumed, + (HEADER_SIZE + record_len) - m_readbuf_pos); + + if(m_readbuf_pos < (HEADER_SIZE + record_len)) + return ((HEADER_SIZE + record_len) - m_readbuf_pos); + BOTAN_ASSERT_EQUAL(HEADER_SIZE + record_len, m_readbuf_pos, + "Bad buffer handling in record body"); + + /* m_readbuf.resize(record_len); m_input_queue.read(header, sizeof(header)); // pull off the header - m_input_queue.read(&m_readbuf[0], m_readbuf.size()); + m_input_queue.read(&m_readbuf[0], record_len); + */ // Null mac means no encryption either, only valid during handshake if(m_mac_size == 0) { - if(header[0] != CHANGE_CIPHER_SPEC && - header[0] != ALERT && - header[0] != HANDSHAKE) + if(m_readbuf[0] != CHANGE_CIPHER_SPEC && + m_readbuf[0] != ALERT && + m_readbuf[0] != HANDSHAKE) { throw TLS_Exception(DECODE_ERROR, "Invalid msg type received during handshake"); } - msg_type = header[0]; - std::swap(output, m_readbuf); // move semantics + msg_type = m_readbuf[0]; + msg.resize(record_len); + copy_mem(&msg[0], &m_readbuf[5], record_len); + + m_readbuf_pos = 0; return 0; // got a full record } // Otherwise, decrypt, check MAC, return plaintext // FIXME: process in-place - m_cipher.process_msg(m_readbuf); - size_t got_back = m_cipher.read(&m_readbuf[0], m_readbuf.size(), Pipe::LAST_MESSAGE); - BOTAN_ASSERT_EQUAL(got_back, m_readbuf.size(), "Cipher didn't decrypt full amount"); + m_cipher.process_msg(&m_readbuf[5], record_len); + size_t got_back = m_cipher.read(&m_readbuf[5], record_len, Pipe::LAST_MESSAGE); + BOTAN_ASSERT_EQUAL(got_back, record_len, "Cipher didn't decrypt full amount"); size_t pad_size = 0; if(m_block_size) { - byte pad_value = m_readbuf[m_readbuf.size()-1]; + byte pad_value = m_readbuf[5 + (record_len-1)]; pad_size = pad_value + 1; /* @@ -256,7 +301,7 @@ size_t Record_Reader::get_record(byte& msg_type, bool padding_good = true; for(size_t i = 0; i != pad_size; ++i) - if(m_readbuf[m_readbuf.size()-i-1] != pad_value) + if(m_readbuf[5 + (record_len-i-1)] != pad_value) padding_good = false; if(!padding_good) @@ -264,41 +309,42 @@ size_t Record_Reader::get_record(byte& msg_type, } } - if(m_readbuf.size() < m_mac_size + pad_size + m_iv_size) + if(record_len < m_mac_size + pad_size + m_iv_size) throw Decoding_Error("Record_Reader: Record truncated"); - const u16bit plain_length = m_readbuf.size() - (m_mac_size + pad_size + m_iv_size); + const u16bit plain_length = record_len - (m_mac_size + pad_size + m_iv_size); if(plain_length > m_max_fragment) throw TLS_Exception(RECORD_OVERFLOW, "Plaintext record is too large"); m_mac->update_be(m_seq_no); - m_mac->update(header[0]); // msg_type + m_mac->update(m_readbuf[0]); // msg_type if(version != SSL_V3) for(size_t i = 0; i != 2; ++i) m_mac->update(get_byte(i, version)); m_mac->update_be(plain_length); - m_mac->update(&m_readbuf[m_iv_size], plain_length); + m_mac->update(&m_readbuf[5 + m_iv_size], plain_length); ++m_seq_no; MemoryVector<byte> computed_mac = m_mac->final(); - const size_t mac_offset = m_readbuf.size() - (m_mac_size + pad_size); + const size_t mac_offset = record_len - (m_mac_size + pad_size); if(computed_mac.size() != m_mac_size) throw TLS_Exception(INTERNAL_ERROR, "MAC produced value of unexpected size"); - if(!same_mem(&m_readbuf[mac_offset], &computed_mac[0], m_mac_size)) + if(!same_mem(&m_readbuf[5 + mac_offset], &computed_mac[0], m_mac_size)) throw TLS_Exception(BAD_RECORD_MAC, "Record_Reader: MAC failure"); - msg_type = header[0]; + msg_type = m_readbuf[0]; - output.resize(plain_length); - copy_mem(&output[0], &m_readbuf[m_iv_size], plain_length); + msg.resize(plain_length); + copy_mem(&msg[0], &m_readbuf[5 + m_iv_size], plain_length); + m_readbuf_pos = 0; return 0; } diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index 73c4fd4ab..7fda4bc86 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -1,6 +1,6 @@ /* * TLS Channels -* (C) 2011 Jack Lloyd +* (C) 2011-2012 Jack Lloyd * * Released under the terms of the Botan license */ @@ -8,6 +8,7 @@ #include <botan/tls_channel.h> #include <botan/internal/tls_alerts.h> #include <botan/internal/tls_handshake_state.h> +#include <botan/internal/assert.h> #include <botan/loadstor.h> namespace Botan { @@ -42,17 +43,21 @@ size_t TLS_Channel::received_data(const byte buf[], size_t buf_size) { try { - reader.add_input(buf, buf_size); + while(buf_size) + { + byte rec_type = CONNECTION_CLOSED; + MemoryVector<byte> record; + size_t consumed = 0; - byte rec_type = CONNECTION_CLOSED; - MemoryVector<byte> record; + const size_t needed = reader.add_input(buf, buf_size, + consumed, + rec_type, record); - while(!reader.currently_empty()) - { - const size_t bytes_needed = reader.get_record(rec_type, record); + buf += consumed; + buf_size -= consumed; - if(bytes_needed > 0) - return bytes_needed; + if(buf_size == 0 && needed != 0) + return needed; // need more data to complete record if(rec_type == APPLICATION_DATA) { @@ -95,7 +100,8 @@ size_t TLS_Channel::received_data(const byte buf[], size_t buf_size) } } else - throw Unexpected_Message("Unknown message type received"); + throw Unexpected_Message("Unknown TLS message type " + + to_string(rec_type) + " received"); } return 0; // on a record boundary diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h index 8e89b9f8a..f4f3e697f 100644 --- a/src/tls/tls_record.h +++ b/src/tls/tls_record.h @@ -59,6 +59,9 @@ class BOTAN_DLL Record_Writer ~Record_Writer() { delete m_mac; } private: + Record_Writer(const Record_Writer&) {} + Record_Writer& operator=(const Record_Writer&) { return (*this); } + void send_record(byte type, const byte input[], size_t length); std::tr1::function<void (const byte[], size_t)> m_output_fn; @@ -80,17 +83,21 @@ class BOTAN_DLL Record_Writer class BOTAN_DLL Record_Reader { public: - void add_input(const byte input[], size_t input_size); /** - * @param msg_type (output variable) - * @param buffer (output variable) - * @return Number of bytes still needed (minimum), or 0 if success + * @param input new input data (may be NULL if input_size == 0) + * @param input_size size of input in bytes + * @param input_consumed is set to the number of bytes of input + * that were consumed + * @param msg_type is set to the type of the message just read if + * this function returns 0 + * @param msg is set to the contents of the record + * @return number of bytes still needed (minimum), or 0 if success */ - size_t get_record(byte& msg_type, - MemoryVector<byte>& buffer); - - SecureVector<byte> get_record(byte& msg_type); + size_t add_input(const byte input[], size_t input_size, + size_t& input_consumed, + byte& msg_type, + MemoryVector<byte>& msg); void activate(const TLS_Cipher_Suite& suite, const SessionKeys& keys, @@ -102,16 +109,22 @@ class BOTAN_DLL Record_Reader void reset(); - bool currently_empty() const { return m_input_queue.size() == 0; } - void set_maximum_fragment_size(size_t max_fragment); Record_Reader(); ~Record_Reader() { delete m_mac; } private: + Record_Reader(const Record_Reader&) {} + Record_Reader& operator=(const Record_Reader&) { return (*this); } + + void consume_input(const byte*& input, + size_t& input_size, + size_t& input_consumed, + size_t desired); + MemoryVector<byte> m_readbuf; - SecureQueue m_input_queue; + size_t m_readbuf_pos; Pipe m_cipher; MessageAuthenticationCode* m_mac; |