aboutsummaryrefslogtreecommitdiffstats
path: root/src/tls
diff options
context:
space:
mode:
Diffstat (limited to 'src/tls')
-rw-r--r--src/tls/tls_alert.cpp2
-rw-r--r--src/tls/tls_alert.h2
-rw-r--r--src/tls/tls_channel.cpp36
-rw-r--r--src/tls/tls_record.cpp226
-rw-r--r--src/tls/tls_record.h50
5 files changed, 120 insertions, 196 deletions
diff --git a/src/tls/tls_alert.cpp b/src/tls/tls_alert.cpp
index f548bd57b..15bb2a2dc 100644
--- a/src/tls/tls_alert.cpp
+++ b/src/tls/tls_alert.cpp
@@ -12,7 +12,7 @@ namespace Botan {
namespace TLS {
-Alert::Alert(const std::vector<byte>& buf)
+Alert::Alert(const secure_vector<byte>& buf)
{
if(buf.size() != 2)
throw Decoding_Error("Alert: Bad size " + std::to_string(buf.size()) +
diff --git a/src/tls/tls_alert.h b/src/tls/tls_alert.h
index 12ab57d6b..bf32178ee 100644
--- a/src/tls/tls_alert.h
+++ b/src/tls/tls_alert.h
@@ -90,7 +90,7 @@ class BOTAN_DLL Alert
* Deserialize an Alert message
* @param buf the serialized alert
*/
- Alert(const std::vector<byte>& buf);
+ Alert(const secure_vector<byte>& buf);
/**
* Create a new Alert
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index c00970c49..7c7d65961 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -278,7 +278,10 @@ size_t Channel::received_data(const byte input[], size_t input_size)
{
while(!is_closed() && input_size)
{
- Record record;
+ secure_vector<byte> record;
+ u64bit record_sequence = 0;
+ Record_Type record_type = NO_RECORD;
+ Protocol_Version record_version;
size_t consumed = 0;
@@ -288,6 +291,9 @@ size_t Channel::received_data(const byte input[], size_t input_size)
input_size,
consumed,
record,
+ &record_sequence,
+ &record_version,
+ &record_type,
m_sequence_numbers.get(),
get_cipherstate);
@@ -303,22 +309,22 @@ size_t Channel::received_data(const byte input[], size_t input_size)
if(input_size == 0 && needed != 0)
return needed; // need more data to complete record
- BOTAN_ASSERT(record.is_valid(), "Got a full record");
-
if(record.size() > max_fragment_size)
throw TLS_Exception(Alert::RECORD_OVERFLOW,
"Plaintext record is too large");
- if(record.type() == HANDSHAKE || record.type() == CHANGE_CIPHER_SPEC)
+ if(record_type == HANDSHAKE || record_type == CHANGE_CIPHER_SPEC)
{
if(!m_pending_state)
{
- create_handshake_state(record.version());
- if(record.version().is_datagram_protocol())
- sequence_numbers().read_accept(record.sequence());
+ create_handshake_state(record_version);
+ if(record_version.is_datagram_protocol())
+ sequence_numbers().read_accept(record_sequence);
}
- m_pending_state->handshake_io().add_record(record.contents(), record.type(), record.sequence());
+ m_pending_state->handshake_io().add_record(unlock(record),
+ record_type,
+ record_sequence);
while(auto pending = m_pending_state.get())
{
@@ -331,12 +337,12 @@ size_t Channel::received_data(const byte input[], size_t input_size)
msg.first, msg.second);
}
}
- else if(record.type() == HEARTBEAT && peer_supports_heartbeats())
+ else if(record_type == HEARTBEAT && peer_supports_heartbeats())
{
if(!active_state())
throw Unexpected_Message("Heartbeat sent before handshake done");
- Heartbeat_Message heartbeat(record.contents());
+ Heartbeat_Message heartbeat(unlock(record));
const std::vector<byte>& payload = heartbeat.payload();
@@ -356,7 +362,7 @@ size_t Channel::received_data(const byte input[], size_t input_size)
m_proc_fn(&payload[0], payload.size(), Alert(Alert::HEARTBEAT_PAYLOAD));
}
}
- else if(record.type() == APPLICATION_DATA)
+ else if(record_type == APPLICATION_DATA)
{
if(!active_state())
throw Unexpected_Message("Application data before handshake done");
@@ -367,11 +373,11 @@ size_t Channel::received_data(const byte input[], size_t input_size)
* following record. Avoid spurious callbacks.
*/
if(record.size() > 0)
- m_proc_fn(record.bits(), record.size(), Alert());
+ m_proc_fn(&record[0], record.size(), Alert());
}
- else if(record.type() == ALERT)
+ else if(record_type == ALERT)
{
- Alert alert_msg(record.contents());
+ Alert alert_msg(record);
if(alert_msg.type() == Alert::NO_RENEGOTIATION)
m_pending_state.reset();
@@ -395,7 +401,7 @@ size_t Channel::received_data(const byte input[], size_t input_size)
}
else
throw Unexpected_Message("Unexpected record type " +
- std::to_string(record.type()) +
+ std::to_string(record_type) +
" from counterparty");
}
diff --git a/src/tls/tls_record.cpp b/src/tls/tls_record.cpp
index d9b222a85..2bfe78152 100644
--- a/src/tls/tls_record.cpp
+++ b/src/tls/tls_record.cpp
@@ -297,60 +297,6 @@ size_t fill_buffer_to(secure_vector<byte>& readbuf,
}
/*
-* MAC scheme used in SSLv3/TLSv1 for RC4 and CBC ciphers
-*/
-bool traditional_mac_check(Record& output_record,
- byte record_contents[], size_t record_len,
- size_t pad_size,
- volatile bool padding_bad,
- u64bit record_sequence,
- Protocol_Version record_version,
- Record_Type record_type,
- Connection_Cipher_State& cipherstate)
- {
- const size_t mac_size = cipherstate.mac_size();
- const size_t iv_size = cipherstate.iv_size();
-
- cipherstate.mac()->update_be(record_sequence);
- cipherstate.mac()->update(static_cast<byte>(record_type));
-
- if(cipherstate.mac_includes_record_version())
- {
- cipherstate.mac()->update(record_version.major_version());
- cipherstate.mac()->update(record_version.minor_version());
- }
-
- const size_t mac_pad_iv_size = mac_size + pad_size + iv_size;
-
- if(record_len < mac_pad_iv_size)
- throw Decoding_Error("Record sent with invalid length");
-
- const byte* plaintext_block = &record_contents[iv_size];
- const u16bit plaintext_length = record_len - mac_pad_iv_size;
-
- cipherstate.mac()->update_be(plaintext_length);
- cipherstate.mac()->update(plaintext_block, plaintext_length);
-
- std::vector<byte> mac_buf(mac_size);
- cipherstate.mac()->final(&mac_buf[0]);
-
- const size_t mac_offset = record_len - (mac_size + pad_size);
-
- const bool mac_bad = !same_mem(&record_contents[mac_offset], &mac_buf[0], mac_size);
-
- if(mac_bad || padding_bad)
- throw TLS_Exception(Alert::BAD_RECORD_MAC, "Message authentication failure");
-
- output_record = Record(record_sequence,
- record_version,
- record_type,
- plaintext_block,
- plaintext_length);
-
- return true;
- }
-
-/*
* Checks the TLS padding. Returns 0 if the padding is invalid (we
* count the padding_length field as part of the padding size so a
* valid padding will always be at least one byte long), or the length
@@ -436,7 +382,7 @@ void cbc_decrypt_record(byte record_contents[], size_t record_len,
cipherstate.cbc_state() = last_ciphertext;
}
-bool decrypt_record(Record& output_record,
+void decrypt_record(secure_vector<byte>& output,
byte record_contents[], size_t record_len,
u64bit record_sequence,
Protocol_Version record_version,
@@ -458,57 +404,76 @@ bool decrypt_record(Record& output_record,
cipherstate.format_ad(record_sequence, record_type, record_version, ptext_size)
);
- // fixme - making a copy, should steal from Record
- secure_vector<byte> buffer;
- buffer += aead->start_vec(nonce);
+ output += aead->start_vec(nonce);
- const size_t offset = buffer.size();
- buffer += std::make_pair(&msg[0], msg_length);
- aead->finish(buffer, offset);
+ const size_t offset = output.size();
+ output += std::make_pair(&msg[0], msg_length);
+ aead->finish(output, offset);
- BOTAN_ASSERT(buffer.size() == ptext_size + offset, "Produced expected size");
+ BOTAN_ASSERT(output.size() == ptext_size + offset, "Produced expected size");
+ }
+ else
+ {
+ // GenericBlockCipher / GenericStreamCipher case
- output_record = Record(record_sequence,
- record_version,
- record_type,
- &buffer[0],
- buffer.size());
+ volatile bool padding_bad = false;
+ size_t pad_size = 0;
- return true;
- }
+ if(StreamCipher* sc = cipherstate.stream_cipher())
+ {
+ sc->cipher1(record_contents, record_len);
+ // no padding to check or remove
+ }
+ else if(BlockCipher* bc = cipherstate.block_cipher())
+ {
+ cbc_decrypt_record(record_contents, record_len, cipherstate, *bc);
- volatile bool padding_bad = false;
- size_t pad_size = 0;
+ pad_size = tls_padding_check(cipherstate.cipher_padding_single_byte(),
+ cipherstate.block_size(),
+ record_contents, record_len);
- if(StreamCipher* sc = cipherstate.stream_cipher())
- {
- sc->cipher1(record_contents, record_len);
- // no padding to check or remove
- }
- else if(BlockCipher* bc = cipherstate.block_cipher())
- {
- cbc_decrypt_record(record_contents, record_len, cipherstate, *bc);
+ padding_bad = (pad_size == 0);
+ }
+ else
+ {
+ throw Internal_Error("No cipher state set but needed to decrypt");
+ }
- pad_size = tls_padding_check(cipherstate.cipher_padding_single_byte(),
- cipherstate.block_size(),
- record_contents, record_len);
+ const size_t mac_size = cipherstate.mac_size();
+ const size_t iv_size = cipherstate.iv_size();
- padding_bad = (pad_size == 0);
- }
- else
- {
- throw Internal_Error("No cipher state set but needed to decrypt");
- }
+ cipherstate.mac()->update_be(record_sequence);
+ cipherstate.mac()->update(static_cast<byte>(record_type));
- return traditional_mac_check(output_record,
- record_contents,
- record_len,
- pad_size,
- padding_bad,
- record_sequence,
- record_version,
- record_type,
- cipherstate);
+ if(cipherstate.mac_includes_record_version())
+ {
+ cipherstate.mac()->update(record_version.major_version());
+ cipherstate.mac()->update(record_version.minor_version());
+ }
+
+ const size_t mac_pad_iv_size = mac_size + pad_size + iv_size;
+
+ if(record_len < mac_pad_iv_size)
+ throw Decoding_Error("Record sent with invalid length");
+
+ const byte* plaintext_block = &record_contents[iv_size];
+ const u16bit plaintext_length = record_len - mac_pad_iv_size;
+
+ cipherstate.mac()->update_be(plaintext_length);
+ cipherstate.mac()->update(plaintext_block, plaintext_length);
+
+ std::vector<byte> mac_buf(mac_size);
+ cipherstate.mac()->final(&mac_buf[0]);
+
+ const size_t mac_offset = record_len - (mac_size + pad_size);
+
+ const bool mac_bad = !same_mem(&record_contents[mac_offset], &mac_buf[0], mac_size);
+
+ if(mac_bad || padding_bad)
+ throw TLS_Exception(Alert::BAD_RECORD_MAC, "Message authentication failure");
+
+ output.assign(plaintext_block, plaintext_block + plaintext_length);
+ }
}
}
@@ -517,7 +482,10 @@ size_t read_record(secure_vector<byte>& readbuf,
const byte input[],
size_t input_sz,
size_t& consumed,
- Record& record,
+ secure_vector<byte>& record,
+ u64bit* record_sequence,
+ Protocol_Version* record_version,
+ Record_Type* record_type,
Connection_Sequence_Numbers* sequence_numbers,
std::function<Connection_Cipher_State* (u16bit)> get_cipherstate)
{
@@ -554,28 +522,26 @@ size_t read_record(secure_vector<byte>& readbuf,
"Have the entire SSLv2 hello");
// Fake v3-style handshake message wrapper
- std::vector<byte> sslv2_hello(4 + readbuf.size() - 2);
-
- sslv2_hello[0] = CLIENT_HELLO_SSLV2;
- sslv2_hello[1] = 0;
- sslv2_hello[2] = readbuf[0] & 0x7F;
- sslv2_hello[3] = readbuf[1];
+ *record_version = Protocol_Version::TLS_V10;
+ *record_sequence = 0;
+ *record_type = HANDSHAKE;
- copy_mem(&sslv2_hello[4], &readbuf[2], readbuf.size() - 2);
+ record.resize(4 + readbuf.size() - 2);
- record = Record(0,
- Protocol_Version::TLS_V10,
- HANDSHAKE,
- std::move(sslv2_hello));
+ record[0] = CLIENT_HELLO_SSLV2;
+ record[1] = 0;
+ record[2] = readbuf[0] & 0x7F;
+ record[3] = readbuf[1];
+ copy_mem(&record[4], &readbuf[2], readbuf.size() - 2);
readbuf.clear();
return 0;
}
}
- Protocol_Version record_version = Protocol_Version(readbuf[1], readbuf[2]);
+ *record_version = Protocol_Version(readbuf[1], readbuf[2]);
- const bool is_dtls = record_version.is_datagram_protocol();
+ const bool is_dtls = record_version->is_datagram_protocol();
if(is_dtls && readbuf.size() < DTLS_HEADER_SIZE)
{
@@ -606,41 +572,35 @@ size_t read_record(secure_vector<byte>& readbuf,
readbuf.size(),
"Have the full record");
- Record_Type record_type = static_cast<Record_Type>(readbuf[0]);
+ *record_type = static_cast<Record_Type>(readbuf[0]);
- u64bit record_sequence = 0;
u16bit epoch = 0;
if(is_dtls)
{
- record_sequence = load_be<u64bit>(&readbuf[3], 0);
- epoch = (record_sequence >> 48);
+ *record_sequence = load_be<u64bit>(&readbuf[3], 0);
+ epoch = (*record_sequence >> 48);
}
else if(sequence_numbers)
{
- record_sequence = sequence_numbers->next_read_sequence();
+ *record_sequence = sequence_numbers->next_read_sequence();
epoch = sequence_numbers->current_read_epoch();
}
else
{
// server initial handshake case
- record_sequence = 0;
+ *record_sequence = 0;
epoch = 0;
}
- if(sequence_numbers && sequence_numbers->already_seen(record_sequence))
+ if(sequence_numbers && sequence_numbers->already_seen(*record_sequence))
return 0;
byte* record_contents = &readbuf[header_size];
if(epoch == 0) // Unencrypted initial handshake
{
- record = Record(record_sequence,
- record_version,
- record_type,
- &readbuf[header_size],
- record_len);
-
+ record.assign(&readbuf[header_size], &readbuf[header_size + record_len]);
readbuf.clear();
return 0; // got a full record
}
@@ -652,16 +612,16 @@ size_t read_record(secure_vector<byte>& readbuf,
BOTAN_ASSERT(cipherstate, "Have cipherstate for this epoch");
- const bool ok = decrypt_record(record,
- record_contents,
- record_len,
- record_sequence,
- record_version,
- record_type,
- *cipherstate);
+ decrypt_record(record,
+ record_contents,
+ record_len,
+ *record_sequence,
+ *record_version,
+ *record_type,
+ *cipherstate);
- if(ok && sequence_numbers)
- sequence_numbers->read_accept(record_sequence);
+ if(sequence_numbers)
+ sequence_numbers->read_accept(*record_sequence);
readbuf.clear();
return 0;
diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h
index 68893af89..ef27a0a02 100644
--- a/src/tls/tls_record.h
+++ b/src/tls/tls_record.h
@@ -94,51 +94,6 @@ class Connection_Cipher_State
bool m_is_ssl3 = false;
};
-class Record
- {
- public:
- Record() {}
-
- Record(u64bit sequence,
- Protocol_Version version,
- Record_Type type,
- const byte contents[],
- size_t contents_size) :
- m_sequence(sequence),
- m_version(version),
- m_type(type),
- m_contents(contents, contents + contents_size) {}
-
- Record(u64bit sequence,
- Protocol_Version version,
- Record_Type type,
- std::vector<byte>&& contents) :
- m_sequence(sequence),
- m_version(version),
- m_type(type),
- m_contents(contents) {}
-
- bool is_valid() const { return m_type != NO_RECORD; }
-
- u64bit sequence() const { return m_sequence; }
-
- Record_Type type() const { return m_type; }
-
- Protocol_Version version() const { return m_version; }
-
- const std::vector<byte>& contents() const { return m_contents; }
-
- const byte* bits() const { return &m_contents[0]; }
-
- size_t size() const { return m_contents.size(); }
-
- private:
- u64bit m_sequence = 0;
- Protocol_Version m_version = Protocol_Version();
- Record_Type m_type = NO_RECORD;
- std::vector<byte> m_contents;
- };
-
/**
* Create a TLS record
* @param write_buffer the output record is placed here
@@ -166,7 +121,10 @@ size_t read_record(secure_vector<byte>& read_buffer,
const byte input[],
size_t input_length,
size_t& input_consumed,
- Record& output_record,
+ secure_vector<byte>& record,
+ u64bit* record_sequence,
+ Protocol_Version* record_version,
+ Record_Type* record_type,
Connection_Sequence_Numbers* sequence_numbers,
std::function<Connection_Cipher_State* (u16bit)> get_cipherstate);