diff options
Diffstat (limited to 'src/tls')
-rw-r--r-- | src/tls/tls_alert.cpp | 2 | ||||
-rw-r--r-- | src/tls/tls_alert.h | 2 | ||||
-rw-r--r-- | src/tls/tls_channel.cpp | 36 | ||||
-rw-r--r-- | src/tls/tls_record.cpp | 226 | ||||
-rw-r--r-- | src/tls/tls_record.h | 50 |
5 files changed, 120 insertions, 196 deletions
diff --git a/src/tls/tls_alert.cpp b/src/tls/tls_alert.cpp index f548bd57b..15bb2a2dc 100644 --- a/src/tls/tls_alert.cpp +++ b/src/tls/tls_alert.cpp @@ -12,7 +12,7 @@ namespace Botan { namespace TLS { -Alert::Alert(const std::vector<byte>& buf) +Alert::Alert(const secure_vector<byte>& buf) { if(buf.size() != 2) throw Decoding_Error("Alert: Bad size " + std::to_string(buf.size()) + diff --git a/src/tls/tls_alert.h b/src/tls/tls_alert.h index 12ab57d6b..bf32178ee 100644 --- a/src/tls/tls_alert.h +++ b/src/tls/tls_alert.h @@ -90,7 +90,7 @@ class BOTAN_DLL Alert * Deserialize an Alert message * @param buf the serialized alert */ - Alert(const std::vector<byte>& buf); + Alert(const secure_vector<byte>& buf); /** * Create a new Alert diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index c00970c49..7c7d65961 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -278,7 +278,10 @@ size_t Channel::received_data(const byte input[], size_t input_size) { while(!is_closed() && input_size) { - Record record; + secure_vector<byte> record; + u64bit record_sequence = 0; + Record_Type record_type = NO_RECORD; + Protocol_Version record_version; size_t consumed = 0; @@ -288,6 +291,9 @@ size_t Channel::received_data(const byte input[], size_t input_size) input_size, consumed, record, + &record_sequence, + &record_version, + &record_type, m_sequence_numbers.get(), get_cipherstate); @@ -303,22 +309,22 @@ size_t Channel::received_data(const byte input[], size_t input_size) if(input_size == 0 && needed != 0) return needed; // need more data to complete record - BOTAN_ASSERT(record.is_valid(), "Got a full record"); - if(record.size() > max_fragment_size) throw TLS_Exception(Alert::RECORD_OVERFLOW, "Plaintext record is too large"); - if(record.type() == HANDSHAKE || record.type() == CHANGE_CIPHER_SPEC) + if(record_type == HANDSHAKE || record_type == CHANGE_CIPHER_SPEC) { if(!m_pending_state) { - create_handshake_state(record.version()); - if(record.version().is_datagram_protocol()) - sequence_numbers().read_accept(record.sequence()); + create_handshake_state(record_version); + if(record_version.is_datagram_protocol()) + sequence_numbers().read_accept(record_sequence); } - m_pending_state->handshake_io().add_record(record.contents(), record.type(), record.sequence()); + m_pending_state->handshake_io().add_record(unlock(record), + record_type, + record_sequence); while(auto pending = m_pending_state.get()) { @@ -331,12 +337,12 @@ size_t Channel::received_data(const byte input[], size_t input_size) msg.first, msg.second); } } - else if(record.type() == HEARTBEAT && peer_supports_heartbeats()) + else if(record_type == HEARTBEAT && peer_supports_heartbeats()) { if(!active_state()) throw Unexpected_Message("Heartbeat sent before handshake done"); - Heartbeat_Message heartbeat(record.contents()); + Heartbeat_Message heartbeat(unlock(record)); const std::vector<byte>& payload = heartbeat.payload(); @@ -356,7 +362,7 @@ size_t Channel::received_data(const byte input[], size_t input_size) m_proc_fn(&payload[0], payload.size(), Alert(Alert::HEARTBEAT_PAYLOAD)); } } - else if(record.type() == APPLICATION_DATA) + else if(record_type == APPLICATION_DATA) { if(!active_state()) throw Unexpected_Message("Application data before handshake done"); @@ -367,11 +373,11 @@ size_t Channel::received_data(const byte input[], size_t input_size) * following record. Avoid spurious callbacks. */ if(record.size() > 0) - m_proc_fn(record.bits(), record.size(), Alert()); + m_proc_fn(&record[0], record.size(), Alert()); } - else if(record.type() == ALERT) + else if(record_type == ALERT) { - Alert alert_msg(record.contents()); + Alert alert_msg(record); if(alert_msg.type() == Alert::NO_RENEGOTIATION) m_pending_state.reset(); @@ -395,7 +401,7 @@ size_t Channel::received_data(const byte input[], size_t input_size) } else throw Unexpected_Message("Unexpected record type " + - std::to_string(record.type()) + + std::to_string(record_type) + " from counterparty"); } diff --git a/src/tls/tls_record.cpp b/src/tls/tls_record.cpp index d9b222a85..2bfe78152 100644 --- a/src/tls/tls_record.cpp +++ b/src/tls/tls_record.cpp @@ -297,60 +297,6 @@ size_t fill_buffer_to(secure_vector<byte>& readbuf, } /* -* MAC scheme used in SSLv3/TLSv1 for RC4 and CBC ciphers -*/ -bool traditional_mac_check(Record& output_record, - byte record_contents[], size_t record_len, - size_t pad_size, - volatile bool padding_bad, - u64bit record_sequence, - Protocol_Version record_version, - Record_Type record_type, - Connection_Cipher_State& cipherstate) - { - const size_t mac_size = cipherstate.mac_size(); - const size_t iv_size = cipherstate.iv_size(); - - cipherstate.mac()->update_be(record_sequence); - cipherstate.mac()->update(static_cast<byte>(record_type)); - - if(cipherstate.mac_includes_record_version()) - { - cipherstate.mac()->update(record_version.major_version()); - cipherstate.mac()->update(record_version.minor_version()); - } - - const size_t mac_pad_iv_size = mac_size + pad_size + iv_size; - - if(record_len < mac_pad_iv_size) - throw Decoding_Error("Record sent with invalid length"); - - const byte* plaintext_block = &record_contents[iv_size]; - const u16bit plaintext_length = record_len - mac_pad_iv_size; - - cipherstate.mac()->update_be(plaintext_length); - cipherstate.mac()->update(plaintext_block, plaintext_length); - - std::vector<byte> mac_buf(mac_size); - cipherstate.mac()->final(&mac_buf[0]); - - const size_t mac_offset = record_len - (mac_size + pad_size); - - const bool mac_bad = !same_mem(&record_contents[mac_offset], &mac_buf[0], mac_size); - - if(mac_bad || padding_bad) - throw TLS_Exception(Alert::BAD_RECORD_MAC, "Message authentication failure"); - - output_record = Record(record_sequence, - record_version, - record_type, - plaintext_block, - plaintext_length); - - return true; - } - -/* * Checks the TLS padding. Returns 0 if the padding is invalid (we * count the padding_length field as part of the padding size so a * valid padding will always be at least one byte long), or the length @@ -436,7 +382,7 @@ void cbc_decrypt_record(byte record_contents[], size_t record_len, cipherstate.cbc_state() = last_ciphertext; } -bool decrypt_record(Record& output_record, +void decrypt_record(secure_vector<byte>& output, byte record_contents[], size_t record_len, u64bit record_sequence, Protocol_Version record_version, @@ -458,57 +404,76 @@ bool decrypt_record(Record& output_record, cipherstate.format_ad(record_sequence, record_type, record_version, ptext_size) ); - // fixme - making a copy, should steal from Record - secure_vector<byte> buffer; - buffer += aead->start_vec(nonce); + output += aead->start_vec(nonce); - const size_t offset = buffer.size(); - buffer += std::make_pair(&msg[0], msg_length); - aead->finish(buffer, offset); + const size_t offset = output.size(); + output += std::make_pair(&msg[0], msg_length); + aead->finish(output, offset); - BOTAN_ASSERT(buffer.size() == ptext_size + offset, "Produced expected size"); + BOTAN_ASSERT(output.size() == ptext_size + offset, "Produced expected size"); + } + else + { + // GenericBlockCipher / GenericStreamCipher case - output_record = Record(record_sequence, - record_version, - record_type, - &buffer[0], - buffer.size()); + volatile bool padding_bad = false; + size_t pad_size = 0; - return true; - } + if(StreamCipher* sc = cipherstate.stream_cipher()) + { + sc->cipher1(record_contents, record_len); + // no padding to check or remove + } + else if(BlockCipher* bc = cipherstate.block_cipher()) + { + cbc_decrypt_record(record_contents, record_len, cipherstate, *bc); - volatile bool padding_bad = false; - size_t pad_size = 0; + pad_size = tls_padding_check(cipherstate.cipher_padding_single_byte(), + cipherstate.block_size(), + record_contents, record_len); - if(StreamCipher* sc = cipherstate.stream_cipher()) - { - sc->cipher1(record_contents, record_len); - // no padding to check or remove - } - else if(BlockCipher* bc = cipherstate.block_cipher()) - { - cbc_decrypt_record(record_contents, record_len, cipherstate, *bc); + padding_bad = (pad_size == 0); + } + else + { + throw Internal_Error("No cipher state set but needed to decrypt"); + } - pad_size = tls_padding_check(cipherstate.cipher_padding_single_byte(), - cipherstate.block_size(), - record_contents, record_len); + const size_t mac_size = cipherstate.mac_size(); + const size_t iv_size = cipherstate.iv_size(); - padding_bad = (pad_size == 0); - } - else - { - throw Internal_Error("No cipher state set but needed to decrypt"); - } + cipherstate.mac()->update_be(record_sequence); + cipherstate.mac()->update(static_cast<byte>(record_type)); - return traditional_mac_check(output_record, - record_contents, - record_len, - pad_size, - padding_bad, - record_sequence, - record_version, - record_type, - cipherstate); + if(cipherstate.mac_includes_record_version()) + { + cipherstate.mac()->update(record_version.major_version()); + cipherstate.mac()->update(record_version.minor_version()); + } + + const size_t mac_pad_iv_size = mac_size + pad_size + iv_size; + + if(record_len < mac_pad_iv_size) + throw Decoding_Error("Record sent with invalid length"); + + const byte* plaintext_block = &record_contents[iv_size]; + const u16bit plaintext_length = record_len - mac_pad_iv_size; + + cipherstate.mac()->update_be(plaintext_length); + cipherstate.mac()->update(plaintext_block, plaintext_length); + + std::vector<byte> mac_buf(mac_size); + cipherstate.mac()->final(&mac_buf[0]); + + const size_t mac_offset = record_len - (mac_size + pad_size); + + const bool mac_bad = !same_mem(&record_contents[mac_offset], &mac_buf[0], mac_size); + + if(mac_bad || padding_bad) + throw TLS_Exception(Alert::BAD_RECORD_MAC, "Message authentication failure"); + + output.assign(plaintext_block, plaintext_block + plaintext_length); + } } } @@ -517,7 +482,10 @@ size_t read_record(secure_vector<byte>& readbuf, const byte input[], size_t input_sz, size_t& consumed, - Record& record, + secure_vector<byte>& record, + u64bit* record_sequence, + Protocol_Version* record_version, + Record_Type* record_type, Connection_Sequence_Numbers* sequence_numbers, std::function<Connection_Cipher_State* (u16bit)> get_cipherstate) { @@ -554,28 +522,26 @@ size_t read_record(secure_vector<byte>& readbuf, "Have the entire SSLv2 hello"); // Fake v3-style handshake message wrapper - std::vector<byte> sslv2_hello(4 + readbuf.size() - 2); - - sslv2_hello[0] = CLIENT_HELLO_SSLV2; - sslv2_hello[1] = 0; - sslv2_hello[2] = readbuf[0] & 0x7F; - sslv2_hello[3] = readbuf[1]; + *record_version = Protocol_Version::TLS_V10; + *record_sequence = 0; + *record_type = HANDSHAKE; - copy_mem(&sslv2_hello[4], &readbuf[2], readbuf.size() - 2); + record.resize(4 + readbuf.size() - 2); - record = Record(0, - Protocol_Version::TLS_V10, - HANDSHAKE, - std::move(sslv2_hello)); + record[0] = CLIENT_HELLO_SSLV2; + record[1] = 0; + record[2] = readbuf[0] & 0x7F; + record[3] = readbuf[1]; + copy_mem(&record[4], &readbuf[2], readbuf.size() - 2); readbuf.clear(); return 0; } } - Protocol_Version record_version = Protocol_Version(readbuf[1], readbuf[2]); + *record_version = Protocol_Version(readbuf[1], readbuf[2]); - const bool is_dtls = record_version.is_datagram_protocol(); + const bool is_dtls = record_version->is_datagram_protocol(); if(is_dtls && readbuf.size() < DTLS_HEADER_SIZE) { @@ -606,41 +572,35 @@ size_t read_record(secure_vector<byte>& readbuf, readbuf.size(), "Have the full record"); - Record_Type record_type = static_cast<Record_Type>(readbuf[0]); + *record_type = static_cast<Record_Type>(readbuf[0]); - u64bit record_sequence = 0; u16bit epoch = 0; if(is_dtls) { - record_sequence = load_be<u64bit>(&readbuf[3], 0); - epoch = (record_sequence >> 48); + *record_sequence = load_be<u64bit>(&readbuf[3], 0); + epoch = (*record_sequence >> 48); } else if(sequence_numbers) { - record_sequence = sequence_numbers->next_read_sequence(); + *record_sequence = sequence_numbers->next_read_sequence(); epoch = sequence_numbers->current_read_epoch(); } else { // server initial handshake case - record_sequence = 0; + *record_sequence = 0; epoch = 0; } - if(sequence_numbers && sequence_numbers->already_seen(record_sequence)) + if(sequence_numbers && sequence_numbers->already_seen(*record_sequence)) return 0; byte* record_contents = &readbuf[header_size]; if(epoch == 0) // Unencrypted initial handshake { - record = Record(record_sequence, - record_version, - record_type, - &readbuf[header_size], - record_len); - + record.assign(&readbuf[header_size], &readbuf[header_size + record_len]); readbuf.clear(); return 0; // got a full record } @@ -652,16 +612,16 @@ size_t read_record(secure_vector<byte>& readbuf, BOTAN_ASSERT(cipherstate, "Have cipherstate for this epoch"); - const bool ok = decrypt_record(record, - record_contents, - record_len, - record_sequence, - record_version, - record_type, - *cipherstate); + decrypt_record(record, + record_contents, + record_len, + *record_sequence, + *record_version, + *record_type, + *cipherstate); - if(ok && sequence_numbers) - sequence_numbers->read_accept(record_sequence); + if(sequence_numbers) + sequence_numbers->read_accept(*record_sequence); readbuf.clear(); return 0; diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h index 68893af89..ef27a0a02 100644 --- a/src/tls/tls_record.h +++ b/src/tls/tls_record.h @@ -94,51 +94,6 @@ class Connection_Cipher_State bool m_is_ssl3 = false; }; -class Record - { - public: - Record() {} - - Record(u64bit sequence, - Protocol_Version version, - Record_Type type, - const byte contents[], - size_t contents_size) : - m_sequence(sequence), - m_version(version), - m_type(type), - m_contents(contents, contents + contents_size) {} - - Record(u64bit sequence, - Protocol_Version version, - Record_Type type, - std::vector<byte>&& contents) : - m_sequence(sequence), - m_version(version), - m_type(type), - m_contents(contents) {} - - bool is_valid() const { return m_type != NO_RECORD; } - - u64bit sequence() const { return m_sequence; } - - Record_Type type() const { return m_type; } - - Protocol_Version version() const { return m_version; } - - const std::vector<byte>& contents() const { return m_contents; } - - const byte* bits() const { return &m_contents[0]; } - - size_t size() const { return m_contents.size(); } - - private: - u64bit m_sequence = 0; - Protocol_Version m_version = Protocol_Version(); - Record_Type m_type = NO_RECORD; - std::vector<byte> m_contents; - }; - /** * Create a TLS record * @param write_buffer the output record is placed here @@ -166,7 +121,10 @@ size_t read_record(secure_vector<byte>& read_buffer, const byte input[], size_t input_length, size_t& input_consumed, - Record& output_record, + secure_vector<byte>& record, + u64bit* record_sequence, + Protocol_Version* record_version, + Record_Type* record_type, Connection_Sequence_Numbers* sequence_numbers, std::function<Connection_Cipher_State* (u16bit)> get_cipherstate); |