aboutsummaryrefslogtreecommitdiffstats
path: root/src/lib
diff options
context:
space:
mode:
authorJack Lloyd <[email protected]>2019-07-05 07:20:52 -0400
committerJack Lloyd <[email protected]>2019-07-05 07:20:52 -0400
commite0e13b7ee7fea7358939116a6e496da81356f0bb (patch)
tree8c209244b3baf6d49bb42e7e86948b8303d13362 /src/lib
parent51d9595d6842747c7723a5ebb4ac43054bed4e2a (diff)
parentb6c81dd38d60327d6e6118599f163933d6eee256 (diff)
Merge GH #2021 TLS record layer cleanups
Diffstat (limited to 'src/lib')
-rw-r--r--src/lib/tls/tls_channel.cpp60
-rw-r--r--src/lib/tls/tls_channel.h1
-rw-r--r--src/lib/tls/tls_record.cpp177
-rw-r--r--src/lib/tls/tls_record.h131
4 files changed, 173 insertions, 196 deletions
diff --git a/src/lib/tls/tls_channel.cpp b/src/lib/tls/tls_channel.cpp
index ced5dd3f1..eef9270eb 100644
--- a/src/lib/tls/tls_channel.cpp
+++ b/src/lib/tls/tls_channel.cpp
@@ -305,22 +305,20 @@ size_t Channel::received_data(const uint8_t input[], size_t input_size)
{
while(!is_closed() && input_size)
{
- secure_vector<uint8_t> record_data;
- uint64_t record_sequence = 0;
- Record_Type record_type = NO_RECORD;
- Protocol_Version record_version;
-
size_t consumed = 0;
- Record_Raw_Input raw_input(input, input_size, consumed, m_is_datagram);
- Record record(record_data, &record_sequence, &record_version, &record_type);
- const size_t needed =
- read_record(m_readbuf,
- raw_input,
- record,
+ const Record_Header record =
+ read_record(m_is_datagram,
+ m_readbuf,
+ input,
+ input_size,
+ consumed,
+ m_record_buf,
m_sequence_numbers.get(),
[this](uint16_t epoch) { return read_cipher_state_epoch(epoch); });
+ const size_t needed = record.needed();
+
BOTAN_ASSERT(consumed > 0, "Got to eat something");
BOTAN_ASSERT(consumed <= input_size,
@@ -332,20 +330,20 @@ size_t Channel::received_data(const uint8_t input[], size_t input_size)
BOTAN_ASSERT(input_size == 0 || needed == 0,
"Got a full record or consumed all input");
- // Ignore invalid records in DTLS
- if(m_is_datagram && *record.get_type() == NO_RECORD)
- return 0;
-
if(input_size == 0 && needed != 0)
return needed; // need more data to complete record
- if(record_data.size() > MAX_PLAINTEXT_SIZE)
+ // Ignore invalid records in DTLS
+ if(m_is_datagram && record.type() == NO_RECORD)
+ return 0;
+
+ if(m_record_buf.size() > MAX_PLAINTEXT_SIZE)
throw TLS_Exception(Alert::RECORD_OVERFLOW,
"TLS plaintext record is larger than allowed maximum");
if(auto pending = pending_state())
{
- if(pending->server_hello() != nullptr && record_version != pending->version())
+ if(pending->server_hello() != nullptr && record.version() != pending->version())
{
throw TLS_Exception(Alert::PROTOCOL_VERSION,
"Received unexpected record version");
@@ -353,7 +351,7 @@ size_t Channel::received_data(const uint8_t input[], size_t input_size)
}
else if(auto active = active_state())
{
- if(record_version != active->version())
+ if(record.version() != active->version())
{
throw TLS_Exception(Alert::PROTOCOL_VERSION,
"Received unexpected record version");
@@ -362,31 +360,31 @@ size_t Channel::received_data(const uint8_t input[], size_t input_size)
else
{
// For initial records just check for basic sanity
- if(record_version.major_version() != 3 &&
- record_version.major_version() != 0xFE)
+ if(record.version().major_version() != 3 &&
+ record.version().major_version() != 0xFE)
{
throw TLS_Exception(Alert::PROTOCOL_VERSION,
"Received unexpected record version in initial record");
}
}
- if(record_type == HANDSHAKE || record_type == CHANGE_CIPHER_SPEC)
+ if(record.type() == HANDSHAKE || record.type() == CHANGE_CIPHER_SPEC)
{
- process_handshake_ccs(record_data, record_sequence, record_type, record_version);
+ process_handshake_ccs(m_record_buf, record.sequence(), record.type(), record.version());
}
- else if(record_type == APPLICATION_DATA)
+ else if(record.type() == APPLICATION_DATA)
{
if(pending_state() != nullptr)
throw TLS_Exception(Alert::UNEXPECTED_MESSAGE, "Can't interleave application and handshake data");
- process_application_data(record_sequence, record_data);
+ process_application_data(record.sequence(), m_record_buf);
}
- else if(record_type == ALERT)
+ else if(record.type() == ALERT)
{
- process_alert(record_data);
+ process_alert(m_record_buf);
}
- else if(record_type != NO_RECORD)
+ else if(record.type() != NO_RECORD)
throw Unexpected_Message("Unexpected record type " +
- std::to_string(record_type) +
+ std::to_string(record.type()) +
" from counterparty");
}
@@ -520,12 +518,12 @@ void Channel::write_record(Connection_Cipher_State* cipher_state, uint16_t epoch
const Protocol_Version record_version =
(m_pending_state) ? (m_pending_state->version()) : (m_active_state->version());
- Record_Message record_message(record_type, 0, input, length);
-
TLS::write_record(m_writebuf,
- record_message,
+ record_type,
record_version,
sequence_numbers().next_write_sequence(epoch),
+ input,
+ length,
cipher_state,
m_rng);
diff --git a/src/lib/tls/tls_channel.h b/src/lib/tls/tls_channel.h
index 63cbcf0fc..2a2b74332 100644
--- a/src/lib/tls/tls_channel.h
+++ b/src/lib/tls/tls_channel.h
@@ -303,6 +303,7 @@ class BOTAN_PUBLIC_API(2,0) Channel
/* I/O buffers */
secure_vector<uint8_t> m_writebuf;
secure_vector<uint8_t> m_readbuf;
+ secure_vector<uint8_t> m_record_buf;
};
}
diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp
index 27714af0b..3304b70eb 100644
--- a/src/lib/tls/tls_record.cpp
+++ b/src/lib/tls/tls_record.cpp
@@ -189,41 +189,43 @@ inline void append_u16_len(secure_vector<uint8_t>& output, size_t len_field)
}
void write_record(secure_vector<uint8_t>& output,
- Record_Message msg,
+ uint8_t record_type,
Protocol_Version version,
- uint64_t seq,
+ uint64_t record_sequence,
+ const uint8_t* message,
+ size_t message_len,
Connection_Cipher_State* cs,
RandomNumberGenerator& rng)
{
output.clear();
- output.push_back(msg.get_type());
+ output.push_back(record_type);
output.push_back(version.major_version());
output.push_back(version.minor_version());
if(version.is_datagram_protocol())
{
for(size_t i = 0; i != 8; ++i)
- output.push_back(get_byte(i, seq));
+ output.push_back(get_byte(i, record_sequence));
}
if(!cs) // initial unencrypted handshake records
{
- append_u16_len(output, msg.get_size());
- output.insert(output.end(), msg.get_data(), msg.get_data() + msg.get_size());
+ append_u16_len(output, message_len);
+ output.insert(output.end(), message, message + message_len);
return;
}
AEAD_Mode& aead = cs->aead();
- std::vector<uint8_t> aad = cs->format_ad(seq, msg.get_type(), version, static_cast<uint16_t>(msg.get_size()));
+ std::vector<uint8_t> aad = cs->format_ad(record_sequence, record_type, version, static_cast<uint16_t>(message_len));
- const size_t ctext_size = aead.output_length(msg.get_size());
+ const size_t ctext_size = aead.output_length(message_len);
const size_t rec_size = ctext_size + cs->nonce_bytes_from_record();
aead.set_ad(aad);
- const std::vector<uint8_t> nonce = cs->aead_nonce(seq, rng);
+ const std::vector<uint8_t> nonce = cs->aead_nonce(record_sequence, rng);
append_u16_len(output, rec_size);
@@ -236,7 +238,7 @@ void write_record(secure_vector<uint8_t>& output,
}
const size_t header_size = output.size();
- output += std::make_pair(msg.get_data(), msg.get_size());
+ output += std::make_pair(message, message_len);
aead.start(nonce);
aead.finish(output, header_size);
@@ -300,35 +302,36 @@ void decrypt_record(secure_vector<uint8_t>& output,
aead.start(nonce);
- const size_t offset = output.size();
- output += std::make_pair(msg, msg_length);
- aead.finish(output, offset);
+ output.assign(msg, msg + msg_length);
+ aead.finish(output, 0);
}
-size_t read_tls_record(secure_vector<uint8_t>& readbuf,
- Record_Raw_Input& raw_input,
- Record& rec,
- Connection_Sequence_Numbers* sequence_numbers,
- get_cipherstate_fn get_cipherstate)
+Record_Header read_tls_record(secure_vector<uint8_t>& readbuf,
+ const uint8_t input[],
+ size_t input_len,
+ size_t& consumed,
+ secure_vector<uint8_t>& recbuf,
+ Connection_Sequence_Numbers* sequence_numbers,
+ get_cipherstate_fn get_cipherstate)
{
if(readbuf.size() < TLS_HEADER_SIZE) // header incomplete?
{
- if(size_t needed = fill_buffer_to(readbuf,
- raw_input.get_data(), raw_input.get_size(), raw_input.get_consumed(),
- TLS_HEADER_SIZE))
- return needed;
+ if(size_t needed = fill_buffer_to(readbuf, input, input_len, consumed, TLS_HEADER_SIZE))
+ {
+ return Record_Header(needed);
+ }
BOTAN_ASSERT_EQUAL(readbuf.size(), TLS_HEADER_SIZE, "Have an entire header");
}
- *rec.get_protocol_version() = Protocol_Version(readbuf[1], readbuf[2]);
+ const Protocol_Version version(readbuf[1], readbuf[2]);
- if(rec.get_protocol_version()->is_datagram_protocol())
+ if(version.is_datagram_protocol())
throw TLS_Exception(Alert::PROTOCOL_VERSION,
"Expected TLS but got a record with DTLS version");
const size_t record_size = make_uint16(readbuf[TLS_HEADER_SIZE-2],
- readbuf[TLS_HEADER_SIZE-1]);
+ readbuf[TLS_HEADER_SIZE-1]);
if(record_size > MAX_CIPHERTEXT_SIZE)
throw TLS_Exception(Alert::RECORD_OVERFLOW,
@@ -338,38 +341,36 @@ size_t read_tls_record(secure_vector<uint8_t>& readbuf,
throw TLS_Exception(Alert::DECODE_ERROR,
"Received a completely empty record");
- if(size_t needed = fill_buffer_to(readbuf,
- raw_input.get_data(), raw_input.get_size(), raw_input.get_consumed(),
- TLS_HEADER_SIZE + record_size))
- return needed;
+ if(size_t needed = fill_buffer_to(readbuf, input, input_len, consumed, TLS_HEADER_SIZE + record_size))
+ {
+ return Record_Header(needed);
+ }
BOTAN_ASSERT_EQUAL(static_cast<size_t>(TLS_HEADER_SIZE) + record_size,
readbuf.size(),
"Have the full record");
- *rec.get_type() = static_cast<Record_Type>(readbuf[0]);
+ const Record_Type type = static_cast<Record_Type>(readbuf[0]);
uint16_t epoch = 0;
+ uint64_t sequence = 0;
if(sequence_numbers)
{
- *rec.get_sequence() = sequence_numbers->next_read_sequence();
+ sequence = sequence_numbers->next_read_sequence();
epoch = sequence_numbers->current_read_epoch();
}
else
{
// server initial handshake case
- *rec.get_sequence() = 0;
epoch = 0;
}
- uint8_t* record_contents = &readbuf[TLS_HEADER_SIZE];
-
if(epoch == 0) // Unencrypted initial handshake
{
- rec.get_data().assign(readbuf.begin() + TLS_HEADER_SIZE, readbuf.begin() + TLS_HEADER_SIZE + record_size);
+ recbuf.assign(readbuf.begin() + TLS_HEADER_SIZE, readbuf.begin() + TLS_HEADER_SIZE + record_size);
readbuf.clear();
- return 0; // got a full record
+ return Record_Header(sequence, version, type);
}
// Otherwise, decrypt, check MAC, return plaintext
@@ -377,45 +378,46 @@ size_t read_tls_record(secure_vector<uint8_t>& readbuf,
BOTAN_ASSERT(cs, "Have cipherstate for this epoch");
- decrypt_record(rec.get_data(),
- record_contents,
+ decrypt_record(recbuf,
+ &readbuf[TLS_HEADER_SIZE],
record_size,
- *rec.get_sequence(),
- *rec.get_protocol_version(),
- *rec.get_type(),
+ sequence,
+ version,
+ type,
*cs);
if(sequence_numbers)
- sequence_numbers->read_accept(*rec.get_sequence());
+ sequence_numbers->read_accept(sequence);
readbuf.clear();
- return 0;
+ return Record_Header(sequence, version, type);
}
-size_t read_dtls_record(secure_vector<uint8_t>& readbuf,
- Record_Raw_Input& raw_input,
- Record& rec,
- Connection_Sequence_Numbers* sequence_numbers,
- get_cipherstate_fn get_cipherstate)
+Record_Header read_dtls_record(secure_vector<uint8_t>& readbuf,
+ const uint8_t input[],
+ size_t input_len,
+ size_t& consumed,
+ secure_vector<uint8_t>& recbuf,
+ Connection_Sequence_Numbers* sequence_numbers,
+ get_cipherstate_fn get_cipherstate)
{
if(readbuf.size() < DTLS_HEADER_SIZE) // header incomplete?
{
- if(fill_buffer_to(readbuf, raw_input.get_data(), raw_input.get_size(), raw_input.get_consumed(), DTLS_HEADER_SIZE))
+ if(fill_buffer_to(readbuf, input, input_len, consumed, DTLS_HEADER_SIZE))
{
readbuf.clear();
- return 0;
+ return Record_Header(0);
}
BOTAN_ASSERT_EQUAL(readbuf.size(), DTLS_HEADER_SIZE, "Have an entire header");
}
- *rec.get_protocol_version() = Protocol_Version(readbuf[1], readbuf[2]);
+ const Protocol_Version version(readbuf[1], readbuf[2]);
- if(rec.get_protocol_version()->is_datagram_protocol() == false)
+ if(version.is_datagram_protocol() == false)
{
readbuf.clear();
- *rec.get_type() = NO_RECORD;
- return 0;
+ return Record_Header(0);
}
const size_t record_size = make_uint16(readbuf[DTLS_HEADER_SIZE-2],
@@ -425,44 +427,39 @@ size_t read_dtls_record(secure_vector<uint8_t>& readbuf,
{
// Too large to be valid, ignore it
readbuf.clear();
- *rec.get_type() = NO_RECORD;
- return 0;
+ return Record_Header(0);
}
- if(fill_buffer_to(readbuf, raw_input.get_data(), raw_input.get_size(), raw_input.get_consumed(), DTLS_HEADER_SIZE + record_size))
+ if(fill_buffer_to(readbuf, input, input_len, consumed, DTLS_HEADER_SIZE + record_size))
{
// Truncated packet?
readbuf.clear();
- *rec.get_type() = NO_RECORD;
- return 0;
+ return Record_Header(0);
}
BOTAN_ASSERT_EQUAL(static_cast<size_t>(DTLS_HEADER_SIZE) + record_size, readbuf.size(),
"Have the full record");
- *rec.get_type() = static_cast<Record_Type>(readbuf[0]);
+ const Record_Type type = static_cast<Record_Type>(readbuf[0]);
uint16_t epoch = 0;
- *rec.get_sequence() = load_be<uint64_t>(&readbuf[3], 0);
- epoch = (*rec.get_sequence() >> 48);
+ const uint64_t sequence = load_be<uint64_t>(&readbuf[3], 0);
+ epoch = (sequence >> 48);
- if(sequence_numbers && sequence_numbers->already_seen(*rec.get_sequence()))
+ if(sequence_numbers && sequence_numbers->already_seen(sequence))
{
readbuf.clear();
- *rec.get_type() = NO_RECORD;
- return 0;
+ return Record_Header(0);
}
- uint8_t* record_contents = &readbuf[DTLS_HEADER_SIZE];
-
if(epoch == 0) // Unencrypted initial handshake
{
- rec.get_data().assign(readbuf.begin() + DTLS_HEADER_SIZE, readbuf.begin() + DTLS_HEADER_SIZE + record_size);
+ recbuf.assign(readbuf.begin() + DTLS_HEADER_SIZE, readbuf.begin() + DTLS_HEADER_SIZE + record_size);
readbuf.clear();
if(sequence_numbers)
- sequence_numbers->read_accept(*rec.get_sequence());
- return 0; // got a full record
+ sequence_numbers->read_accept(sequence);
+ return Record_Header(sequence, version, type);
}
try
@@ -472,42 +469,44 @@ size_t read_dtls_record(secure_vector<uint8_t>& readbuf,
BOTAN_ASSERT(cs, "Have cipherstate for this epoch");
- decrypt_record(rec.get_data(),
- record_contents,
+ decrypt_record(recbuf,
+ &readbuf[DTLS_HEADER_SIZE],
record_size,
- *rec.get_sequence(),
- *rec.get_protocol_version(),
- *rec.get_type(),
+ sequence,
+ version,
+ type,
*cs);
}
catch(std::exception&)
{
readbuf.clear();
- *rec.get_type() = NO_RECORD;
- return 0;
+ return Record_Header(0);
}
if(sequence_numbers)
- sequence_numbers->read_accept(*rec.get_sequence());
+ sequence_numbers->read_accept(sequence);
readbuf.clear();
- return 0;
+ return Record_Header(sequence, version, type);
}
}
-size_t read_record(secure_vector<uint8_t>& readbuf,
- Record_Raw_Input& raw_input,
- Record& rec,
- Connection_Sequence_Numbers* sequence_numbers,
- get_cipherstate_fn get_cipherstate)
+Record_Header read_record(bool is_datagram,
+ secure_vector<uint8_t>& readbuf,
+ const uint8_t input[],
+ size_t input_len,
+ size_t& consumed,
+ secure_vector<uint8_t>& recbuf,
+ Connection_Sequence_Numbers* sequence_numbers,
+ get_cipherstate_fn get_cipherstate)
{
- if(raw_input.is_datagram())
- return read_dtls_record(readbuf, raw_input, rec,
- sequence_numbers, get_cipherstate);
+ if(is_datagram)
+ return read_dtls_record(readbuf, input, input_len, consumed,
+ recbuf, sequence_numbers, get_cipherstate);
else
- return read_tls_record(readbuf, raw_input, rec,
- sequence_numbers, get_cipherstate);
+ return read_tls_record(readbuf, input, input_len, consumed,
+ recbuf, sequence_numbers, get_cipherstate);
}
}
diff --git a/src/lib/tls/tls_record.h b/src/lib/tls/tls_record.h
index d0ffc0270..3e3475c03 100644
--- a/src/lib/tls/tls_record.h
+++ b/src/lib/tls/tls_record.h
@@ -77,93 +77,69 @@ class Connection_Cipher_State final
size_t m_nonce_bytes_from_record;
};
-class Record final
+class Record_Header final
{
public:
- Record(secure_vector<uint8_t>& data,
- uint64_t* sequence,
- Protocol_Version* protocol_version,
- Record_Type* type)
- : m_data(data), m_sequence(sequence), m_protocol_version(protocol_version),
- m_type(type), m_size(data.size()) {}
-
- secure_vector<uint8_t>& get_data() { return m_data; }
-
- Protocol_Version* get_protocol_version() { return m_protocol_version; }
-
- uint64_t* get_sequence() { return m_sequence; }
-
- Record_Type* get_type() { return m_type; }
-
- size_t& get_size() { return m_size; }
-
- private:
- secure_vector<uint8_t>& m_data;
- uint64_t* m_sequence;
- Protocol_Version* m_protocol_version;
- Record_Type* m_type;
- size_t m_size;
- };
+ Record_Header(uint64_t sequence,
+ Protocol_Version version,
+ Record_Type type) :
+ m_needed(0),
+ m_sequence(sequence),
+ m_version(version),
+ m_type(type)
+ {}
+
+ Record_Header(size_t needed) :
+ m_needed(needed),
+ m_sequence(0),
+ m_version(Protocol_Version()),
+ m_type(NO_RECORD)
+ {}
+
+ size_t needed() const { return m_needed; }
+
+ Protocol_Version version() const
+ {
+ BOTAN_ASSERT_NOMSG(m_needed == 0);
+ return m_version;
+ }
-class Record_Message final
- {
- public:
- Record_Message(const uint8_t* data, size_t size)
- : m_type(0), m_sequence(0), m_data(data), m_size(size) {}
- Record_Message(uint8_t type, uint64_t sequence, const uint8_t* data, size_t size)
- : m_type(type), m_sequence(sequence), m_data(data),
- m_size(size) {}
+ uint64_t sequence() const
+ {
+ BOTAN_ASSERT_NOMSG(m_needed == 0);
+ return m_sequence;
+ }
- uint8_t& get_type() { return m_type; }
- uint64_t& get_sequence() { return m_sequence; }
- const uint8_t* get_data() { return m_data; }
- size_t& get_size() { return m_size; }
+ Record_Type type() const
+ {
+ BOTAN_ASSERT_NOMSG(m_needed == 0);
+ return m_type;
+ }
private:
- uint8_t m_type;
+ size_t m_needed;
uint64_t m_sequence;
- const uint8_t* m_data;
- size_t m_size;
-};
-
-class Record_Raw_Input final
- {
- public:
- Record_Raw_Input(const uint8_t* data, size_t size, size_t& consumed,
- bool is_datagram)
- : m_data(data), m_size(size), m_consumed(consumed),
- m_is_datagram(is_datagram) {}
-
- const uint8_t*& get_data() { return m_data; }
-
- size_t& get_size() { return m_size; }
-
- size_t& get_consumed() { return m_consumed; }
- void set_consumed(size_t consumed) { m_consumed = consumed; }
-
- bool is_datagram() { return m_is_datagram; }
-
- private:
- const uint8_t* m_data;
- size_t m_size;
- size_t& m_consumed;
- bool m_is_datagram;
+ Protocol_Version m_version;
+ Record_Type m_type;
};
-
/**
* Create a TLS record
* @param write_buffer the output record is placed here
-* @param rec_msg is the plaintext message
-* @param version is the protocol version
-* @param msg_sequence is the sequence number
+* @param record_type the record layer type
+* @param record_version the record layer version
+* @param record_sequence the record layer sequence number
+* @param message the record contents
+* @param message_len is size of message
* @param cipherstate is the writing cipher state
* @param rng is a random number generator
*/
void write_record(secure_vector<uint8_t>& write_buffer,
- Record_Message rec_msg,
- Protocol_Version version,
- uint64_t msg_sequence,
+ uint8_t record_type,
+ Protocol_Version record_version,
+ uint64_t record_sequence,
+ const uint8_t* message,
+ size_t message_len,
Connection_Cipher_State* cipherstate,
RandomNumberGenerator& rng);
@@ -174,11 +150,14 @@ typedef std::function<std::shared_ptr<Connection_Cipher_State> (uint16_t)> get_c
* Decode a TLS record
* @return zero if full message, else number of bytes still needed
*/
-size_t read_record(secure_vector<uint8_t>& read_buffer,
- Record_Raw_Input& raw_input,
- Record& rec,
- Connection_Sequence_Numbers* sequence_numbers,
- get_cipherstate_fn get_cipherstate);
+Record_Header read_record(bool is_datagram,
+ secure_vector<uint8_t>& read_buffer,
+ const uint8_t input[],
+ size_t input_len,
+ size_t& consumed,
+ secure_vector<uint8_t>& record_buf,
+ Connection_Sequence_Numbers* sequence_numbers,
+ get_cipherstate_fn get_cipherstate);
}