diff options
author | Jack Lloyd <[email protected]> | 2020-04-23 07:21:30 -0400 |
---|---|---|
committer | Jack Lloyd <[email protected]> | 2020-04-24 09:57:43 -0400 |
commit | 9b3d1c467b93e27d88bb2da0a90dc0b9b1e95cfd (patch) | |
tree | 173085af29bb9ff7426f14473fdf141a29203d20 /src/lib/tls | |
parent | 39c5aacdf1572dfe27bb3e58fbceb7854bfca117 (diff) |
Small refactorings of TLS record layer
Reduces some code duplication in #2320
Diffstat (limited to 'src/lib/tls')
-rw-r--r-- | src/lib/tls/tls_channel.cpp | 20 | ||||
-rw-r--r-- | src/lib/tls/tls_record.cpp | 61 | ||||
-rw-r--r-- | src/lib/tls/tls_record.h | 18 |
3 files changed, 67 insertions, 32 deletions
diff --git a/src/lib/tls/tls_channel.cpp b/src/lib/tls/tls_channel.cpp index 88b897227..0004bac11 100644 --- a/src/lib/tls/tls_channel.cpp +++ b/src/lib/tls/tls_channel.cpp @@ -558,14 +558,18 @@ void Channel::write_record(Connection_Cipher_State* cipher_state, uint16_t epoch const Protocol_Version record_version = (m_pending_state) ? (m_pending_state->version()) : (m_active_state->version()); - TLS::write_record(m_writebuf, - record_type, - record_version, - sequence_numbers().next_write_sequence(epoch), - input, - length, - cipher_state, - m_rng); + const uint64_t next_seq = sequence_numbers().next_write_sequence(epoch); + + if(cipher_state == nullptr) + { + TLS::write_unencrypted_record(m_writebuf, record_type, record_version, next_seq, + input, length); + } + else + { + TLS::write_record(m_writebuf, record_type, record_version, next_seq, + input, length, *cipher_state, m_rng); + } callbacks().tls_emit_data(m_writebuf.data(), m_writebuf.size()); } diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp index c662d1fe5..d9e248487 100644 --- a/src/lib/tls/tls_record.cpp +++ b/src/lib/tls/tls_record.cpp @@ -187,16 +187,10 @@ inline void append_u16_len(secure_vector<uint8_t>& output, size_t len_field) output.push_back(get_byte(1, len16)); } -} - -void write_record(secure_vector<uint8_t>& output, - uint8_t record_type, - Protocol_Version version, - uint64_t record_sequence, - const uint8_t* message, - size_t message_len, - Connection_Cipher_State* cs, - RandomNumberGenerator& rng) +void write_record_header(secure_vector<uint8_t>& output, + uint8_t record_type, + Protocol_Version version, + uint64_t record_sequence) { output.clear(); @@ -209,33 +203,54 @@ void write_record(secure_vector<uint8_t>& output, for(size_t i = 0; i != 8; ++i) output.push_back(get_byte(i, record_sequence)); } + } - if(!cs) // initial unencrypted handshake records - { - append_u16_len(output, message_len); - output.insert(output.end(), message, message + message_len); - return; - } +} - AEAD_Mode& aead = cs->aead(); - std::vector<uint8_t> aad = cs->format_ad(record_sequence, record_type, version, static_cast<uint16_t>(message_len)); +void write_unencrypted_record(secure_vector<uint8_t>& output, + uint8_t record_type, + Protocol_Version version, + uint64_t record_sequence, + const uint8_t* message, + size_t message_len) + { + if(record_type == APPLICATION_DATA) + throw Internal_Error("Writing an unencrypted TLS application data record"); + write_record_header(output, record_type, version, record_sequence); + append_u16_len(output, message_len); + output.insert(output.end(), message, message + message_len); + } + +void write_record(secure_vector<uint8_t>& output, + uint8_t record_type, + Protocol_Version version, + uint64_t record_sequence, + const uint8_t* message, + size_t message_len, + Connection_Cipher_State& cs, + RandomNumberGenerator& rng) + { + write_record_header(output, record_type, version, record_sequence); + + AEAD_Mode& aead = cs.aead(); + std::vector<uint8_t> aad = cs.format_ad(record_sequence, record_type, version, static_cast<uint16_t>(message_len)); const size_t ctext_size = aead.output_length(message_len); - const size_t rec_size = ctext_size + cs->nonce_bytes_from_record(); + const size_t rec_size = ctext_size + cs.nonce_bytes_from_record(); aead.set_ad(aad); - const std::vector<uint8_t> nonce = cs->aead_nonce(record_sequence, rng); + const std::vector<uint8_t> nonce = cs.aead_nonce(record_sequence, rng); append_u16_len(output, rec_size); - if(cs->nonce_bytes_from_record() > 0) + if(cs.nonce_bytes_from_record() > 0) { - if(cs->nonce_format() == Nonce_Format::CBC_MODE) + if(cs.nonce_format() == Nonce_Format::CBC_MODE) output += nonce; else - output += std::make_pair(&nonce[cs->nonce_bytes_from_handshake()], cs->nonce_bytes_from_record()); + output += std::make_pair(&nonce[cs.nonce_bytes_from_handshake()], cs.nonce_bytes_from_record()); } const size_t header_size = output.size(); diff --git a/src/lib/tls/tls_record.h b/src/lib/tls/tls_record.h index 779954439..7593ac970 100644 --- a/src/lib/tls/tls_record.h +++ b/src/lib/tls/tls_record.h @@ -129,6 +129,22 @@ class Record_Header final }; /** +* Create an initial (unencrypted) TLS handshake record +* @param write_buffer the output record is placed here +* @param record_type the record layer type +* @param record_version the record layer version +* @param record_sequence the record layer sequence number +* @param message the record contents +* @param message_len is size of message +*/ +void write_unencrypted_record(secure_vector<uint8_t>& output, + uint8_t record_type, + Protocol_Version version, + uint64_t record_sequence, + const uint8_t* message, + size_t message_len); + +/** * Create a TLS record * @param write_buffer the output record is placed here * @param record_type the record layer type @@ -145,7 +161,7 @@ void write_record(secure_vector<uint8_t>& write_buffer, uint64_t record_sequence, const uint8_t* message, size_t message_len, - Connection_Cipher_State* cipherstate, + Connection_Cipher_State& cipherstate, RandomNumberGenerator& rng); // epoch -> cipher state |