diff options
-rw-r--r-- | src/tls/rec_read.cpp | 150 | ||||
-rw-r--r-- | src/tls/tls_channel.cpp | 26 | ||||
-rw-r--r-- | src/tls/tls_record.h | 35 |
3 files changed, 138 insertions, 73 deletions
diff --git a/src/tls/rec_read.cpp b/src/tls/rec_read.cpp index 518540bab..957e86aaa 100644 --- a/src/tls/rec_read.cpp +++ b/src/tls/rec_read.cpp @@ -19,6 +19,9 @@ Record_Reader::Record_Reader() m_mac = 0; reset(); set_maximum_fragment_size(0); + + // A single record is never larger than this + m_readbuf.resize(MAX_CIPHERTEXT_SIZE); } /* @@ -31,6 +34,9 @@ void Record_Reader::reset() delete m_mac; m_mac = 0; + zeroise(m_readbuf); + m_readbuf_pos = 0; + m_mac_size = 0; m_block_size = 0; m_iv_size = 0; @@ -137,62 +143,91 @@ void Record_Reader::activate(const TLS_Cipher_Suite& suite, throw Invalid_Argument("Record_Reader: Unknown hash " + mac_algo); } -void Record_Reader::add_input(const byte input[], size_t input_size) +void Record_Reader::consume_input(const byte*& input, + size_t& input_size, + size_t& input_consumed, + size_t desired) { - m_input_queue.write(input, input_size); + const size_t space_available = (m_readbuf.size() - m_readbuf_pos); + const size_t taken = std::min(input_size, desired); + + if(taken > space_available) + throw TLS_Exception(RECORD_OVERFLOW, + "Record is larger than allowed maximum size"); + + copy_mem(&m_readbuf[m_readbuf_pos], input, taken); + m_readbuf_pos += taken; + input_consumed += taken; + input_size -= taken; + input += taken; } /* * Retrieve the next record */ -size_t Record_Reader::get_record(byte& msg_type, - MemoryVector<byte>& output) +size_t Record_Reader::add_input(const byte input_array[], size_t input_size, + size_t& input_consumed, + byte& msg_type, + MemoryVector<byte>& msg) { - byte header[5] = { 0 }; + const byte* input = &input_array[0]; - const size_t have_in_queue = m_input_queue.size(); + input_consumed = 0; - if(have_in_queue < sizeof(header)) - return (sizeof(header) - have_in_queue); + const size_t HEADER_SIZE = 5; - /* - * We peek first to make sure we have the full record - */ - m_input_queue.peek(header, sizeof(header)); + if(m_readbuf_pos < HEADER_SIZE) // header incomplete? + { + consume_input(input, input_size, input_consumed, HEADER_SIZE - m_readbuf_pos); + + if(m_readbuf_pos < HEADER_SIZE) + return (HEADER_SIZE - m_readbuf_pos); // header still incomplete + + BOTAN_ASSERT_EQUAL(m_readbuf_pos, HEADER_SIZE, + "Buffer error in SSL header"); + } // SSLv2-format client hello? - if(header[0] & 0x80 && header[2] == 1 && header[3] == 3) + if(m_readbuf[0] & 0x80 && m_readbuf[2] == 1 && m_readbuf[3] >= 3) { - size_t record_len = make_u16bit(header[0], header[1]) & 0x7FFF; + size_t record_len = make_u16bit(m_readbuf[0], m_readbuf[1]) & 0x7FFF; + + consume_input(input, input_size, input_consumed, (record_len + 2) - m_readbuf_pos); + + if(m_readbuf_pos < (record_len + 2)) + return ((record_len + 2) - m_readbuf_pos); - if(have_in_queue < record_len + 2) - return (record_len + 2 - have_in_queue); + BOTAN_ASSERT_EQUAL(m_readbuf_pos, (record_len + 2), + "Buffer error in SSLv2 hello"); msg_type = HANDSHAKE; - output.resize(record_len + 4); - m_input_queue.read(&output[2], record_len + 2); - output[0] = CLIENT_HELLO_SSLV2; - output[1] = 0; - output[2] = header[0] & 0x7F; - output[3] = header[1]; + msg.resize(record_len + 4); + // Fake v3-style handshake message wrapper + msg[0] = CLIENT_HELLO_SSLV2; + msg[1] = 0; + msg[2] = m_readbuf[0] & 0x7F; + msg[3] = m_readbuf[1]; + + copy_mem(&msg[4], &m_readbuf[2], m_readbuf_pos - 2); + m_readbuf_pos = 0; return 0; } - if(header[0] != CHANGE_CIPHER_SPEC && - header[0] != ALERT && - header[0] != HANDSHAKE && - header[0] != APPLICATION_DATA) + if(m_readbuf[0] != CHANGE_CIPHER_SPEC && + m_readbuf[0] != ALERT && + m_readbuf[0] != HANDSHAKE && + m_readbuf[0] != APPLICATION_DATA) { throw TLS_Exception(UNEXPECTED_MESSAGE, "Record_Reader: Unknown record type"); } - const u16bit version = make_u16bit(header[1], header[2]); - const u16bit record_len = make_u16bit(header[3], header[4]); + const u16bit version = make_u16bit(m_readbuf[1], m_readbuf[2]); + const u16bit record_len = make_u16bit(m_readbuf[3], m_readbuf[4]); - if(m_major && (header[1] != m_major || header[2] != m_minor)) + if(m_major && (m_readbuf[1] != m_major || m_readbuf[2] != m_minor)) throw TLS_Exception(PROTOCOL_VERSION, "Record_Reader: Got unexpected version"); @@ -200,42 +235,52 @@ size_t Record_Reader::get_record(byte& msg_type, throw TLS_Exception(RECORD_OVERFLOW, "Got message that exceeds maximum size"); - // If insufficient data, return without doing anything - if(have_in_queue < (sizeof(header) + record_len)) - return (sizeof(header) + record_len - have_in_queue); + consume_input(input, input_size, input_consumed, + (HEADER_SIZE + record_len) - m_readbuf_pos); + + if(m_readbuf_pos < (HEADER_SIZE + record_len)) + return ((HEADER_SIZE + record_len) - m_readbuf_pos); + BOTAN_ASSERT_EQUAL(HEADER_SIZE + record_len, m_readbuf_pos, + "Bad buffer handling in record body"); + + /* m_readbuf.resize(record_len); m_input_queue.read(header, sizeof(header)); // pull off the header - m_input_queue.read(&m_readbuf[0], m_readbuf.size()); + m_input_queue.read(&m_readbuf[0], record_len); + */ // Null mac means no encryption either, only valid during handshake if(m_mac_size == 0) { - if(header[0] != CHANGE_CIPHER_SPEC && - header[0] != ALERT && - header[0] != HANDSHAKE) + if(m_readbuf[0] != CHANGE_CIPHER_SPEC && + m_readbuf[0] != ALERT && + m_readbuf[0] != HANDSHAKE) { throw TLS_Exception(DECODE_ERROR, "Invalid msg type received during handshake"); } - msg_type = header[0]; - std::swap(output, m_readbuf); // move semantics + msg_type = m_readbuf[0]; + msg.resize(record_len); + copy_mem(&msg[0], &m_readbuf[5], record_len); + + m_readbuf_pos = 0; return 0; // got a full record } // Otherwise, decrypt, check MAC, return plaintext // FIXME: process in-place - m_cipher.process_msg(m_readbuf); - size_t got_back = m_cipher.read(&m_readbuf[0], m_readbuf.size(), Pipe::LAST_MESSAGE); - BOTAN_ASSERT_EQUAL(got_back, m_readbuf.size(), "Cipher didn't decrypt full amount"); + m_cipher.process_msg(&m_readbuf[5], record_len); + size_t got_back = m_cipher.read(&m_readbuf[5], record_len, Pipe::LAST_MESSAGE); + BOTAN_ASSERT_EQUAL(got_back, record_len, "Cipher didn't decrypt full amount"); size_t pad_size = 0; if(m_block_size) { - byte pad_value = m_readbuf[m_readbuf.size()-1]; + byte pad_value = m_readbuf[5 + (record_len-1)]; pad_size = pad_value + 1; /* @@ -256,7 +301,7 @@ size_t Record_Reader::get_record(byte& msg_type, bool padding_good = true; for(size_t i = 0; i != pad_size; ++i) - if(m_readbuf[m_readbuf.size()-i-1] != pad_value) + if(m_readbuf[5 + (record_len-i-1)] != pad_value) padding_good = false; if(!padding_good) @@ -264,41 +309,42 @@ size_t Record_Reader::get_record(byte& msg_type, } } - if(m_readbuf.size() < m_mac_size + pad_size + m_iv_size) + if(record_len < m_mac_size + pad_size + m_iv_size) throw Decoding_Error("Record_Reader: Record truncated"); - const u16bit plain_length = m_readbuf.size() - (m_mac_size + pad_size + m_iv_size); + const u16bit plain_length = record_len - (m_mac_size + pad_size + m_iv_size); if(plain_length > m_max_fragment) throw TLS_Exception(RECORD_OVERFLOW, "Plaintext record is too large"); m_mac->update_be(m_seq_no); - m_mac->update(header[0]); // msg_type + m_mac->update(m_readbuf[0]); // msg_type if(version != SSL_V3) for(size_t i = 0; i != 2; ++i) m_mac->update(get_byte(i, version)); m_mac->update_be(plain_length); - m_mac->update(&m_readbuf[m_iv_size], plain_length); + m_mac->update(&m_readbuf[5 + m_iv_size], plain_length); ++m_seq_no; MemoryVector<byte> computed_mac = m_mac->final(); - const size_t mac_offset = m_readbuf.size() - (m_mac_size + pad_size); + const size_t mac_offset = record_len - (m_mac_size + pad_size); if(computed_mac.size() != m_mac_size) throw TLS_Exception(INTERNAL_ERROR, "MAC produced value of unexpected size"); - if(!same_mem(&m_readbuf[mac_offset], &computed_mac[0], m_mac_size)) + if(!same_mem(&m_readbuf[5 + mac_offset], &computed_mac[0], m_mac_size)) throw TLS_Exception(BAD_RECORD_MAC, "Record_Reader: MAC failure"); - msg_type = header[0]; + msg_type = m_readbuf[0]; - output.resize(plain_length); - copy_mem(&output[0], &m_readbuf[m_iv_size], plain_length); + msg.resize(plain_length); + copy_mem(&msg[0], &m_readbuf[5 + m_iv_size], plain_length); + m_readbuf_pos = 0; return 0; } diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index 73c4fd4ab..7fda4bc86 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -1,6 +1,6 @@ /* * TLS Channels -* (C) 2011 Jack Lloyd +* (C) 2011-2012 Jack Lloyd * * Released under the terms of the Botan license */ @@ -8,6 +8,7 @@ #include <botan/tls_channel.h> #include <botan/internal/tls_alerts.h> #include <botan/internal/tls_handshake_state.h> +#include <botan/internal/assert.h> #include <botan/loadstor.h> namespace Botan { @@ -42,17 +43,21 @@ size_t TLS_Channel::received_data(const byte buf[], size_t buf_size) { try { - reader.add_input(buf, buf_size); + while(buf_size) + { + byte rec_type = CONNECTION_CLOSED; + MemoryVector<byte> record; + size_t consumed = 0; - byte rec_type = CONNECTION_CLOSED; - MemoryVector<byte> record; + const size_t needed = reader.add_input(buf, buf_size, + consumed, + rec_type, record); - while(!reader.currently_empty()) - { - const size_t bytes_needed = reader.get_record(rec_type, record); + buf += consumed; + buf_size -= consumed; - if(bytes_needed > 0) - return bytes_needed; + if(buf_size == 0 && needed != 0) + return needed; // need more data to complete record if(rec_type == APPLICATION_DATA) { @@ -95,7 +100,8 @@ size_t TLS_Channel::received_data(const byte buf[], size_t buf_size) } } else - throw Unexpected_Message("Unknown message type received"); + throw Unexpected_Message("Unknown TLS message type " + + to_string(rec_type) + " received"); } return 0; // on a record boundary diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h index 8e89b9f8a..f4f3e697f 100644 --- a/src/tls/tls_record.h +++ b/src/tls/tls_record.h @@ -59,6 +59,9 @@ class BOTAN_DLL Record_Writer ~Record_Writer() { delete m_mac; } private: + Record_Writer(const Record_Writer&) {} + Record_Writer& operator=(const Record_Writer&) { return (*this); } + void send_record(byte type, const byte input[], size_t length); std::tr1::function<void (const byte[], size_t)> m_output_fn; @@ -80,17 +83,21 @@ class BOTAN_DLL Record_Writer class BOTAN_DLL Record_Reader { public: - void add_input(const byte input[], size_t input_size); /** - * @param msg_type (output variable) - * @param buffer (output variable) - * @return Number of bytes still needed (minimum), or 0 if success + * @param input new input data (may be NULL if input_size == 0) + * @param input_size size of input in bytes + * @param input_consumed is set to the number of bytes of input + * that were consumed + * @param msg_type is set to the type of the message just read if + * this function returns 0 + * @param msg is set to the contents of the record + * @return number of bytes still needed (minimum), or 0 if success */ - size_t get_record(byte& msg_type, - MemoryVector<byte>& buffer); - - SecureVector<byte> get_record(byte& msg_type); + size_t add_input(const byte input[], size_t input_size, + size_t& input_consumed, + byte& msg_type, + MemoryVector<byte>& msg); void activate(const TLS_Cipher_Suite& suite, const SessionKeys& keys, @@ -102,16 +109,22 @@ class BOTAN_DLL Record_Reader void reset(); - bool currently_empty() const { return m_input_queue.size() == 0; } - void set_maximum_fragment_size(size_t max_fragment); Record_Reader(); ~Record_Reader() { delete m_mac; } private: + Record_Reader(const Record_Reader&) {} + Record_Reader& operator=(const Record_Reader&) { return (*this); } + + void consume_input(const byte*& input, + size_t& input_size, + size_t& input_consumed, + size_t desired); + MemoryVector<byte> m_readbuf; - SecureQueue m_input_queue; + size_t m_readbuf_pos; Pipe m_cipher; MessageAuthenticationCode* m_mac; |