aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/tls/rec_read.cpp150
-rw-r--r--src/tls/tls_channel.cpp26
-rw-r--r--src/tls/tls_record.h35
3 files changed, 138 insertions, 73 deletions
diff --git a/src/tls/rec_read.cpp b/src/tls/rec_read.cpp
index 518540bab..957e86aaa 100644
--- a/src/tls/rec_read.cpp
+++ b/src/tls/rec_read.cpp
@@ -19,6 +19,9 @@ Record_Reader::Record_Reader()
m_mac = 0;
reset();
set_maximum_fragment_size(0);
+
+ // A single record is never larger than this
+ m_readbuf.resize(MAX_CIPHERTEXT_SIZE);
}
/*
@@ -31,6 +34,9 @@ void Record_Reader::reset()
delete m_mac;
m_mac = 0;
+ zeroise(m_readbuf);
+ m_readbuf_pos = 0;
+
m_mac_size = 0;
m_block_size = 0;
m_iv_size = 0;
@@ -137,62 +143,91 @@ void Record_Reader::activate(const TLS_Cipher_Suite& suite,
throw Invalid_Argument("Record_Reader: Unknown hash " + mac_algo);
}
-void Record_Reader::add_input(const byte input[], size_t input_size)
+void Record_Reader::consume_input(const byte*& input,
+ size_t& input_size,
+ size_t& input_consumed,
+ size_t desired)
{
- m_input_queue.write(input, input_size);
+ const size_t space_available = (m_readbuf.size() - m_readbuf_pos);
+ const size_t taken = std::min(input_size, desired);
+
+ if(taken > space_available)
+ throw TLS_Exception(RECORD_OVERFLOW,
+ "Record is larger than allowed maximum size");
+
+ copy_mem(&m_readbuf[m_readbuf_pos], input, taken);
+ m_readbuf_pos += taken;
+ input_consumed += taken;
+ input_size -= taken;
+ input += taken;
}
/*
* Retrieve the next record
*/
-size_t Record_Reader::get_record(byte& msg_type,
- MemoryVector<byte>& output)
+size_t Record_Reader::add_input(const byte input_array[], size_t input_size,
+ size_t& input_consumed,
+ byte& msg_type,
+ MemoryVector<byte>& msg)
{
- byte header[5] = { 0 };
+ const byte* input = &input_array[0];
- const size_t have_in_queue = m_input_queue.size();
+ input_consumed = 0;
- if(have_in_queue < sizeof(header))
- return (sizeof(header) - have_in_queue);
+ const size_t HEADER_SIZE = 5;
- /*
- * We peek first to make sure we have the full record
- */
- m_input_queue.peek(header, sizeof(header));
+ if(m_readbuf_pos < HEADER_SIZE) // header incomplete?
+ {
+ consume_input(input, input_size, input_consumed, HEADER_SIZE - m_readbuf_pos);
+
+ if(m_readbuf_pos < HEADER_SIZE)
+ return (HEADER_SIZE - m_readbuf_pos); // header still incomplete
+
+ BOTAN_ASSERT_EQUAL(m_readbuf_pos, HEADER_SIZE,
+ "Buffer error in SSL header");
+ }
// SSLv2-format client hello?
- if(header[0] & 0x80 && header[2] == 1 && header[3] == 3)
+ if(m_readbuf[0] & 0x80 && m_readbuf[2] == 1 && m_readbuf[3] >= 3)
{
- size_t record_len = make_u16bit(header[0], header[1]) & 0x7FFF;
+ size_t record_len = make_u16bit(m_readbuf[0], m_readbuf[1]) & 0x7FFF;
+
+ consume_input(input, input_size, input_consumed, (record_len + 2) - m_readbuf_pos);
+
+ if(m_readbuf_pos < (record_len + 2))
+ return ((record_len + 2) - m_readbuf_pos);
- if(have_in_queue < record_len + 2)
- return (record_len + 2 - have_in_queue);
+ BOTAN_ASSERT_EQUAL(m_readbuf_pos, (record_len + 2),
+ "Buffer error in SSLv2 hello");
msg_type = HANDSHAKE;
- output.resize(record_len + 4);
- m_input_queue.read(&output[2], record_len + 2);
- output[0] = CLIENT_HELLO_SSLV2;
- output[1] = 0;
- output[2] = header[0] & 0x7F;
- output[3] = header[1];
+ msg.resize(record_len + 4);
+ // Fake v3-style handshake message wrapper
+ msg[0] = CLIENT_HELLO_SSLV2;
+ msg[1] = 0;
+ msg[2] = m_readbuf[0] & 0x7F;
+ msg[3] = m_readbuf[1];
+
+ copy_mem(&msg[4], &m_readbuf[2], m_readbuf_pos - 2);
+ m_readbuf_pos = 0;
return 0;
}
- if(header[0] != CHANGE_CIPHER_SPEC &&
- header[0] != ALERT &&
- header[0] != HANDSHAKE &&
- header[0] != APPLICATION_DATA)
+ if(m_readbuf[0] != CHANGE_CIPHER_SPEC &&
+ m_readbuf[0] != ALERT &&
+ m_readbuf[0] != HANDSHAKE &&
+ m_readbuf[0] != APPLICATION_DATA)
{
throw TLS_Exception(UNEXPECTED_MESSAGE,
"Record_Reader: Unknown record type");
}
- const u16bit version = make_u16bit(header[1], header[2]);
- const u16bit record_len = make_u16bit(header[3], header[4]);
+ const u16bit version = make_u16bit(m_readbuf[1], m_readbuf[2]);
+ const u16bit record_len = make_u16bit(m_readbuf[3], m_readbuf[4]);
- if(m_major && (header[1] != m_major || header[2] != m_minor))
+ if(m_major && (m_readbuf[1] != m_major || m_readbuf[2] != m_minor))
throw TLS_Exception(PROTOCOL_VERSION,
"Record_Reader: Got unexpected version");
@@ -200,42 +235,52 @@ size_t Record_Reader::get_record(byte& msg_type,
throw TLS_Exception(RECORD_OVERFLOW,
"Got message that exceeds maximum size");
- // If insufficient data, return without doing anything
- if(have_in_queue < (sizeof(header) + record_len))
- return (sizeof(header) + record_len - have_in_queue);
+ consume_input(input, input_size, input_consumed,
+ (HEADER_SIZE + record_len) - m_readbuf_pos);
+
+ if(m_readbuf_pos < (HEADER_SIZE + record_len))
+ return ((HEADER_SIZE + record_len) - m_readbuf_pos);
+ BOTAN_ASSERT_EQUAL(HEADER_SIZE + record_len, m_readbuf_pos,
+ "Bad buffer handling in record body");
+
+ /*
m_readbuf.resize(record_len);
m_input_queue.read(header, sizeof(header)); // pull off the header
- m_input_queue.read(&m_readbuf[0], m_readbuf.size());
+ m_input_queue.read(&m_readbuf[0], record_len);
+ */
// Null mac means no encryption either, only valid during handshake
if(m_mac_size == 0)
{
- if(header[0] != CHANGE_CIPHER_SPEC &&
- header[0] != ALERT &&
- header[0] != HANDSHAKE)
+ if(m_readbuf[0] != CHANGE_CIPHER_SPEC &&
+ m_readbuf[0] != ALERT &&
+ m_readbuf[0] != HANDSHAKE)
{
throw TLS_Exception(DECODE_ERROR, "Invalid msg type received during handshake");
}
- msg_type = header[0];
- std::swap(output, m_readbuf); // move semantics
+ msg_type = m_readbuf[0];
+ msg.resize(record_len);
+ copy_mem(&msg[0], &m_readbuf[5], record_len);
+
+ m_readbuf_pos = 0;
return 0; // got a full record
}
// Otherwise, decrypt, check MAC, return plaintext
// FIXME: process in-place
- m_cipher.process_msg(m_readbuf);
- size_t got_back = m_cipher.read(&m_readbuf[0], m_readbuf.size(), Pipe::LAST_MESSAGE);
- BOTAN_ASSERT_EQUAL(got_back, m_readbuf.size(), "Cipher didn't decrypt full amount");
+ m_cipher.process_msg(&m_readbuf[5], record_len);
+ size_t got_back = m_cipher.read(&m_readbuf[5], record_len, Pipe::LAST_MESSAGE);
+ BOTAN_ASSERT_EQUAL(got_back, record_len, "Cipher didn't decrypt full amount");
size_t pad_size = 0;
if(m_block_size)
{
- byte pad_value = m_readbuf[m_readbuf.size()-1];
+ byte pad_value = m_readbuf[5 + (record_len-1)];
pad_size = pad_value + 1;
/*
@@ -256,7 +301,7 @@ size_t Record_Reader::get_record(byte& msg_type,
bool padding_good = true;
for(size_t i = 0; i != pad_size; ++i)
- if(m_readbuf[m_readbuf.size()-i-1] != pad_value)
+ if(m_readbuf[5 + (record_len-i-1)] != pad_value)
padding_good = false;
if(!padding_good)
@@ -264,41 +309,42 @@ size_t Record_Reader::get_record(byte& msg_type,
}
}
- if(m_readbuf.size() < m_mac_size + pad_size + m_iv_size)
+ if(record_len < m_mac_size + pad_size + m_iv_size)
throw Decoding_Error("Record_Reader: Record truncated");
- const u16bit plain_length = m_readbuf.size() - (m_mac_size + pad_size + m_iv_size);
+ const u16bit plain_length = record_len - (m_mac_size + pad_size + m_iv_size);
if(plain_length > m_max_fragment)
throw TLS_Exception(RECORD_OVERFLOW, "Plaintext record is too large");
m_mac->update_be(m_seq_no);
- m_mac->update(header[0]); // msg_type
+ m_mac->update(m_readbuf[0]); // msg_type
if(version != SSL_V3)
for(size_t i = 0; i != 2; ++i)
m_mac->update(get_byte(i, version));
m_mac->update_be(plain_length);
- m_mac->update(&m_readbuf[m_iv_size], plain_length);
+ m_mac->update(&m_readbuf[5 + m_iv_size], plain_length);
++m_seq_no;
MemoryVector<byte> computed_mac = m_mac->final();
- const size_t mac_offset = m_readbuf.size() - (m_mac_size + pad_size);
+ const size_t mac_offset = record_len - (m_mac_size + pad_size);
if(computed_mac.size() != m_mac_size)
throw TLS_Exception(INTERNAL_ERROR,
"MAC produced value of unexpected size");
- if(!same_mem(&m_readbuf[mac_offset], &computed_mac[0], m_mac_size))
+ if(!same_mem(&m_readbuf[5 + mac_offset], &computed_mac[0], m_mac_size))
throw TLS_Exception(BAD_RECORD_MAC, "Record_Reader: MAC failure");
- msg_type = header[0];
+ msg_type = m_readbuf[0];
- output.resize(plain_length);
- copy_mem(&output[0], &m_readbuf[m_iv_size], plain_length);
+ msg.resize(plain_length);
+ copy_mem(&msg[0], &m_readbuf[5 + m_iv_size], plain_length);
+ m_readbuf_pos = 0;
return 0;
}
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index 73c4fd4ab..7fda4bc86 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -1,6 +1,6 @@
/*
* TLS Channels
-* (C) 2011 Jack Lloyd
+* (C) 2011-2012 Jack Lloyd
*
* Released under the terms of the Botan license
*/
@@ -8,6 +8,7 @@
#include <botan/tls_channel.h>
#include <botan/internal/tls_alerts.h>
#include <botan/internal/tls_handshake_state.h>
+#include <botan/internal/assert.h>
#include <botan/loadstor.h>
namespace Botan {
@@ -42,17 +43,21 @@ size_t TLS_Channel::received_data(const byte buf[], size_t buf_size)
{
try
{
- reader.add_input(buf, buf_size);
+ while(buf_size)
+ {
+ byte rec_type = CONNECTION_CLOSED;
+ MemoryVector<byte> record;
+ size_t consumed = 0;
- byte rec_type = CONNECTION_CLOSED;
- MemoryVector<byte> record;
+ const size_t needed = reader.add_input(buf, buf_size,
+ consumed,
+ rec_type, record);
- while(!reader.currently_empty())
- {
- const size_t bytes_needed = reader.get_record(rec_type, record);
+ buf += consumed;
+ buf_size -= consumed;
- if(bytes_needed > 0)
- return bytes_needed;
+ if(buf_size == 0 && needed != 0)
+ return needed; // need more data to complete record
if(rec_type == APPLICATION_DATA)
{
@@ -95,7 +100,8 @@ size_t TLS_Channel::received_data(const byte buf[], size_t buf_size)
}
}
else
- throw Unexpected_Message("Unknown message type received");
+ throw Unexpected_Message("Unknown TLS message type " +
+ to_string(rec_type) + " received");
}
return 0; // on a record boundary
diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h
index 8e89b9f8a..f4f3e697f 100644
--- a/src/tls/tls_record.h
+++ b/src/tls/tls_record.h
@@ -59,6 +59,9 @@ class BOTAN_DLL Record_Writer
~Record_Writer() { delete m_mac; }
private:
+ Record_Writer(const Record_Writer&) {}
+ Record_Writer& operator=(const Record_Writer&) { return (*this); }
+
void send_record(byte type, const byte input[], size_t length);
std::tr1::function<void (const byte[], size_t)> m_output_fn;
@@ -80,17 +83,21 @@ class BOTAN_DLL Record_Writer
class BOTAN_DLL Record_Reader
{
public:
- void add_input(const byte input[], size_t input_size);
/**
- * @param msg_type (output variable)
- * @param buffer (output variable)
- * @return Number of bytes still needed (minimum), or 0 if success
+ * @param input new input data (may be NULL if input_size == 0)
+ * @param input_size size of input in bytes
+ * @param input_consumed is set to the number of bytes of input
+ * that were consumed
+ * @param msg_type is set to the type of the message just read if
+ * this function returns 0
+ * @param msg is set to the contents of the record
+ * @return number of bytes still needed (minimum), or 0 if success
*/
- size_t get_record(byte& msg_type,
- MemoryVector<byte>& buffer);
-
- SecureVector<byte> get_record(byte& msg_type);
+ size_t add_input(const byte input[], size_t input_size,
+ size_t& input_consumed,
+ byte& msg_type,
+ MemoryVector<byte>& msg);
void activate(const TLS_Cipher_Suite& suite,
const SessionKeys& keys,
@@ -102,16 +109,22 @@ class BOTAN_DLL Record_Reader
void reset();
- bool currently_empty() const { return m_input_queue.size() == 0; }
-
void set_maximum_fragment_size(size_t max_fragment);
Record_Reader();
~Record_Reader() { delete m_mac; }
private:
+ Record_Reader(const Record_Reader&) {}
+ Record_Reader& operator=(const Record_Reader&) { return (*this); }
+
+ void consume_input(const byte*& input,
+ size_t& input_size,
+ size_t& input_consumed,
+ size_t desired);
+
MemoryVector<byte> m_readbuf;
- SecureQueue m_input_queue;
+ size_t m_readbuf_pos;
Pipe m_cipher;
MessageAuthenticationCode* m_mac;