aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorlloyd <[email protected]>2012-12-10 17:53:19 +0000
committerlloyd <[email protected]>2012-12-10 17:53:19 +0000
commitb0164f672c53ba10cf24a5f5a502d9aac4746161 (patch)
treeb94baba373f943ad68d860acb939ffb9277139da
parent71e60dbdb404b715532a9e5d70efdff393602470 (diff)
parent79d3cfa5fd64ed4cfaa0643bb318edd38f22de92 (diff)
merge of '2a4d641c566916555a5127b4ba82a1fa9f9e2b0c'
and '59030896322f59cfd47ba0ff17993ccd263174c6'
-rw-r--r--src/tls/tls_channel.cpp54
-rw-r--r--src/tls/tls_handshake_io.cpp40
-rw-r--r--src/tls/tls_handshake_io.h16
-rw-r--r--src/tls/tls_record.cpp53
-rw-r--r--src/tls/tls_record.h50
5 files changed, 123 insertions, 90 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index 5858f5d90..6bbf60a24 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -254,7 +254,7 @@ bool Channel::heartbeat_sending_allowed() const
return false;
}
-size_t Channel::received_data(const byte buf[], size_t buf_size)
+size_t Channel::received_data(const byte input[], size_t input_size)
{
const auto get_cipherstate = [this](u16bit epoch)
{ return this->read_cipher_state_epoch(epoch).get(); };
@@ -263,57 +263,49 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
try
{
- while(!is_closed() && buf_size)
+ while(!is_closed() && input_size)
{
- byte rec_type = NO_RECORD;
- std::vector<byte> record;
- u64bit record_sequence = 0;
- Protocol_Version record_version;
+ Record record;
size_t consumed = 0;
const size_t needed =
read_record(m_readbuf,
- buf,
- buf_size,
+ input,
+ input_size,
consumed,
- rec_type,
record,
- record_version,
- record_sequence,
m_sequence_numbers.get(),
get_cipherstate);
- BOTAN_ASSERT(consumed <= buf_size,
+ BOTAN_ASSERT(consumed <= input_size,
"Record reader consumed sane amount");
- buf += consumed;
- buf_size -= consumed;
+ input += consumed;
+ input_size -= consumed;
- BOTAN_ASSERT(buf_size == 0 || needed == 0,
+ BOTAN_ASSERT(input_size == 0 || needed == 0,
"Got a full record or consumed all input");
- if(buf_size == 0 && needed != 0)
+ if(input_size == 0 && needed != 0)
return needed; // need more data to complete record
- if(rec_type == NO_RECORD)
- continue;
+ BOTAN_ASSERT(record.is_valid(), "Got a full record");
if(record.size() > max_fragment_size)
throw TLS_Exception(Alert::RECORD_OVERFLOW,
"Plaintext record is too large");
- if(rec_type == HANDSHAKE || rec_type == CHANGE_CIPHER_SPEC)
+ if(record.type() == HANDSHAKE || record.type() == CHANGE_CIPHER_SPEC)
{
if(!m_pending_state)
{
- create_handshake_state(record_version);
- if(record_version.is_datagram_protocol())
- sequence_numbers().read_accept(record_sequence);
+ create_handshake_state(record.version());
+ if(record.version().is_datagram_protocol())
+ sequence_numbers().read_accept(record.sequence());
}
- m_pending_state->handshake_io().add_input(
- rec_type, &record[0], record.size(), record_sequence);
+ m_pending_state->handshake_io().add_record(record);
while(auto pending = m_pending_state.get())
{
@@ -326,12 +318,12 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
msg.first, msg.second);
}
}
- else if(rec_type == HEARTBEAT && peer_supports_heartbeats())
+ else if(record.type() == HEARTBEAT && peer_supports_heartbeats())
{
if(!active_state())
throw Unexpected_Message("Heartbeat sent before handshake done");
- Heartbeat_Message heartbeat(record);
+ Heartbeat_Message heartbeat(record.contents());
const std::vector<byte>& payload = heartbeat.payload();
@@ -351,7 +343,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
m_proc_fn(&payload[0], payload.size(), Alert(Alert::HEARTBEAT_PAYLOAD));
}
}
- else if(rec_type == APPLICATION_DATA)
+ else if(record.type() == APPLICATION_DATA)
{
if(!active_state())
throw Unexpected_Message("Application data before handshake done");
@@ -362,11 +354,11 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
* following record. Avoid spurious callbacks.
*/
if(record.size() > 0)
- m_proc_fn(&record[0], record.size(), Alert());
+ m_proc_fn(record.bits(), record.size(), Alert());
}
- else if(rec_type == ALERT)
+ else if(record.type() == ALERT)
{
- Alert alert_msg(record);
+ Alert alert_msg(record.contents());
if(alert_msg.type() == Alert::NO_RENEGOTIATION)
m_pending_state.reset();
@@ -392,7 +384,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
}
else
throw Unexpected_Message("Unexpected record type " +
- std::to_string(rec_type) +
+ std::to_string(record.type()) +
" from counterparty");
}
diff --git a/src/tls/tls_handshake_io.cpp b/src/tls/tls_handshake_io.cpp
index 1fae7b5b7..c685c80ef 100644
--- a/src/tls/tls_handshake_io.cpp
+++ b/src/tls/tls_handshake_io.cpp
@@ -7,6 +7,7 @@
#include <botan/internal/tls_handshake_io.h>
#include <botan/internal/tls_messages.h>
+#include <botan/internal/tls_record.h>
#include <botan/internal/tls_seq_numbers.h>
#include <botan/exceptn.h>
@@ -38,18 +39,15 @@ Protocol_Version Stream_Handshake_IO::initial_record_version() const
return Protocol_Version::TLS_V10;
}
-void Stream_Handshake_IO::add_input(const byte rec_type,
- const byte record[],
- size_t record_size,
- u64bit /*record_number*/)
+void Stream_Handshake_IO::add_record(const Record& record)
{
- if(rec_type == HANDSHAKE)
+ if(record.type() == HANDSHAKE)
{
- m_queue.insert(m_queue.end(), record, record + record_size);
+ m_queue.insert(m_queue.end(), record.bits(), record.bits() + record.size());
}
- else if(rec_type == CHANGE_CIPHER_SPEC)
+ else if(record.type() == CHANGE_CIPHER_SPEC)
{
- if(record_size != 1 || record[0] != 1)
+ if(record.size() != 1 || record.bits()[0] != 1)
throw Decoding_Error("Invalid ChangeCipherSpec");
// Pretend it's a regular handshake message of zero length
@@ -120,14 +118,11 @@ Protocol_Version Datagram_Handshake_IO::initial_record_version() const
return Protocol_Version::DTLS_V10;
}
-void Datagram_Handshake_IO::add_input(const byte rec_type,
- const byte record[],
- size_t record_size,
- u64bit record_number)
+void Datagram_Handshake_IO::add_record(const Record& record)
{
- const u16bit epoch = static_cast<u16bit>(record_number >> 48);
+ const u16bit epoch = static_cast<u16bit>(record.sequence() >> 48);
- if(rec_type == CHANGE_CIPHER_SPEC)
+ if(record.type() == CHANGE_CIPHER_SPEC)
{
m_ccs_epochs.insert(epoch);
return;
@@ -135,16 +130,19 @@ void Datagram_Handshake_IO::add_input(const byte rec_type,
const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
+ const byte* record_bits = record.bits();
+ size_t record_size = record.size();
+
while(record_size)
{
if(record_size < DTLS_HANDSHAKE_HEADER_LEN)
return; // completely bogus? at least degenerate/weird
- const byte msg_type = record[0];
- const size_t msg_len = load_be24(&record[1]);
- const u16bit message_seq = load_be<u16bit>(&record[4], 0);
- const size_t fragment_offset = load_be24(&record[6]);
- const size_t fragment_length = load_be24(&record[9]);
+ const byte msg_type = record_bits[0];
+ const size_t msg_len = load_be24(&record_bits[1]);
+ const u16bit message_seq = load_be<u16bit>(&record_bits[4], 0);
+ const size_t fragment_offset = load_be24(&record_bits[6]);
+ const size_t fragment_length = load_be24(&record_bits[9]);
const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length;
@@ -153,7 +151,7 @@ void Datagram_Handshake_IO::add_input(const byte rec_type,
if(message_seq >= m_in_message_seq)
{
- m_messages[message_seq].add_fragment(&record[DTLS_HANDSHAKE_HEADER_LEN],
+ m_messages[message_seq].add_fragment(&record_bits[DTLS_HANDSHAKE_HEADER_LEN],
fragment_length,
fragment_offset,
epoch,
@@ -161,7 +159,7 @@ void Datagram_Handshake_IO::add_input(const byte rec_type,
msg_len);
}
- record += total_size;
+ record_bits += total_size;
record_size -= total_size;
}
}
diff --git a/src/tls/tls_handshake_io.h b/src/tls/tls_handshake_io.h
index 18fde1a83..82d1a8e7e 100644
--- a/src/tls/tls_handshake_io.h
+++ b/src/tls/tls_handshake_io.h
@@ -24,6 +24,7 @@ namespace Botan {
namespace TLS {
class Handshake_Message;
+class Record;
/**
* Handshake IO Interface
@@ -39,10 +40,7 @@ class Handshake_IO
const std::vector<byte>& handshake_msg,
Handshake_Type handshake_type) const = 0;
- virtual void add_input(byte record_type,
- const byte record[],
- size_t record_size,
- u64bit record_number) = 0;
+ virtual void add_record(const Record& record) = 0;
/**
* Returns (HANDSHAKE_NONE, std::vector<>()) if no message currently available
@@ -76,10 +74,7 @@ class Stream_Handshake_IO : public Handshake_IO
const std::vector<byte>& handshake_msg,
Handshake_Type handshake_type) const override;
- void add_input(byte record_type,
- const byte record[],
- size_t record_size,
- u64bit record_number) override;
+ void add_record(const Record& record) override;
std::pair<Handshake_Type, std::vector<byte>>
get_next_record(bool expecting_ccs) override;
@@ -106,10 +101,7 @@ class Datagram_Handshake_IO : public Handshake_IO
const std::vector<byte>& handshake_msg,
Handshake_Type handshake_type) const override;
- void add_input(const byte rec_type,
- const byte record[],
- size_t record_size,
- u64bit record_number) override;
+ void add_record(const Record& record) override;
std::pair<Handshake_Type, std::vector<byte>>
get_next_record(bool expecting_ccs) override;
diff --git a/src/tls/tls_record.cpp b/src/tls/tls_record.cpp
index fab966e72..0557c1796 100644
--- a/src/tls/tls_record.cpp
+++ b/src/tls/tls_record.cpp
@@ -270,10 +270,7 @@ size_t read_record(std::vector<byte>& readbuf,
const byte input[],
size_t input_sz,
size_t& consumed,
- byte& msg_type,
- std::vector<byte>& msg,
- Protocol_Version& record_version,
- u64bit& record_sequence,
+ Record& record,
Connection_Sequence_Numbers* sequence_numbers,
std::function<Connection_Cipher_State* (u16bit)> get_cipherstate)
{
@@ -309,24 +306,27 @@ size_t read_record(std::vector<byte>& readbuf,
BOTAN_ASSERT_EQUAL(readbuf.size(), (record_len + 2),
"Have the entire SSLv2 hello");
- msg_type = HANDSHAKE;
+ // Fake v3-style handshake message wrapper
+ std::vector<byte> sslv2_hello(4 + readbuf.size() - 2);
- msg.resize(record_len + 4);
+ sslv2_hello[0] = CLIENT_HELLO_SSLV2;
+ sslv2_hello[1] = 0;
+ sslv2_hello[2] = readbuf[0] & 0x7F;
+ sslv2_hello[3] = readbuf[1];
- // Fake v3-style handshake message wrapper
- msg[0] = CLIENT_HELLO_SSLV2;
- msg[1] = 0;
- msg[2] = readbuf[0] & 0x7F;
- msg[3] = readbuf[1];
+ copy_mem(&sslv2_hello[4], &readbuf[2], readbuf.size() - 2);
- copy_mem(&msg[4], &readbuf[2], readbuf.size() - 2);
+ record = Record(0,
+ Protocol_Version::TLS_V10,
+ HANDSHAKE,
+ std::move(sslv2_hello));
readbuf.clear();
return 0;
}
}
- record_version = Protocol_Version(readbuf[1], readbuf[2]);
+ Protocol_Version record_version = Protocol_Version(readbuf[1], readbuf[2]);
const bool is_dtls = record_version.is_datagram_protocol();
@@ -359,6 +359,9 @@ size_t read_record(std::vector<byte>& readbuf,
readbuf.size(),
"Have the full record");
+ Record_Type record_type = static_cast<Record_Type>(readbuf[0]);
+
+ u64bit record_sequence = 0;
u16bit epoch = 0;
if(is_dtls)
@@ -385,8 +388,11 @@ size_t read_record(std::vector<byte>& readbuf,
if(epoch == 0) // Unencrypted initial handshake
{
- msg_type = readbuf[0];
- msg.assign(&record_contents[0], &record_contents[record_len]);
+ record = Record(record_sequence,
+ record_version,
+ record_type,
+ &readbuf[header_size],
+ record_len);
readbuf.clear();
return 0; // got a full record
@@ -453,7 +459,7 @@ size_t read_record(std::vector<byte>& readbuf,
throw Decoding_Error("Record sent with invalid length");
cipherstate->mac()->update_be(record_sequence);
- cipherstate->mac()->update(readbuf[0]); // msg_type
+ cipherstate->mac()->update(static_cast<byte>(record_type));
if(cipherstate->mac_includes_record_version())
{
@@ -461,10 +467,11 @@ size_t read_record(std::vector<byte>& readbuf,
cipherstate->mac()->update(record_version.minor_version());
}
- const u16bit plain_length = record_len - mac_pad_iv_size;
+ const byte* plaintext_block = &record_contents[iv_size];
+ const u16bit plaintext_length = record_len - mac_pad_iv_size;
- cipherstate->mac()->update_be(plain_length);
- cipherstate->mac()->update(&record_contents[iv_size], plain_length);
+ cipherstate->mac()->update_be(plaintext_length);
+ cipherstate->mac()->update(plaintext_block, plaintext_length);
std::vector<byte> mac_buf(mac_size);
cipherstate->mac()->final(&mac_buf[0]);
@@ -481,9 +488,11 @@ size_t read_record(std::vector<byte>& readbuf,
if(sequence_numbers)
sequence_numbers->read_accept(record_sequence);
- msg_type = readbuf[0];
- msg.assign(&record_contents[iv_size],
- &record_contents[iv_size + plain_length]);
+ record = Record(record_sequence,
+ record_version,
+ record_type,
+ plaintext_block,
+ plaintext_length);
readbuf.clear();
return 0;
diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h
index a73efccb1..bc86600fa 100644
--- a/src/tls/tls_record.h
+++ b/src/tls/tls_record.h
@@ -78,6 +78,51 @@ class Connection_Cipher_State
bool m_is_ssl3 = false;
};
+class Record
+ {
+ public:
+ Record() {}
+
+ Record(u64bit sequence,
+ Protocol_Version version,
+ Record_Type type,
+ const byte contents[],
+ size_t contents_size) :
+ m_sequence(sequence),
+ m_version(version),
+ m_type(type),
+ m_contents(contents, contents + contents_size) {}
+
+ Record(u64bit sequence,
+ Protocol_Version version,
+ Record_Type type,
+ std::vector<byte>&& contents) :
+ m_sequence(sequence),
+ m_version(version),
+ m_type(type),
+ m_contents(contents) {}
+
+ bool is_valid() const { return m_type != NO_RECORD; }
+
+ u64bit sequence() const { return m_sequence; }
+
+ Record_Type type() const { return m_type; }
+
+ Protocol_Version version() const { return m_version; }
+
+ const std::vector<byte>& contents() const { return m_contents; }
+
+ const byte* bits() const { return &m_contents[0]; }
+
+ size_t size() const { return m_contents.size(); }
+
+ private:
+ u64bit m_sequence = 0;
+ Protocol_Version m_version = Protocol_Version();
+ Record_Type m_type = NO_RECORD;
+ std::vector<byte> m_contents;
+ };
+
/**
* Create a TLS record
* @param write_buffer the output record is placed here
@@ -105,10 +150,7 @@ size_t read_record(std::vector<byte>& read_buffer,
const byte input[],
size_t input_length,
size_t& input_consumed,
- byte& msg_type,
- std::vector<byte>& msg,
- Protocol_Version& record_version,
- u64bit& record_sequence,
+ Record& output_record,
Connection_Sequence_Numbers* sequence_numbers,
std::function<Connection_Cipher_State* (u16bit)> get_cipherstate);