aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJack Lloyd <[email protected]>2020-04-23 07:21:30 -0400
committerJack Lloyd <[email protected]>2020-04-24 09:57:43 -0400
commit9b3d1c467b93e27d88bb2da0a90dc0b9b1e95cfd (patch)
tree173085af29bb9ff7426f14473fdf141a29203d20
parent39c5aacdf1572dfe27bb3e58fbceb7854bfca117 (diff)
Small refactorings of TLS record layer
Reduces some code duplication in #2320
-rw-r--r--src/lib/tls/tls_channel.cpp20
-rw-r--r--src/lib/tls/tls_record.cpp61
-rw-r--r--src/lib/tls/tls_record.h18
3 files changed, 67 insertions, 32 deletions
diff --git a/src/lib/tls/tls_channel.cpp b/src/lib/tls/tls_channel.cpp
index 88b897227..0004bac11 100644
--- a/src/lib/tls/tls_channel.cpp
+++ b/src/lib/tls/tls_channel.cpp
@@ -558,14 +558,18 @@ void Channel::write_record(Connection_Cipher_State* cipher_state, uint16_t epoch
const Protocol_Version record_version =
(m_pending_state) ? (m_pending_state->version()) : (m_active_state->version());
- TLS::write_record(m_writebuf,
- record_type,
- record_version,
- sequence_numbers().next_write_sequence(epoch),
- input,
- length,
- cipher_state,
- m_rng);
+ const uint64_t next_seq = sequence_numbers().next_write_sequence(epoch);
+
+ if(cipher_state == nullptr)
+ {
+ TLS::write_unencrypted_record(m_writebuf, record_type, record_version, next_seq,
+ input, length);
+ }
+ else
+ {
+ TLS::write_record(m_writebuf, record_type, record_version, next_seq,
+ input, length, *cipher_state, m_rng);
+ }
callbacks().tls_emit_data(m_writebuf.data(), m_writebuf.size());
}
diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp
index c662d1fe5..d9e248487 100644
--- a/src/lib/tls/tls_record.cpp
+++ b/src/lib/tls/tls_record.cpp
@@ -187,16 +187,10 @@ inline void append_u16_len(secure_vector<uint8_t>& output, size_t len_field)
output.push_back(get_byte(1, len16));
}
-}
-
-void write_record(secure_vector<uint8_t>& output,
- uint8_t record_type,
- Protocol_Version version,
- uint64_t record_sequence,
- const uint8_t* message,
- size_t message_len,
- Connection_Cipher_State* cs,
- RandomNumberGenerator& rng)
+void write_record_header(secure_vector<uint8_t>& output,
+ uint8_t record_type,
+ Protocol_Version version,
+ uint64_t record_sequence)
{
output.clear();
@@ -209,33 +203,54 @@ void write_record(secure_vector<uint8_t>& output,
for(size_t i = 0; i != 8; ++i)
output.push_back(get_byte(i, record_sequence));
}
+ }
- if(!cs) // initial unencrypted handshake records
- {
- append_u16_len(output, message_len);
- output.insert(output.end(), message, message + message_len);
- return;
- }
+}
- AEAD_Mode& aead = cs->aead();
- std::vector<uint8_t> aad = cs->format_ad(record_sequence, record_type, version, static_cast<uint16_t>(message_len));
+void write_unencrypted_record(secure_vector<uint8_t>& output,
+ uint8_t record_type,
+ Protocol_Version version,
+ uint64_t record_sequence,
+ const uint8_t* message,
+ size_t message_len)
+ {
+ if(record_type == APPLICATION_DATA)
+ throw Internal_Error("Writing an unencrypted TLS application data record");
+ write_record_header(output, record_type, version, record_sequence);
+ append_u16_len(output, message_len);
+ output.insert(output.end(), message, message + message_len);
+ }
+
+void write_record(secure_vector<uint8_t>& output,
+ uint8_t record_type,
+ Protocol_Version version,
+ uint64_t record_sequence,
+ const uint8_t* message,
+ size_t message_len,
+ Connection_Cipher_State& cs,
+ RandomNumberGenerator& rng)
+ {
+ write_record_header(output, record_type, version, record_sequence);
+
+ AEAD_Mode& aead = cs.aead();
+ std::vector<uint8_t> aad = cs.format_ad(record_sequence, record_type, version, static_cast<uint16_t>(message_len));
const size_t ctext_size = aead.output_length(message_len);
- const size_t rec_size = ctext_size + cs->nonce_bytes_from_record();
+ const size_t rec_size = ctext_size + cs.nonce_bytes_from_record();
aead.set_ad(aad);
- const std::vector<uint8_t> nonce = cs->aead_nonce(record_sequence, rng);
+ const std::vector<uint8_t> nonce = cs.aead_nonce(record_sequence, rng);
append_u16_len(output, rec_size);
- if(cs->nonce_bytes_from_record() > 0)
+ if(cs.nonce_bytes_from_record() > 0)
{
- if(cs->nonce_format() == Nonce_Format::CBC_MODE)
+ if(cs.nonce_format() == Nonce_Format::CBC_MODE)
output += nonce;
else
- output += std::make_pair(&nonce[cs->nonce_bytes_from_handshake()], cs->nonce_bytes_from_record());
+ output += std::make_pair(&nonce[cs.nonce_bytes_from_handshake()], cs.nonce_bytes_from_record());
}
const size_t header_size = output.size();
diff --git a/src/lib/tls/tls_record.h b/src/lib/tls/tls_record.h
index 779954439..7593ac970 100644
--- a/src/lib/tls/tls_record.h
+++ b/src/lib/tls/tls_record.h
@@ -129,6 +129,22 @@ class Record_Header final
};
/**
+* Create an initial (unencrypted) TLS handshake record
+* @param write_buffer the output record is placed here
+* @param record_type the record layer type
+* @param record_version the record layer version
+* @param record_sequence the record layer sequence number
+* @param message the record contents
+* @param message_len is size of message
+*/
+void write_unencrypted_record(secure_vector<uint8_t>& output,
+ uint8_t record_type,
+ Protocol_Version version,
+ uint64_t record_sequence,
+ const uint8_t* message,
+ size_t message_len);
+
+/**
* Create a TLS record
* @param write_buffer the output record is placed here
* @param record_type the record layer type
@@ -145,7 +161,7 @@ void write_record(secure_vector<uint8_t>& write_buffer,
uint64_t record_sequence,
const uint8_t* message,
size_t message_len,
- Connection_Cipher_State* cipherstate,
+ Connection_Cipher_State& cipherstate,
RandomNumberGenerator& rng);
// epoch -> cipher state