From 9b3d1c467b93e27d88bb2da0a90dc0b9b1e95cfd Mon Sep 17 00:00:00 2001 From: Jack Lloyd Date: Thu, 23 Apr 2020 07:21:30 -0400 Subject: Small refactorings of TLS record layer Reduces some code duplication in #2320 --- src/lib/tls/tls_channel.cpp | 20 +++++++++------ src/lib/tls/tls_record.cpp | 61 ++++++++++++++++++++++++++++----------------- src/lib/tls/tls_record.h | 18 ++++++++++++- 3 files changed, 67 insertions(+), 32 deletions(-) diff --git a/src/lib/tls/tls_channel.cpp b/src/lib/tls/tls_channel.cpp index 88b897227..0004bac11 100644 --- a/src/lib/tls/tls_channel.cpp +++ b/src/lib/tls/tls_channel.cpp @@ -558,14 +558,18 @@ void Channel::write_record(Connection_Cipher_State* cipher_state, uint16_t epoch const Protocol_Version record_version = (m_pending_state) ? (m_pending_state->version()) : (m_active_state->version()); - TLS::write_record(m_writebuf, - record_type, - record_version, - sequence_numbers().next_write_sequence(epoch), - input, - length, - cipher_state, - m_rng); + const uint64_t next_seq = sequence_numbers().next_write_sequence(epoch); + + if(cipher_state == nullptr) + { + TLS::write_unencrypted_record(m_writebuf, record_type, record_version, next_seq, + input, length); + } + else + { + TLS::write_record(m_writebuf, record_type, record_version, next_seq, + input, length, *cipher_state, m_rng); + } callbacks().tls_emit_data(m_writebuf.data(), m_writebuf.size()); } diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp index c662d1fe5..d9e248487 100644 --- a/src/lib/tls/tls_record.cpp +++ b/src/lib/tls/tls_record.cpp @@ -187,16 +187,10 @@ inline void append_u16_len(secure_vector& output, size_t len_field) output.push_back(get_byte(1, len16)); } -} - -void write_record(secure_vector& output, - uint8_t record_type, - Protocol_Version version, - uint64_t record_sequence, - const uint8_t* message, - size_t message_len, - Connection_Cipher_State* cs, - RandomNumberGenerator& rng) +void write_record_header(secure_vector& output, + uint8_t record_type, + Protocol_Version version, + uint64_t record_sequence) { output.clear(); @@ -209,33 +203,54 @@ void write_record(secure_vector& output, for(size_t i = 0; i != 8; ++i) output.push_back(get_byte(i, record_sequence)); } + } - if(!cs) // initial unencrypted handshake records - { - append_u16_len(output, message_len); - output.insert(output.end(), message, message + message_len); - return; - } +} - AEAD_Mode& aead = cs->aead(); - std::vector aad = cs->format_ad(record_sequence, record_type, version, static_cast(message_len)); +void write_unencrypted_record(secure_vector& output, + uint8_t record_type, + Protocol_Version version, + uint64_t record_sequence, + const uint8_t* message, + size_t message_len) + { + if(record_type == APPLICATION_DATA) + throw Internal_Error("Writing an unencrypted TLS application data record"); + write_record_header(output, record_type, version, record_sequence); + append_u16_len(output, message_len); + output.insert(output.end(), message, message + message_len); + } + +void write_record(secure_vector& output, + uint8_t record_type, + Protocol_Version version, + uint64_t record_sequence, + const uint8_t* message, + size_t message_len, + Connection_Cipher_State& cs, + RandomNumberGenerator& rng) + { + write_record_header(output, record_type, version, record_sequence); + + AEAD_Mode& aead = cs.aead(); + std::vector aad = cs.format_ad(record_sequence, record_type, version, static_cast(message_len)); const size_t ctext_size = aead.output_length(message_len); - const size_t rec_size = ctext_size + cs->nonce_bytes_from_record(); + const size_t rec_size = ctext_size + cs.nonce_bytes_from_record(); aead.set_ad(aad); - const std::vector nonce = cs->aead_nonce(record_sequence, rng); + const std::vector nonce = cs.aead_nonce(record_sequence, rng); append_u16_len(output, rec_size); - if(cs->nonce_bytes_from_record() > 0) + if(cs.nonce_bytes_from_record() > 0) { - if(cs->nonce_format() == Nonce_Format::CBC_MODE) + if(cs.nonce_format() == Nonce_Format::CBC_MODE) output += nonce; else - output += std::make_pair(&nonce[cs->nonce_bytes_from_handshake()], cs->nonce_bytes_from_record()); + output += std::make_pair(&nonce[cs.nonce_bytes_from_handshake()], cs.nonce_bytes_from_record()); } const size_t header_size = output.size(); diff --git a/src/lib/tls/tls_record.h b/src/lib/tls/tls_record.h index 779954439..7593ac970 100644 --- a/src/lib/tls/tls_record.h +++ b/src/lib/tls/tls_record.h @@ -128,6 +128,22 @@ class Record_Header final Record_Type m_type; }; +/** +* Create an initial (unencrypted) TLS handshake record +* @param write_buffer the output record is placed here +* @param record_type the record layer type +* @param record_version the record layer version +* @param record_sequence the record layer sequence number +* @param message the record contents +* @param message_len is size of message +*/ +void write_unencrypted_record(secure_vector& output, + uint8_t record_type, + Protocol_Version version, + uint64_t record_sequence, + const uint8_t* message, + size_t message_len); + /** * Create a TLS record * @param write_buffer the output record is placed here @@ -145,7 +161,7 @@ void write_record(secure_vector& write_buffer, uint64_t record_sequence, const uint8_t* message, size_t message_len, - Connection_Cipher_State* cipherstate, + Connection_Cipher_State& cipherstate, RandomNumberGenerator& rng); // epoch -> cipher state -- cgit v1.2.3