diff options
-rw-r--r-- | src/tls/tls_record.cpp | 74 |
1 files changed, 43 insertions, 31 deletions
diff --git a/src/tls/tls_record.cpp b/src/tls/tls_record.cpp index 5ecd226a4..db74b7268 100644 --- a/src/tls/tls_record.cpp +++ b/src/tls/tls_record.cpp @@ -272,32 +272,26 @@ size_t read_record(std::vector<byte>& readbuf, byte& msg_type, std::vector<byte>& msg, u64bit msg_sequence, - Protocol_Version version, + Protocol_Version negotiated_version, Connection_Cipher_State* cipherstate) { consumed = 0; - BOTAN_ASSERT(version.valid(), - "We know what version we are using"); - - const size_t header_size = - (version.is_datagram_protocol()) ? DTLS_HEADER_SIZE : TLS_HEADER_SIZE; - - if(readbuf_pos < header_size) // header incomplete? + if(readbuf_pos < TLS_HEADER_SIZE) // header incomplete? { if(size_t needed = fill_buffer_to(readbuf, readbuf_pos, input, input_sz, consumed, - header_size)) + TLS_HEADER_SIZE)) return needed; - BOTAN_ASSERT_EQUAL(readbuf_pos, header_size, + BOTAN_ASSERT_EQUAL(readbuf_pos, TLS_HEADER_SIZE, "Have an entire header"); } // Possible SSLv2 format client hello if((!cipherstate) && (readbuf[0] & 0x80) && (readbuf[2] == 1)) { - if(version.is_datagram_protocol()) + if(negotiated_version.is_datagram_protocol()) throw TLS_Exception(Alert::PROTOCOL_VERSION, "Client sent SSLv2-style DTLS hello"); @@ -346,21 +340,35 @@ size_t read_record(std::vector<byte>& readbuf, Protocol_Version record_version(readbuf[1], readbuf[2]); - if(record_version.is_datagram_protocol()) - msg_sequence = load_be<u64bit>(&readbuf[3], 0); - - const size_t record_len = make_u16bit(readbuf[header_size-2], - readbuf[header_size-1]); - - if(version.valid() && record_version != version) + if(negotiated_version.valid() && record_version != negotiated_version) { throw TLS_Exception(Alert::PROTOCOL_VERSION, "Got record with version " + record_version.to_string() + " expected " + - version.to_string()); + negotiated_version.to_string()); } + if(record_version.is_datagram_protocol() && readbuf_pos < DTLS_HEADER_SIZE) + { + if(size_t needed = fill_buffer_to(readbuf, readbuf_pos, + input, input_sz, consumed, + DTLS_HEADER_SIZE)) + return needed; + + BOTAN_ASSERT_EQUAL(readbuf_pos, DTLS_HEADER_SIZE, + "Have an entire header"); + } + + const size_t header_size = + (record_version.is_datagram_protocol()) ? DTLS_HEADER_SIZE : TLS_HEADER_SIZE; + + if(record_version.is_datagram_protocol()) + msg_sequence = load_be<u64bit>(&readbuf[3], 0); + + const size_t record_len = make_u16bit(readbuf[header_size-2], + readbuf[header_size-1]); + if(record_len > MAX_CIPHERTEXT_SIZE) throw TLS_Exception(Alert::RECORD_OVERFLOW, "Got message that exceeds maximum size"); @@ -376,7 +384,7 @@ size_t read_record(std::vector<byte>& readbuf, byte* record_contents = &readbuf[header_size]; - if(!cipherstate) // Only handshake messages allowed here + if(!cipherstate) // Only handshake messages allowed during initial handshake { if(readbuf[0] != CHANGE_CIPHER_SPEC && readbuf[0] != ALERT && @@ -392,6 +400,9 @@ size_t read_record(std::vector<byte>& readbuf, return 0; // got a full record } + BOTAN_ASSERT(negotiated_version.valid(), + "We know what version we are using"); + // Otherwise, decrypt, check MAC, return plaintext const size_t block_size = cipherstate->block_size(); const size_t iv_size = cipherstate->iv_size(); @@ -403,31 +414,32 @@ size_t read_record(std::vector<byte>& readbuf, } else if(BlockCipher* bc = cipherstate->block_cipher()) { - secure_vector<byte>& cbc_state = cipherstate->cbc_state(); - BOTAN_ASSERT(record_len % block_size == 0, "Buffer is an even multiple of block size"); - byte* buf = record_contents; - const size_t blocks = record_len / block_size; + BOTAN_ASSERT(blocks > 0, "At least one ciphertext block"); + + byte* buf = record_contents; + secure_vector<byte> last_ciphertext(block_size); copy_mem(&last_ciphertext[0], &buf[0], block_size); bc->decrypt(&buf[0]); - xor_buf(&buf[0], &cbc_state[0], block_size); + xor_buf(&buf[0], &cipherstate->cbc_state()[0], block_size); secure_vector<byte> last_ciphertext2; - for(size_t i = 1; i <= blocks; ++i) + for(size_t i = 1; i < blocks; ++i) { last_ciphertext2.assign(&buf[block_size*i], &buf[block_size*(i+1)]); bc->decrypt(&buf[block_size*i]); xor_buf(&buf[block_size*i], &last_ciphertext[0], block_size); std::swap(last_ciphertext, last_ciphertext2); } - cbc_state = last_ciphertext; + + cipherstate->cbc_state() = last_ciphertext; } else throw Internal_Error("NULL cipher not supported"); @@ -437,7 +449,7 @@ size_t read_record(std::vector<byte>& readbuf, * padding_length fields are padding from our perspective. */ const size_t pad_size = - tls_padding_check(version, block_size, + tls_padding_check(negotiated_version, block_size, record_contents, record_len); const size_t mac_pad_iv_size = mac_size + pad_size + iv_size; @@ -448,10 +460,10 @@ size_t read_record(std::vector<byte>& readbuf, cipherstate->mac()->update_be(msg_sequence); cipherstate->mac()->update(readbuf[0]); // msg_type - if(version != Protocol_Version::SSL_V3) + if(negotiated_version != Protocol_Version::SSL_V3) { - cipherstate->mac()->update(version.major_version()); - cipherstate->mac()->update(version.minor_version()); + cipherstate->mac()->update(negotiated_version.major_version()); + cipherstate->mac()->update(negotiated_version.minor_version()); } const u16bit plain_length = record_len - mac_pad_iv_size; |