diff options
author | lloyd <[email protected]> | 2012-01-05 23:01:06 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2012-01-05 23:01:06 +0000 |
commit | f452ca334eeb469d13d816c43227a7ea2f49efeb (patch) | |
tree | 51b21923652a596d3d04f6e24ff601e32ff97eb6 /src/tls/rec_read.cpp | |
parent | 74226be019b1a66f8eae9a6516f2eb28a53fb9e2 (diff) |
Make record reading faster (less copying, no queue at all), at the
expense of significant complexity. Needs careful testing for corner
cases and malicious inputs, but seems to work well with randomly
chosen segmentations in a correctly formatted stream at least.
Diffstat (limited to 'src/tls/rec_read.cpp')
-rw-r--r-- | src/tls/rec_read.cpp | 150 |
1 files changed, 98 insertions, 52 deletions
diff --git a/src/tls/rec_read.cpp b/src/tls/rec_read.cpp index 518540bab..957e86aaa 100644 --- a/src/tls/rec_read.cpp +++ b/src/tls/rec_read.cpp @@ -19,6 +19,9 @@ Record_Reader::Record_Reader() m_mac = 0; reset(); set_maximum_fragment_size(0); + + // A single record is never larger than this + m_readbuf.resize(MAX_CIPHERTEXT_SIZE); } /* @@ -31,6 +34,9 @@ void Record_Reader::reset() delete m_mac; m_mac = 0; + zeroise(m_readbuf); + m_readbuf_pos = 0; + m_mac_size = 0; m_block_size = 0; m_iv_size = 0; @@ -137,62 +143,91 @@ void Record_Reader::activate(const TLS_Cipher_Suite& suite, throw Invalid_Argument("Record_Reader: Unknown hash " + mac_algo); } -void Record_Reader::add_input(const byte input[], size_t input_size) +void Record_Reader::consume_input(const byte*& input, + size_t& input_size, + size_t& input_consumed, + size_t desired) { - m_input_queue.write(input, input_size); + const size_t space_available = (m_readbuf.size() - m_readbuf_pos); + const size_t taken = std::min(input_size, desired); + + if(taken > space_available) + throw TLS_Exception(RECORD_OVERFLOW, + "Record is larger than allowed maximum size"); + + copy_mem(&m_readbuf[m_readbuf_pos], input, taken); + m_readbuf_pos += taken; + input_consumed += taken; + input_size -= taken; + input += taken; } /* * Retrieve the next record */ -size_t Record_Reader::get_record(byte& msg_type, - MemoryVector<byte>& output) +size_t Record_Reader::add_input(const byte input_array[], size_t input_size, + size_t& input_consumed, + byte& msg_type, + MemoryVector<byte>& msg) { - byte header[5] = { 0 }; + const byte* input = &input_array[0]; - const size_t have_in_queue = m_input_queue.size(); + input_consumed = 0; - if(have_in_queue < sizeof(header)) - return (sizeof(header) - have_in_queue); + const size_t HEADER_SIZE = 5; - /* - * We peek first to make sure we have the full record - */ - m_input_queue.peek(header, sizeof(header)); + if(m_readbuf_pos < HEADER_SIZE) // header incomplete? + { + consume_input(input, input_size, input_consumed, HEADER_SIZE - m_readbuf_pos); + + if(m_readbuf_pos < HEADER_SIZE) + return (HEADER_SIZE - m_readbuf_pos); // header still incomplete + + BOTAN_ASSERT_EQUAL(m_readbuf_pos, HEADER_SIZE, + "Buffer error in SSL header"); + } // SSLv2-format client hello? - if(header[0] & 0x80 && header[2] == 1 && header[3] == 3) + if(m_readbuf[0] & 0x80 && m_readbuf[2] == 1 && m_readbuf[3] >= 3) { - size_t record_len = make_u16bit(header[0], header[1]) & 0x7FFF; + size_t record_len = make_u16bit(m_readbuf[0], m_readbuf[1]) & 0x7FFF; + + consume_input(input, input_size, input_consumed, (record_len + 2) - m_readbuf_pos); + + if(m_readbuf_pos < (record_len + 2)) + return ((record_len + 2) - m_readbuf_pos); - if(have_in_queue < record_len + 2) - return (record_len + 2 - have_in_queue); + BOTAN_ASSERT_EQUAL(m_readbuf_pos, (record_len + 2), + "Buffer error in SSLv2 hello"); msg_type = HANDSHAKE; - output.resize(record_len + 4); - m_input_queue.read(&output[2], record_len + 2); - output[0] = CLIENT_HELLO_SSLV2; - output[1] = 0; - output[2] = header[0] & 0x7F; - output[3] = header[1]; + msg.resize(record_len + 4); + // Fake v3-style handshake message wrapper + msg[0] = CLIENT_HELLO_SSLV2; + msg[1] = 0; + msg[2] = m_readbuf[0] & 0x7F; + msg[3] = m_readbuf[1]; + + copy_mem(&msg[4], &m_readbuf[2], m_readbuf_pos - 2); + m_readbuf_pos = 0; return 0; } - if(header[0] != CHANGE_CIPHER_SPEC && - header[0] != ALERT && - header[0] != HANDSHAKE && - header[0] != APPLICATION_DATA) + if(m_readbuf[0] != CHANGE_CIPHER_SPEC && + m_readbuf[0] != ALERT && + m_readbuf[0] != HANDSHAKE && + m_readbuf[0] != APPLICATION_DATA) { throw TLS_Exception(UNEXPECTED_MESSAGE, "Record_Reader: Unknown record type"); } - const u16bit version = make_u16bit(header[1], header[2]); - const u16bit record_len = make_u16bit(header[3], header[4]); + const u16bit version = make_u16bit(m_readbuf[1], m_readbuf[2]); + const u16bit record_len = make_u16bit(m_readbuf[3], m_readbuf[4]); - if(m_major && (header[1] != m_major || header[2] != m_minor)) + if(m_major && (m_readbuf[1] != m_major || m_readbuf[2] != m_minor)) throw TLS_Exception(PROTOCOL_VERSION, "Record_Reader: Got unexpected version"); @@ -200,42 +235,52 @@ size_t Record_Reader::get_record(byte& msg_type, throw TLS_Exception(RECORD_OVERFLOW, "Got message that exceeds maximum size"); - // If insufficient data, return without doing anything - if(have_in_queue < (sizeof(header) + record_len)) - return (sizeof(header) + record_len - have_in_queue); + consume_input(input, input_size, input_consumed, + (HEADER_SIZE + record_len) - m_readbuf_pos); + + if(m_readbuf_pos < (HEADER_SIZE + record_len)) + return ((HEADER_SIZE + record_len) - m_readbuf_pos); + BOTAN_ASSERT_EQUAL(HEADER_SIZE + record_len, m_readbuf_pos, + "Bad buffer handling in record body"); + + /* m_readbuf.resize(record_len); m_input_queue.read(header, sizeof(header)); // pull off the header - m_input_queue.read(&m_readbuf[0], m_readbuf.size()); + m_input_queue.read(&m_readbuf[0], record_len); + */ // Null mac means no encryption either, only valid during handshake if(m_mac_size == 0) { - if(header[0] != CHANGE_CIPHER_SPEC && - header[0] != ALERT && - header[0] != HANDSHAKE) + if(m_readbuf[0] != CHANGE_CIPHER_SPEC && + m_readbuf[0] != ALERT && + m_readbuf[0] != HANDSHAKE) { throw TLS_Exception(DECODE_ERROR, "Invalid msg type received during handshake"); } - msg_type = header[0]; - std::swap(output, m_readbuf); // move semantics + msg_type = m_readbuf[0]; + msg.resize(record_len); + copy_mem(&msg[0], &m_readbuf[5], record_len); + + m_readbuf_pos = 0; return 0; // got a full record } // Otherwise, decrypt, check MAC, return plaintext // FIXME: process in-place - m_cipher.process_msg(m_readbuf); - size_t got_back = m_cipher.read(&m_readbuf[0], m_readbuf.size(), Pipe::LAST_MESSAGE); - BOTAN_ASSERT_EQUAL(got_back, m_readbuf.size(), "Cipher didn't decrypt full amount"); + m_cipher.process_msg(&m_readbuf[5], record_len); + size_t got_back = m_cipher.read(&m_readbuf[5], record_len, Pipe::LAST_MESSAGE); + BOTAN_ASSERT_EQUAL(got_back, record_len, "Cipher didn't decrypt full amount"); size_t pad_size = 0; if(m_block_size) { - byte pad_value = m_readbuf[m_readbuf.size()-1]; + byte pad_value = m_readbuf[5 + (record_len-1)]; pad_size = pad_value + 1; /* @@ -256,7 +301,7 @@ size_t Record_Reader::get_record(byte& msg_type, bool padding_good = true; for(size_t i = 0; i != pad_size; ++i) - if(m_readbuf[m_readbuf.size()-i-1] != pad_value) + if(m_readbuf[5 + (record_len-i-1)] != pad_value) padding_good = false; if(!padding_good) @@ -264,41 +309,42 @@ size_t Record_Reader::get_record(byte& msg_type, } } - if(m_readbuf.size() < m_mac_size + pad_size + m_iv_size) + if(record_len < m_mac_size + pad_size + m_iv_size) throw Decoding_Error("Record_Reader: Record truncated"); - const u16bit plain_length = m_readbuf.size() - (m_mac_size + pad_size + m_iv_size); + const u16bit plain_length = record_len - (m_mac_size + pad_size + m_iv_size); if(plain_length > m_max_fragment) throw TLS_Exception(RECORD_OVERFLOW, "Plaintext record is too large"); m_mac->update_be(m_seq_no); - m_mac->update(header[0]); // msg_type + m_mac->update(m_readbuf[0]); // msg_type if(version != SSL_V3) for(size_t i = 0; i != 2; ++i) m_mac->update(get_byte(i, version)); m_mac->update_be(plain_length); - m_mac->update(&m_readbuf[m_iv_size], plain_length); + m_mac->update(&m_readbuf[5 + m_iv_size], plain_length); ++m_seq_no; MemoryVector<byte> computed_mac = m_mac->final(); - const size_t mac_offset = m_readbuf.size() - (m_mac_size + pad_size); + const size_t mac_offset = record_len - (m_mac_size + pad_size); if(computed_mac.size() != m_mac_size) throw TLS_Exception(INTERNAL_ERROR, "MAC produced value of unexpected size"); - if(!same_mem(&m_readbuf[mac_offset], &computed_mac[0], m_mac_size)) + if(!same_mem(&m_readbuf[5 + mac_offset], &computed_mac[0], m_mac_size)) throw TLS_Exception(BAD_RECORD_MAC, "Record_Reader: MAC failure"); - msg_type = header[0]; + msg_type = m_readbuf[0]; - output.resize(plain_length); - copy_mem(&output[0], &m_readbuf[m_iv_size], plain_length); + msg.resize(plain_length); + copy_mem(&msg[0], &m_readbuf[5 + m_iv_size], plain_length); + m_readbuf_pos = 0; return 0; } |