aboutsummaryrefslogtreecommitdiffstats
path: root/src/lib/tls
diff options
context:
space:
mode:
authorJack Lloyd <[email protected]>2019-05-24 05:17:59 -0400
committerJack Lloyd <[email protected]>2019-05-24 05:17:59 -0400
commit78dff743222447cd626c6a7a1d94c5ccd46de02b (patch)
tree5cfe2f209ee497141857203ffabe191e86d61455 /src/lib/tls
parente16ec9353e3aa379b730fdb8d9473bc2cccb4b72 (diff)
Avoid unnecessary copies during TLS handshake
Diffstat (limited to 'src/lib/tls')
-rw-r--r--src/lib/tls/tls_channel.cpp6
-rw-r--r--src/lib/tls/tls_handshake_io.cpp37
-rw-r--r--src/lib/tls/tls_handshake_io.h9
3 files changed, 28 insertions, 24 deletions
diff --git a/src/lib/tls/tls_channel.cpp b/src/lib/tls/tls_channel.cpp
index b066f649e..9bff836f2 100644
--- a/src/lib/tls/tls_channel.cpp
+++ b/src/lib/tls/tls_channel.cpp
@@ -441,7 +441,8 @@ void Channel::process_handshake_ccs(const secure_vector<uint8_t>& record,
else if(epoch == sequence_numbers().current_read_epoch() - 1)
{
BOTAN_ASSERT(m_active_state, "Have active state here");
- m_active_state->handshake_io().add_record(unlock(record),
+ m_active_state->handshake_io().add_record(record.data(),
+ record.size(),
record_type,
record_sequence);
}
@@ -460,7 +461,8 @@ void Channel::process_handshake_ccs(const secure_vector<uint8_t>& record,
// May have been created in above conditional
if(m_pending_state)
{
- m_pending_state->handshake_io().add_record(unlock(record),
+ m_pending_state->handshake_io().add_record(record.data(),
+ record.size(),
record_type,
record_sequence);
diff --git a/src/lib/tls/tls_handshake_io.cpp b/src/lib/tls/tls_handshake_io.cpp
index 8834e0008..acc30b102 100644
--- a/src/lib/tls/tls_handshake_io.cpp
+++ b/src/lib/tls/tls_handshake_io.cpp
@@ -46,16 +46,17 @@ Protocol_Version Stream_Handshake_IO::initial_record_version() const
return Protocol_Version::TLS_V10;
}
-void Stream_Handshake_IO::add_record(const std::vector<uint8_t>& record,
+void Stream_Handshake_IO::add_record(const uint8_t record[],
+ size_t record_len,
Record_Type record_type, uint64_t)
{
if(record_type == HANDSHAKE)
{
- m_queue.insert(m_queue.end(), record.begin(), record.end());
+ m_queue.insert(m_queue.end(), record, record + record_len);
}
else if(record_type == CHANGE_CIPHER_SPEC)
{
- if(record.size() != 1 || record[0] != 1)
+ if(record_len != 1 || record[0] != 1)
throw Decoding_Error("Invalid ChangeCipherSpec");
// Pretend it's a regular handshake message of zero length
@@ -181,7 +182,8 @@ bool Datagram_Handshake_IO::timeout_check()
return true;
}
-void Datagram_Handshake_IO::add_record(const std::vector<uint8_t>& record,
+void Datagram_Handshake_IO::add_record(const uint8_t record[],
+ size_t record_len,
Record_Type record_type,
uint64_t record_sequence)
{
@@ -189,7 +191,7 @@ void Datagram_Handshake_IO::add_record(const std::vector<uint8_t>& record,
if(record_type == CHANGE_CIPHER_SPEC)
{
- if(record.size() != 1 || record[0] != 1)
+ if(record_len != 1 || record[0] != 1)
throw Decoding_Error("Invalid ChangeCipherSpec");
// TODO: check this is otherwise empty
@@ -199,28 +201,25 @@ void Datagram_Handshake_IO::add_record(const std::vector<uint8_t>& record,
const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
- const uint8_t* record_bits = record.data();
- size_t record_size = record.size();
-
- while(record_size)
+ while(record_len)
{
- if(record_size < DTLS_HANDSHAKE_HEADER_LEN)
+ if(record_len < DTLS_HANDSHAKE_HEADER_LEN)
return; // completely bogus? at least degenerate/weird
- const uint8_t msg_type = record_bits[0];
- const size_t msg_len = load_be24(&record_bits[1]);
- const uint16_t message_seq = load_be<uint16_t>(&record_bits[4], 0);
- const size_t fragment_offset = load_be24(&record_bits[6]);
- const size_t fragment_length = load_be24(&record_bits[9]);
+ const uint8_t msg_type = record[0];
+ const size_t msg_len = load_be24(&record[1]);
+ const uint16_t message_seq = load_be<uint16_t>(&record[4], 0);
+ const size_t fragment_offset = load_be24(&record[6]);
+ const size_t fragment_length = load_be24(&record[9]);
const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length;
- if(record_size < total_size)
+ if(record_len < total_size)
throw Decoding_Error("Bad lengths in DTLS header");
if(message_seq >= m_in_message_seq)
{
- m_messages[message_seq].add_fragment(&record_bits[DTLS_HANDSHAKE_HEADER_LEN],
+ m_messages[message_seq].add_fragment(&record[DTLS_HANDSHAKE_HEADER_LEN],
fragment_length,
fragment_offset,
epoch,
@@ -232,8 +231,8 @@ void Datagram_Handshake_IO::add_record(const std::vector<uint8_t>& record,
// TODO: detect retransmitted flight
}
- record_bits += total_size;
- record_size -= total_size;
+ record += total_size;
+ record_len -= total_size;
}
}
diff --git a/src/lib/tls/tls_handshake_io.h b/src/lib/tls/tls_handshake_io.h
index 8e1a0eca7..66579459d 100644
--- a/src/lib/tls/tls_handshake_io.h
+++ b/src/lib/tls/tls_handshake_io.h
@@ -39,7 +39,8 @@ class Handshake_IO
const std::vector<uint8_t>& handshake_msg,
Handshake_Type handshake_type) const = 0;
- virtual void add_record(const std::vector<uint8_t>& record,
+ virtual void add_record(const uint8_t record[],
+ size_t record_len,
Record_Type type,
uint64_t sequence_number) = 0;
@@ -78,7 +79,8 @@ class Stream_Handshake_IO final : public Handshake_IO
const std::vector<uint8_t>& handshake_msg,
Handshake_Type handshake_type) const override;
- void add_record(const std::vector<uint8_t>& record,
+ void add_record(const uint8_t record[],
+ size_t record_len,
Record_Type type,
uint64_t sequence_number) override;
@@ -118,7 +120,8 @@ class Datagram_Handshake_IO final : public Handshake_IO
const std::vector<uint8_t>& handshake_msg,
Handshake_Type handshake_type) const override;
- void add_record(const std::vector<uint8_t>& record,
+ void add_record(const uint8_t record[],
+ size_t record_len,
Record_Type type,
uint64_t sequence_number) override;