diff options
author | Jack Lloyd <[email protected]> | 2019-07-05 07:20:52 -0400 |
---|---|---|
committer | Jack Lloyd <[email protected]> | 2019-07-05 07:20:52 -0400 |
commit | e0e13b7ee7fea7358939116a6e496da81356f0bb (patch) | |
tree | 8c209244b3baf6d49bb42e7e86948b8303d13362 /src/lib | |
parent | 51d9595d6842747c7723a5ebb4ac43054bed4e2a (diff) | |
parent | b6c81dd38d60327d6e6118599f163933d6eee256 (diff) |
Merge GH #2021 TLS record layer cleanups
Diffstat (limited to 'src/lib')
-rw-r--r-- | src/lib/tls/tls_channel.cpp | 60 | ||||
-rw-r--r-- | src/lib/tls/tls_channel.h | 1 | ||||
-rw-r--r-- | src/lib/tls/tls_record.cpp | 177 | ||||
-rw-r--r-- | src/lib/tls/tls_record.h | 131 |
4 files changed, 173 insertions, 196 deletions
diff --git a/src/lib/tls/tls_channel.cpp b/src/lib/tls/tls_channel.cpp index ced5dd3f1..eef9270eb 100644 --- a/src/lib/tls/tls_channel.cpp +++ b/src/lib/tls/tls_channel.cpp @@ -305,22 +305,20 @@ size_t Channel::received_data(const uint8_t input[], size_t input_size) { while(!is_closed() && input_size) { - secure_vector<uint8_t> record_data; - uint64_t record_sequence = 0; - Record_Type record_type = NO_RECORD; - Protocol_Version record_version; - size_t consumed = 0; - Record_Raw_Input raw_input(input, input_size, consumed, m_is_datagram); - Record record(record_data, &record_sequence, &record_version, &record_type); - const size_t needed = - read_record(m_readbuf, - raw_input, - record, + const Record_Header record = + read_record(m_is_datagram, + m_readbuf, + input, + input_size, + consumed, + m_record_buf, m_sequence_numbers.get(), [this](uint16_t epoch) { return read_cipher_state_epoch(epoch); }); + const size_t needed = record.needed(); + BOTAN_ASSERT(consumed > 0, "Got to eat something"); BOTAN_ASSERT(consumed <= input_size, @@ -332,20 +330,20 @@ size_t Channel::received_data(const uint8_t input[], size_t input_size) BOTAN_ASSERT(input_size == 0 || needed == 0, "Got a full record or consumed all input"); - // Ignore invalid records in DTLS - if(m_is_datagram && *record.get_type() == NO_RECORD) - return 0; - if(input_size == 0 && needed != 0) return needed; // need more data to complete record - if(record_data.size() > MAX_PLAINTEXT_SIZE) + // Ignore invalid records in DTLS + if(m_is_datagram && record.type() == NO_RECORD) + return 0; + + if(m_record_buf.size() > MAX_PLAINTEXT_SIZE) throw TLS_Exception(Alert::RECORD_OVERFLOW, "TLS plaintext record is larger than allowed maximum"); if(auto pending = pending_state()) { - if(pending->server_hello() != nullptr && record_version != pending->version()) + if(pending->server_hello() != nullptr && record.version() != pending->version()) { throw TLS_Exception(Alert::PROTOCOL_VERSION, "Received unexpected record version"); @@ -353,7 +351,7 @@ size_t Channel::received_data(const uint8_t input[], size_t input_size) } else if(auto active = active_state()) { - if(record_version != active->version()) + if(record.version() != active->version()) { throw TLS_Exception(Alert::PROTOCOL_VERSION, "Received unexpected record version"); @@ -362,31 +360,31 @@ size_t Channel::received_data(const uint8_t input[], size_t input_size) else { // For initial records just check for basic sanity - if(record_version.major_version() != 3 && - record_version.major_version() != 0xFE) + if(record.version().major_version() != 3 && + record.version().major_version() != 0xFE) { throw TLS_Exception(Alert::PROTOCOL_VERSION, "Received unexpected record version in initial record"); } } - if(record_type == HANDSHAKE || record_type == CHANGE_CIPHER_SPEC) + if(record.type() == HANDSHAKE || record.type() == CHANGE_CIPHER_SPEC) { - process_handshake_ccs(record_data, record_sequence, record_type, record_version); + process_handshake_ccs(m_record_buf, record.sequence(), record.type(), record.version()); } - else if(record_type == APPLICATION_DATA) + else if(record.type() == APPLICATION_DATA) { if(pending_state() != nullptr) throw TLS_Exception(Alert::UNEXPECTED_MESSAGE, "Can't interleave application and handshake data"); - process_application_data(record_sequence, record_data); + process_application_data(record.sequence(), m_record_buf); } - else if(record_type == ALERT) + else if(record.type() == ALERT) { - process_alert(record_data); + process_alert(m_record_buf); } - else if(record_type != NO_RECORD) + else if(record.type() != NO_RECORD) throw Unexpected_Message("Unexpected record type " + - std::to_string(record_type) + + std::to_string(record.type()) + " from counterparty"); } @@ -520,12 +518,12 @@ void Channel::write_record(Connection_Cipher_State* cipher_state, uint16_t epoch const Protocol_Version record_version = (m_pending_state) ? (m_pending_state->version()) : (m_active_state->version()); - Record_Message record_message(record_type, 0, input, length); - TLS::write_record(m_writebuf, - record_message, + record_type, record_version, sequence_numbers().next_write_sequence(epoch), + input, + length, cipher_state, m_rng); diff --git a/src/lib/tls/tls_channel.h b/src/lib/tls/tls_channel.h index 63cbcf0fc..2a2b74332 100644 --- a/src/lib/tls/tls_channel.h +++ b/src/lib/tls/tls_channel.h @@ -303,6 +303,7 @@ class BOTAN_PUBLIC_API(2,0) Channel /* I/O buffers */ secure_vector<uint8_t> m_writebuf; secure_vector<uint8_t> m_readbuf; + secure_vector<uint8_t> m_record_buf; }; } diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp index 27714af0b..3304b70eb 100644 --- a/src/lib/tls/tls_record.cpp +++ b/src/lib/tls/tls_record.cpp @@ -189,41 +189,43 @@ inline void append_u16_len(secure_vector<uint8_t>& output, size_t len_field) } void write_record(secure_vector<uint8_t>& output, - Record_Message msg, + uint8_t record_type, Protocol_Version version, - uint64_t seq, + uint64_t record_sequence, + const uint8_t* message, + size_t message_len, Connection_Cipher_State* cs, RandomNumberGenerator& rng) { output.clear(); - output.push_back(msg.get_type()); + output.push_back(record_type); output.push_back(version.major_version()); output.push_back(version.minor_version()); if(version.is_datagram_protocol()) { for(size_t i = 0; i != 8; ++i) - output.push_back(get_byte(i, seq)); + output.push_back(get_byte(i, record_sequence)); } if(!cs) // initial unencrypted handshake records { - append_u16_len(output, msg.get_size()); - output.insert(output.end(), msg.get_data(), msg.get_data() + msg.get_size()); + append_u16_len(output, message_len); + output.insert(output.end(), message, message + message_len); return; } AEAD_Mode& aead = cs->aead(); - std::vector<uint8_t> aad = cs->format_ad(seq, msg.get_type(), version, static_cast<uint16_t>(msg.get_size())); + std::vector<uint8_t> aad = cs->format_ad(record_sequence, record_type, version, static_cast<uint16_t>(message_len)); - const size_t ctext_size = aead.output_length(msg.get_size()); + const size_t ctext_size = aead.output_length(message_len); const size_t rec_size = ctext_size + cs->nonce_bytes_from_record(); aead.set_ad(aad); - const std::vector<uint8_t> nonce = cs->aead_nonce(seq, rng); + const std::vector<uint8_t> nonce = cs->aead_nonce(record_sequence, rng); append_u16_len(output, rec_size); @@ -236,7 +238,7 @@ void write_record(secure_vector<uint8_t>& output, } const size_t header_size = output.size(); - output += std::make_pair(msg.get_data(), msg.get_size()); + output += std::make_pair(message, message_len); aead.start(nonce); aead.finish(output, header_size); @@ -300,35 +302,36 @@ void decrypt_record(secure_vector<uint8_t>& output, aead.start(nonce); - const size_t offset = output.size(); - output += std::make_pair(msg, msg_length); - aead.finish(output, offset); + output.assign(msg, msg + msg_length); + aead.finish(output, 0); } -size_t read_tls_record(secure_vector<uint8_t>& readbuf, - Record_Raw_Input& raw_input, - Record& rec, - Connection_Sequence_Numbers* sequence_numbers, - get_cipherstate_fn get_cipherstate) +Record_Header read_tls_record(secure_vector<uint8_t>& readbuf, + const uint8_t input[], + size_t input_len, + size_t& consumed, + secure_vector<uint8_t>& recbuf, + Connection_Sequence_Numbers* sequence_numbers, + get_cipherstate_fn get_cipherstate) { if(readbuf.size() < TLS_HEADER_SIZE) // header incomplete? { - if(size_t needed = fill_buffer_to(readbuf, - raw_input.get_data(), raw_input.get_size(), raw_input.get_consumed(), - TLS_HEADER_SIZE)) - return needed; + if(size_t needed = fill_buffer_to(readbuf, input, input_len, consumed, TLS_HEADER_SIZE)) + { + return Record_Header(needed); + } BOTAN_ASSERT_EQUAL(readbuf.size(), TLS_HEADER_SIZE, "Have an entire header"); } - *rec.get_protocol_version() = Protocol_Version(readbuf[1], readbuf[2]); + const Protocol_Version version(readbuf[1], readbuf[2]); - if(rec.get_protocol_version()->is_datagram_protocol()) + if(version.is_datagram_protocol()) throw TLS_Exception(Alert::PROTOCOL_VERSION, "Expected TLS but got a record with DTLS version"); const size_t record_size = make_uint16(readbuf[TLS_HEADER_SIZE-2], - readbuf[TLS_HEADER_SIZE-1]); + readbuf[TLS_HEADER_SIZE-1]); if(record_size > MAX_CIPHERTEXT_SIZE) throw TLS_Exception(Alert::RECORD_OVERFLOW, @@ -338,38 +341,36 @@ size_t read_tls_record(secure_vector<uint8_t>& readbuf, throw TLS_Exception(Alert::DECODE_ERROR, "Received a completely empty record"); - if(size_t needed = fill_buffer_to(readbuf, - raw_input.get_data(), raw_input.get_size(), raw_input.get_consumed(), - TLS_HEADER_SIZE + record_size)) - return needed; + if(size_t needed = fill_buffer_to(readbuf, input, input_len, consumed, TLS_HEADER_SIZE + record_size)) + { + return Record_Header(needed); + } BOTAN_ASSERT_EQUAL(static_cast<size_t>(TLS_HEADER_SIZE) + record_size, readbuf.size(), "Have the full record"); - *rec.get_type() = static_cast<Record_Type>(readbuf[0]); + const Record_Type type = static_cast<Record_Type>(readbuf[0]); uint16_t epoch = 0; + uint64_t sequence = 0; if(sequence_numbers) { - *rec.get_sequence() = sequence_numbers->next_read_sequence(); + sequence = sequence_numbers->next_read_sequence(); epoch = sequence_numbers->current_read_epoch(); } else { // server initial handshake case - *rec.get_sequence() = 0; epoch = 0; } - uint8_t* record_contents = &readbuf[TLS_HEADER_SIZE]; - if(epoch == 0) // Unencrypted initial handshake { - rec.get_data().assign(readbuf.begin() + TLS_HEADER_SIZE, readbuf.begin() + TLS_HEADER_SIZE + record_size); + recbuf.assign(readbuf.begin() + TLS_HEADER_SIZE, readbuf.begin() + TLS_HEADER_SIZE + record_size); readbuf.clear(); - return 0; // got a full record + return Record_Header(sequence, version, type); } // Otherwise, decrypt, check MAC, return plaintext @@ -377,45 +378,46 @@ size_t read_tls_record(secure_vector<uint8_t>& readbuf, BOTAN_ASSERT(cs, "Have cipherstate for this epoch"); - decrypt_record(rec.get_data(), - record_contents, + decrypt_record(recbuf, + &readbuf[TLS_HEADER_SIZE], record_size, - *rec.get_sequence(), - *rec.get_protocol_version(), - *rec.get_type(), + sequence, + version, + type, *cs); if(sequence_numbers) - sequence_numbers->read_accept(*rec.get_sequence()); + sequence_numbers->read_accept(sequence); readbuf.clear(); - return 0; + return Record_Header(sequence, version, type); } -size_t read_dtls_record(secure_vector<uint8_t>& readbuf, - Record_Raw_Input& raw_input, - Record& rec, - Connection_Sequence_Numbers* sequence_numbers, - get_cipherstate_fn get_cipherstate) +Record_Header read_dtls_record(secure_vector<uint8_t>& readbuf, + const uint8_t input[], + size_t input_len, + size_t& consumed, + secure_vector<uint8_t>& recbuf, + Connection_Sequence_Numbers* sequence_numbers, + get_cipherstate_fn get_cipherstate) { if(readbuf.size() < DTLS_HEADER_SIZE) // header incomplete? { - if(fill_buffer_to(readbuf, raw_input.get_data(), raw_input.get_size(), raw_input.get_consumed(), DTLS_HEADER_SIZE)) + if(fill_buffer_to(readbuf, input, input_len, consumed, DTLS_HEADER_SIZE)) { readbuf.clear(); - return 0; + return Record_Header(0); } BOTAN_ASSERT_EQUAL(readbuf.size(), DTLS_HEADER_SIZE, "Have an entire header"); } - *rec.get_protocol_version() = Protocol_Version(readbuf[1], readbuf[2]); + const Protocol_Version version(readbuf[1], readbuf[2]); - if(rec.get_protocol_version()->is_datagram_protocol() == false) + if(version.is_datagram_protocol() == false) { readbuf.clear(); - *rec.get_type() = NO_RECORD; - return 0; + return Record_Header(0); } const size_t record_size = make_uint16(readbuf[DTLS_HEADER_SIZE-2], @@ -425,44 +427,39 @@ size_t read_dtls_record(secure_vector<uint8_t>& readbuf, { // Too large to be valid, ignore it readbuf.clear(); - *rec.get_type() = NO_RECORD; - return 0; + return Record_Header(0); } - if(fill_buffer_to(readbuf, raw_input.get_data(), raw_input.get_size(), raw_input.get_consumed(), DTLS_HEADER_SIZE + record_size)) + if(fill_buffer_to(readbuf, input, input_len, consumed, DTLS_HEADER_SIZE + record_size)) { // Truncated packet? readbuf.clear(); - *rec.get_type() = NO_RECORD; - return 0; + return Record_Header(0); } BOTAN_ASSERT_EQUAL(static_cast<size_t>(DTLS_HEADER_SIZE) + record_size, readbuf.size(), "Have the full record"); - *rec.get_type() = static_cast<Record_Type>(readbuf[0]); + const Record_Type type = static_cast<Record_Type>(readbuf[0]); uint16_t epoch = 0; - *rec.get_sequence() = load_be<uint64_t>(&readbuf[3], 0); - epoch = (*rec.get_sequence() >> 48); + const uint64_t sequence = load_be<uint64_t>(&readbuf[3], 0); + epoch = (sequence >> 48); - if(sequence_numbers && sequence_numbers->already_seen(*rec.get_sequence())) + if(sequence_numbers && sequence_numbers->already_seen(sequence)) { readbuf.clear(); - *rec.get_type() = NO_RECORD; - return 0; + return Record_Header(0); } - uint8_t* record_contents = &readbuf[DTLS_HEADER_SIZE]; - if(epoch == 0) // Unencrypted initial handshake { - rec.get_data().assign(readbuf.begin() + DTLS_HEADER_SIZE, readbuf.begin() + DTLS_HEADER_SIZE + record_size); + recbuf.assign(readbuf.begin() + DTLS_HEADER_SIZE, readbuf.begin() + DTLS_HEADER_SIZE + record_size); readbuf.clear(); if(sequence_numbers) - sequence_numbers->read_accept(*rec.get_sequence()); - return 0; // got a full record + sequence_numbers->read_accept(sequence); + return Record_Header(sequence, version, type); } try @@ -472,42 +469,44 @@ size_t read_dtls_record(secure_vector<uint8_t>& readbuf, BOTAN_ASSERT(cs, "Have cipherstate for this epoch"); - decrypt_record(rec.get_data(), - record_contents, + decrypt_record(recbuf, + &readbuf[DTLS_HEADER_SIZE], record_size, - *rec.get_sequence(), - *rec.get_protocol_version(), - *rec.get_type(), + sequence, + version, + type, *cs); } catch(std::exception&) { readbuf.clear(); - *rec.get_type() = NO_RECORD; - return 0; + return Record_Header(0); } if(sequence_numbers) - sequence_numbers->read_accept(*rec.get_sequence()); + sequence_numbers->read_accept(sequence); readbuf.clear(); - return 0; + return Record_Header(sequence, version, type); } } -size_t read_record(secure_vector<uint8_t>& readbuf, - Record_Raw_Input& raw_input, - Record& rec, - Connection_Sequence_Numbers* sequence_numbers, - get_cipherstate_fn get_cipherstate) +Record_Header read_record(bool is_datagram, + secure_vector<uint8_t>& readbuf, + const uint8_t input[], + size_t input_len, + size_t& consumed, + secure_vector<uint8_t>& recbuf, + Connection_Sequence_Numbers* sequence_numbers, + get_cipherstate_fn get_cipherstate) { - if(raw_input.is_datagram()) - return read_dtls_record(readbuf, raw_input, rec, - sequence_numbers, get_cipherstate); + if(is_datagram) + return read_dtls_record(readbuf, input, input_len, consumed, + recbuf, sequence_numbers, get_cipherstate); else - return read_tls_record(readbuf, raw_input, rec, - sequence_numbers, get_cipherstate); + return read_tls_record(readbuf, input, input_len, consumed, + recbuf, sequence_numbers, get_cipherstate); } } diff --git a/src/lib/tls/tls_record.h b/src/lib/tls/tls_record.h index d0ffc0270..3e3475c03 100644 --- a/src/lib/tls/tls_record.h +++ b/src/lib/tls/tls_record.h @@ -77,93 +77,69 @@ class Connection_Cipher_State final size_t m_nonce_bytes_from_record; }; -class Record final +class Record_Header final { public: - Record(secure_vector<uint8_t>& data, - uint64_t* sequence, - Protocol_Version* protocol_version, - Record_Type* type) - : m_data(data), m_sequence(sequence), m_protocol_version(protocol_version), - m_type(type), m_size(data.size()) {} - - secure_vector<uint8_t>& get_data() { return m_data; } - - Protocol_Version* get_protocol_version() { return m_protocol_version; } - - uint64_t* get_sequence() { return m_sequence; } - - Record_Type* get_type() { return m_type; } - - size_t& get_size() { return m_size; } - - private: - secure_vector<uint8_t>& m_data; - uint64_t* m_sequence; - Protocol_Version* m_protocol_version; - Record_Type* m_type; - size_t m_size; - }; + Record_Header(uint64_t sequence, + Protocol_Version version, + Record_Type type) : + m_needed(0), + m_sequence(sequence), + m_version(version), + m_type(type) + {} + + Record_Header(size_t needed) : + m_needed(needed), + m_sequence(0), + m_version(Protocol_Version()), + m_type(NO_RECORD) + {} + + size_t needed() const { return m_needed; } + + Protocol_Version version() const + { + BOTAN_ASSERT_NOMSG(m_needed == 0); + return m_version; + } -class Record_Message final - { - public: - Record_Message(const uint8_t* data, size_t size) - : m_type(0), m_sequence(0), m_data(data), m_size(size) {} - Record_Message(uint8_t type, uint64_t sequence, const uint8_t* data, size_t size) - : m_type(type), m_sequence(sequence), m_data(data), - m_size(size) {} + uint64_t sequence() const + { + BOTAN_ASSERT_NOMSG(m_needed == 0); + return m_sequence; + } - uint8_t& get_type() { return m_type; } - uint64_t& get_sequence() { return m_sequence; } - const uint8_t* get_data() { return m_data; } - size_t& get_size() { return m_size; } + Record_Type type() const + { + BOTAN_ASSERT_NOMSG(m_needed == 0); + return m_type; + } private: - uint8_t m_type; + size_t m_needed; uint64_t m_sequence; - const uint8_t* m_data; - size_t m_size; -}; - -class Record_Raw_Input final - { - public: - Record_Raw_Input(const uint8_t* data, size_t size, size_t& consumed, - bool is_datagram) - : m_data(data), m_size(size), m_consumed(consumed), - m_is_datagram(is_datagram) {} - - const uint8_t*& get_data() { return m_data; } - - size_t& get_size() { return m_size; } - - size_t& get_consumed() { return m_consumed; } - void set_consumed(size_t consumed) { m_consumed = consumed; } - - bool is_datagram() { return m_is_datagram; } - - private: - const uint8_t* m_data; - size_t m_size; - size_t& m_consumed; - bool m_is_datagram; + Protocol_Version m_version; + Record_Type m_type; }; - /** * Create a TLS record * @param write_buffer the output record is placed here -* @param rec_msg is the plaintext message -* @param version is the protocol version -* @param msg_sequence is the sequence number +* @param record_type the record layer type +* @param record_version the record layer version +* @param record_sequence the record layer sequence number +* @param message the record contents +* @param message_len is size of message * @param cipherstate is the writing cipher state * @param rng is a random number generator */ void write_record(secure_vector<uint8_t>& write_buffer, - Record_Message rec_msg, - Protocol_Version version, - uint64_t msg_sequence, + uint8_t record_type, + Protocol_Version record_version, + uint64_t record_sequence, + const uint8_t* message, + size_t message_len, Connection_Cipher_State* cipherstate, RandomNumberGenerator& rng); @@ -174,11 +150,14 @@ typedef std::function<std::shared_ptr<Connection_Cipher_State> (uint16_t)> get_c * Decode a TLS record * @return zero if full message, else number of bytes still needed */ -size_t read_record(secure_vector<uint8_t>& read_buffer, - Record_Raw_Input& raw_input, - Record& rec, - Connection_Sequence_Numbers* sequence_numbers, - get_cipherstate_fn get_cipherstate); +Record_Header read_record(bool is_datagram, + secure_vector<uint8_t>& read_buffer, + const uint8_t input[], + size_t input_len, + size_t& consumed, + secure_vector<uint8_t>& record_buf, + Connection_Sequence_Numbers* sequence_numbers, + get_cipherstate_fn get_cipherstate); } |