diff options
author | lloyd <[email protected]> | 2012-12-10 17:53:19 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2012-12-10 17:53:19 +0000 |
commit | b0164f672c53ba10cf24a5f5a502d9aac4746161 (patch) | |
tree | b94baba373f943ad68d860acb939ffb9277139da | |
parent | 71e60dbdb404b715532a9e5d70efdff393602470 (diff) | |
parent | 79d3cfa5fd64ed4cfaa0643bb318edd38f22de92 (diff) |
merge of '2a4d641c566916555a5127b4ba82a1fa9f9e2b0c'
and '59030896322f59cfd47ba0ff17993ccd263174c6'
-rw-r--r-- | src/tls/tls_channel.cpp | 54 | ||||
-rw-r--r-- | src/tls/tls_handshake_io.cpp | 40 | ||||
-rw-r--r-- | src/tls/tls_handshake_io.h | 16 | ||||
-rw-r--r-- | src/tls/tls_record.cpp | 53 | ||||
-rw-r--r-- | src/tls/tls_record.h | 50 |
5 files changed, 123 insertions, 90 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index 5858f5d90..6bbf60a24 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -254,7 +254,7 @@ bool Channel::heartbeat_sending_allowed() const return false; } -size_t Channel::received_data(const byte buf[], size_t buf_size) +size_t Channel::received_data(const byte input[], size_t input_size) { const auto get_cipherstate = [this](u16bit epoch) { return this->read_cipher_state_epoch(epoch).get(); }; @@ -263,57 +263,49 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) try { - while(!is_closed() && buf_size) + while(!is_closed() && input_size) { - byte rec_type = NO_RECORD; - std::vector<byte> record; - u64bit record_sequence = 0; - Protocol_Version record_version; + Record record; size_t consumed = 0; const size_t needed = read_record(m_readbuf, - buf, - buf_size, + input, + input_size, consumed, - rec_type, record, - record_version, - record_sequence, m_sequence_numbers.get(), get_cipherstate); - BOTAN_ASSERT(consumed <= buf_size, + BOTAN_ASSERT(consumed <= input_size, "Record reader consumed sane amount"); - buf += consumed; - buf_size -= consumed; + input += consumed; + input_size -= consumed; - BOTAN_ASSERT(buf_size == 0 || needed == 0, + BOTAN_ASSERT(input_size == 0 || needed == 0, "Got a full record or consumed all input"); - if(buf_size == 0 && needed != 0) + if(input_size == 0 && needed != 0) return needed; // need more data to complete record - if(rec_type == NO_RECORD) - continue; + BOTAN_ASSERT(record.is_valid(), "Got a full record"); if(record.size() > max_fragment_size) throw TLS_Exception(Alert::RECORD_OVERFLOW, "Plaintext record is too large"); - if(rec_type == HANDSHAKE || rec_type == CHANGE_CIPHER_SPEC) + if(record.type() == HANDSHAKE || record.type() == CHANGE_CIPHER_SPEC) { if(!m_pending_state) { - create_handshake_state(record_version); - if(record_version.is_datagram_protocol()) - sequence_numbers().read_accept(record_sequence); + create_handshake_state(record.version()); + if(record.version().is_datagram_protocol()) + sequence_numbers().read_accept(record.sequence()); } - m_pending_state->handshake_io().add_input( - rec_type, &record[0], record.size(), record_sequence); + m_pending_state->handshake_io().add_record(record); while(auto pending = m_pending_state.get()) { @@ -326,12 +318,12 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) msg.first, msg.second); } } - else if(rec_type == HEARTBEAT && peer_supports_heartbeats()) + else if(record.type() == HEARTBEAT && peer_supports_heartbeats()) { if(!active_state()) throw Unexpected_Message("Heartbeat sent before handshake done"); - Heartbeat_Message heartbeat(record); + Heartbeat_Message heartbeat(record.contents()); const std::vector<byte>& payload = heartbeat.payload(); @@ -351,7 +343,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) m_proc_fn(&payload[0], payload.size(), Alert(Alert::HEARTBEAT_PAYLOAD)); } } - else if(rec_type == APPLICATION_DATA) + else if(record.type() == APPLICATION_DATA) { if(!active_state()) throw Unexpected_Message("Application data before handshake done"); @@ -362,11 +354,11 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) * following record. Avoid spurious callbacks. */ if(record.size() > 0) - m_proc_fn(&record[0], record.size(), Alert()); + m_proc_fn(record.bits(), record.size(), Alert()); } - else if(rec_type == ALERT) + else if(record.type() == ALERT) { - Alert alert_msg(record); + Alert alert_msg(record.contents()); if(alert_msg.type() == Alert::NO_RENEGOTIATION) m_pending_state.reset(); @@ -392,7 +384,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) } else throw Unexpected_Message("Unexpected record type " + - std::to_string(rec_type) + + std::to_string(record.type()) + " from counterparty"); } diff --git a/src/tls/tls_handshake_io.cpp b/src/tls/tls_handshake_io.cpp index 1fae7b5b7..c685c80ef 100644 --- a/src/tls/tls_handshake_io.cpp +++ b/src/tls/tls_handshake_io.cpp @@ -7,6 +7,7 @@ #include <botan/internal/tls_handshake_io.h> #include <botan/internal/tls_messages.h> +#include <botan/internal/tls_record.h> #include <botan/internal/tls_seq_numbers.h> #include <botan/exceptn.h> @@ -38,18 +39,15 @@ Protocol_Version Stream_Handshake_IO::initial_record_version() const return Protocol_Version::TLS_V10; } -void Stream_Handshake_IO::add_input(const byte rec_type, - const byte record[], - size_t record_size, - u64bit /*record_number*/) +void Stream_Handshake_IO::add_record(const Record& record) { - if(rec_type == HANDSHAKE) + if(record.type() == HANDSHAKE) { - m_queue.insert(m_queue.end(), record, record + record_size); + m_queue.insert(m_queue.end(), record.bits(), record.bits() + record.size()); } - else if(rec_type == CHANGE_CIPHER_SPEC) + else if(record.type() == CHANGE_CIPHER_SPEC) { - if(record_size != 1 || record[0] != 1) + if(record.size() != 1 || record.bits()[0] != 1) throw Decoding_Error("Invalid ChangeCipherSpec"); // Pretend it's a regular handshake message of zero length @@ -120,14 +118,11 @@ Protocol_Version Datagram_Handshake_IO::initial_record_version() const return Protocol_Version::DTLS_V10; } -void Datagram_Handshake_IO::add_input(const byte rec_type, - const byte record[], - size_t record_size, - u64bit record_number) +void Datagram_Handshake_IO::add_record(const Record& record) { - const u16bit epoch = static_cast<u16bit>(record_number >> 48); + const u16bit epoch = static_cast<u16bit>(record.sequence() >> 48); - if(rec_type == CHANGE_CIPHER_SPEC) + if(record.type() == CHANGE_CIPHER_SPEC) { m_ccs_epochs.insert(epoch); return; @@ -135,16 +130,19 @@ void Datagram_Handshake_IO::add_input(const byte rec_type, const size_t DTLS_HANDSHAKE_HEADER_LEN = 12; + const byte* record_bits = record.bits(); + size_t record_size = record.size(); + while(record_size) { if(record_size < DTLS_HANDSHAKE_HEADER_LEN) return; // completely bogus? at least degenerate/weird - const byte msg_type = record[0]; - const size_t msg_len = load_be24(&record[1]); - const u16bit message_seq = load_be<u16bit>(&record[4], 0); - const size_t fragment_offset = load_be24(&record[6]); - const size_t fragment_length = load_be24(&record[9]); + const byte msg_type = record_bits[0]; + const size_t msg_len = load_be24(&record_bits[1]); + const u16bit message_seq = load_be<u16bit>(&record_bits[4], 0); + const size_t fragment_offset = load_be24(&record_bits[6]); + const size_t fragment_length = load_be24(&record_bits[9]); const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length; @@ -153,7 +151,7 @@ void Datagram_Handshake_IO::add_input(const byte rec_type, if(message_seq >= m_in_message_seq) { - m_messages[message_seq].add_fragment(&record[DTLS_HANDSHAKE_HEADER_LEN], + m_messages[message_seq].add_fragment(&record_bits[DTLS_HANDSHAKE_HEADER_LEN], fragment_length, fragment_offset, epoch, @@ -161,7 +159,7 @@ void Datagram_Handshake_IO::add_input(const byte rec_type, msg_len); } - record += total_size; + record_bits += total_size; record_size -= total_size; } } diff --git a/src/tls/tls_handshake_io.h b/src/tls/tls_handshake_io.h index 18fde1a83..82d1a8e7e 100644 --- a/src/tls/tls_handshake_io.h +++ b/src/tls/tls_handshake_io.h @@ -24,6 +24,7 @@ namespace Botan { namespace TLS { class Handshake_Message; +class Record; /** * Handshake IO Interface @@ -39,10 +40,7 @@ class Handshake_IO const std::vector<byte>& handshake_msg, Handshake_Type handshake_type) const = 0; - virtual void add_input(byte record_type, - const byte record[], - size_t record_size, - u64bit record_number) = 0; + virtual void add_record(const Record& record) = 0; /** * Returns (HANDSHAKE_NONE, std::vector<>()) if no message currently available @@ -76,10 +74,7 @@ class Stream_Handshake_IO : public Handshake_IO const std::vector<byte>& handshake_msg, Handshake_Type handshake_type) const override; - void add_input(byte record_type, - const byte record[], - size_t record_size, - u64bit record_number) override; + void add_record(const Record& record) override; std::pair<Handshake_Type, std::vector<byte>> get_next_record(bool expecting_ccs) override; @@ -106,10 +101,7 @@ class Datagram_Handshake_IO : public Handshake_IO const std::vector<byte>& handshake_msg, Handshake_Type handshake_type) const override; - void add_input(const byte rec_type, - const byte record[], - size_t record_size, - u64bit record_number) override; + void add_record(const Record& record) override; std::pair<Handshake_Type, std::vector<byte>> get_next_record(bool expecting_ccs) override; diff --git a/src/tls/tls_record.cpp b/src/tls/tls_record.cpp index fab966e72..0557c1796 100644 --- a/src/tls/tls_record.cpp +++ b/src/tls/tls_record.cpp @@ -270,10 +270,7 @@ size_t read_record(std::vector<byte>& readbuf, const byte input[], size_t input_sz, size_t& consumed, - byte& msg_type, - std::vector<byte>& msg, - Protocol_Version& record_version, - u64bit& record_sequence, + Record& record, Connection_Sequence_Numbers* sequence_numbers, std::function<Connection_Cipher_State* (u16bit)> get_cipherstate) { @@ -309,24 +306,27 @@ size_t read_record(std::vector<byte>& readbuf, BOTAN_ASSERT_EQUAL(readbuf.size(), (record_len + 2), "Have the entire SSLv2 hello"); - msg_type = HANDSHAKE; + // Fake v3-style handshake message wrapper + std::vector<byte> sslv2_hello(4 + readbuf.size() - 2); - msg.resize(record_len + 4); + sslv2_hello[0] = CLIENT_HELLO_SSLV2; + sslv2_hello[1] = 0; + sslv2_hello[2] = readbuf[0] & 0x7F; + sslv2_hello[3] = readbuf[1]; - // Fake v3-style handshake message wrapper - msg[0] = CLIENT_HELLO_SSLV2; - msg[1] = 0; - msg[2] = readbuf[0] & 0x7F; - msg[3] = readbuf[1]; + copy_mem(&sslv2_hello[4], &readbuf[2], readbuf.size() - 2); - copy_mem(&msg[4], &readbuf[2], readbuf.size() - 2); + record = Record(0, + Protocol_Version::TLS_V10, + HANDSHAKE, + std::move(sslv2_hello)); readbuf.clear(); return 0; } } - record_version = Protocol_Version(readbuf[1], readbuf[2]); + Protocol_Version record_version = Protocol_Version(readbuf[1], readbuf[2]); const bool is_dtls = record_version.is_datagram_protocol(); @@ -359,6 +359,9 @@ size_t read_record(std::vector<byte>& readbuf, readbuf.size(), "Have the full record"); + Record_Type record_type = static_cast<Record_Type>(readbuf[0]); + + u64bit record_sequence = 0; u16bit epoch = 0; if(is_dtls) @@ -385,8 +388,11 @@ size_t read_record(std::vector<byte>& readbuf, if(epoch == 0) // Unencrypted initial handshake { - msg_type = readbuf[0]; - msg.assign(&record_contents[0], &record_contents[record_len]); + record = Record(record_sequence, + record_version, + record_type, + &readbuf[header_size], + record_len); readbuf.clear(); return 0; // got a full record @@ -453,7 +459,7 @@ size_t read_record(std::vector<byte>& readbuf, throw Decoding_Error("Record sent with invalid length"); cipherstate->mac()->update_be(record_sequence); - cipherstate->mac()->update(readbuf[0]); // msg_type + cipherstate->mac()->update(static_cast<byte>(record_type)); if(cipherstate->mac_includes_record_version()) { @@ -461,10 +467,11 @@ size_t read_record(std::vector<byte>& readbuf, cipherstate->mac()->update(record_version.minor_version()); } - const u16bit plain_length = record_len - mac_pad_iv_size; + const byte* plaintext_block = &record_contents[iv_size]; + const u16bit plaintext_length = record_len - mac_pad_iv_size; - cipherstate->mac()->update_be(plain_length); - cipherstate->mac()->update(&record_contents[iv_size], plain_length); + cipherstate->mac()->update_be(plaintext_length); + cipherstate->mac()->update(plaintext_block, plaintext_length); std::vector<byte> mac_buf(mac_size); cipherstate->mac()->final(&mac_buf[0]); @@ -481,9 +488,11 @@ size_t read_record(std::vector<byte>& readbuf, if(sequence_numbers) sequence_numbers->read_accept(record_sequence); - msg_type = readbuf[0]; - msg.assign(&record_contents[iv_size], - &record_contents[iv_size + plain_length]); + record = Record(record_sequence, + record_version, + record_type, + plaintext_block, + plaintext_length); readbuf.clear(); return 0; diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h index a73efccb1..bc86600fa 100644 --- a/src/tls/tls_record.h +++ b/src/tls/tls_record.h @@ -78,6 +78,51 @@ class Connection_Cipher_State bool m_is_ssl3 = false; }; +class Record + { + public: + Record() {} + + Record(u64bit sequence, + Protocol_Version version, + Record_Type type, + const byte contents[], + size_t contents_size) : + m_sequence(sequence), + m_version(version), + m_type(type), + m_contents(contents, contents + contents_size) {} + + Record(u64bit sequence, + Protocol_Version version, + Record_Type type, + std::vector<byte>&& contents) : + m_sequence(sequence), + m_version(version), + m_type(type), + m_contents(contents) {} + + bool is_valid() const { return m_type != NO_RECORD; } + + u64bit sequence() const { return m_sequence; } + + Record_Type type() const { return m_type; } + + Protocol_Version version() const { return m_version; } + + const std::vector<byte>& contents() const { return m_contents; } + + const byte* bits() const { return &m_contents[0]; } + + size_t size() const { return m_contents.size(); } + + private: + u64bit m_sequence = 0; + Protocol_Version m_version = Protocol_Version(); + Record_Type m_type = NO_RECORD; + std::vector<byte> m_contents; + }; + /** * Create a TLS record * @param write_buffer the output record is placed here @@ -105,10 +150,7 @@ size_t read_record(std::vector<byte>& read_buffer, const byte input[], size_t input_length, size_t& input_consumed, - byte& msg_type, - std::vector<byte>& msg, - Protocol_Version& record_version, - u64bit& record_sequence, + Record& output_record, Connection_Sequence_Numbers* sequence_numbers, std::function<Connection_Cipher_State* (u16bit)> get_cipherstate); |