diff options
author | lloyd <[email protected]> | 2012-09-07 14:13:18 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2012-09-07 14:13:18 +0000 |
commit | 3f877fe296c1959fefa696094314c96f78bb9f7e (patch) | |
tree | a2004d39a05f06c4521086c9880bc8bf73c6e57e | |
parent | 70781697af4a4f6d94f04198b25a556d0a78ee81 (diff) |
In Channel move some checks to after we've verified needed == 0 to
avoid a conditional.
Clean up record checking in the reader.
-rw-r--r-- | src/tls/tls_channel.cpp | 17 | ||||
-rw-r--r-- | src/tls/tls_record.cpp | 17 |
2 files changed, 16 insertions, 18 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index 91aaae206..12fc564f0 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -155,16 +155,6 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) m_current_version, m_read_cipherstate.get()); - if(needed == 0) // full message decoded - { - if(record.size() > m_max_fragment) - throw TLS_Exception(Alert::RECORD_OVERFLOW, - "Plaintext record is too large"); - - record_number = m_read_seq_no; - m_read_seq_no += 1; - } - BOTAN_ASSERT(consumed <= buf_size, "Record reader consumed sane amount"); @@ -177,6 +167,13 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) if(buf_size == 0 && needed != 0) return needed; // need more data to complete record + if(record.size() > m_max_fragment) + throw TLS_Exception(Alert::RECORD_OVERFLOW, + "Plaintext record is too large"); + + record_number = m_read_seq_no; + m_read_seq_no += 1; + if(rec_type == HANDSHAKE || rec_type == CHANGE_CIPHER_SPEC) { if(!m_state) diff --git a/src/tls/tls_record.cpp b/src/tls/tls_record.cpp index 4a031f626..5ecd226a4 100644 --- a/src/tls/tls_record.cpp +++ b/src/tls/tls_record.cpp @@ -344,20 +344,21 @@ size_t read_record(std::vector<byte>& readbuf, " from counterparty"); } - if(version.is_datagram_protocol()) + Protocol_Version record_version(readbuf[1], readbuf[2]); + + if(record_version.is_datagram_protocol()) msg_sequence = load_be<u64bit>(&readbuf[3], 0); const size_t record_len = make_u16bit(readbuf[header_size-2], readbuf[header_size-1]); - if(version.major_version()) + if(version.valid() && record_version != version) { - if(readbuf[1] != version.major_version() || - readbuf[2] != version.minor_version()) - { - throw TLS_Exception(Alert::PROTOCOL_VERSION, - "Got unexpected version from counterparty"); - } + throw TLS_Exception(Alert::PROTOCOL_VERSION, + "Got record with version " + + record_version.to_string() + + " expected " + + version.to_string()); } if(record_len > MAX_CIPHERTEXT_SIZE) |