diff options
Diffstat (limited to 'src/lib/tls/tls_record.cpp')
-rw-r--r-- | src/lib/tls/tls_record.cpp | 186 |
1 files changed, 139 insertions, 47 deletions
diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp index 925961764..0b356fad3 100644 --- a/src/lib/tls/tls_record.cpp +++ b/src/lib/tls/tls_record.cpp @@ -464,18 +464,16 @@ void decrypt_record(secure_vector<byte>& output, } } -} - -size_t read_record(secure_vector<byte>& readbuf, - const byte input[], - size_t input_sz, - size_t& consumed, - secure_vector<byte>& record, - u64bit* record_sequence, - Protocol_Version* record_version, - Record_Type* record_type, - Connection_Sequence_Numbers* sequence_numbers, - std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate) +size_t read_tls_record(secure_vector<byte>& readbuf, + const byte input[], + size_t input_sz, + size_t& consumed, + secure_vector<byte>& record, + u64bit* record_sequence, + Protocol_Version* record_version, + Record_Type* record_type, + Connection_Sequence_Numbers* sequence_numbers, + std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate) { consumed = 0; @@ -529,23 +527,10 @@ size_t read_record(secure_vector<byte>& readbuf, *record_version = Protocol_Version(readbuf[1], readbuf[2]); - const bool is_dtls = record_version->is_datagram_protocol(); + BOTAN_ASSERT(!record_version->is_datagram_protocol(), "Expected TLS"); - if(is_dtls && readbuf.size() < DTLS_HEADER_SIZE) - { - if(size_t needed = fill_buffer_to(readbuf, - input, input_sz, consumed, - DTLS_HEADER_SIZE)) - return needed; - - BOTAN_ASSERT_EQUAL(readbuf.size(), DTLS_HEADER_SIZE, - "Have an entire header"); - } - - const size_t header_size = (is_dtls) ? DTLS_HEADER_SIZE : TLS_HEADER_SIZE; - - const size_t record_len = make_u16bit(readbuf[header_size-2], - readbuf[header_size-1]); + const size_t record_len = make_u16bit(readbuf[TLS_HEADER_SIZE-2], + readbuf[TLS_HEADER_SIZE-1]); if(record_len > MAX_CIPHERTEXT_SIZE) throw TLS_Exception(Alert::RECORD_OVERFLOW, @@ -553,10 +538,10 @@ size_t read_record(secure_vector<byte>& readbuf, if(size_t needed = fill_buffer_to(readbuf, input, input_sz, consumed, - header_size + record_len)) - return needed; // wrong for DTLS? + TLS_HEADER_SIZE + record_len)) + return needed; - BOTAN_ASSERT_EQUAL(static_cast<size_t>(header_size) + record_len, + BOTAN_ASSERT_EQUAL(static_cast<size_t>(TLS_HEADER_SIZE) + record_len, readbuf.size(), "Have the full record"); @@ -564,12 +549,7 @@ size_t read_record(secure_vector<byte>& readbuf, u16bit epoch = 0; - if(is_dtls) - { - *record_sequence = load_be<u64bit>(&readbuf[3], 0); - epoch = (*record_sequence >> 48); - } - else if(sequence_numbers) + if(sequence_numbers) { *record_sequence = sequence_numbers->next_read_sequence(); epoch = sequence_numbers->current_read_epoch(); @@ -581,17 +561,11 @@ size_t read_record(secure_vector<byte>& readbuf, epoch = 0; } - if(sequence_numbers && sequence_numbers->already_seen(*record_sequence)) - { - readbuf.clear(); - return 0; - } - - byte* record_contents = &readbuf[header_size]; + byte* record_contents = &readbuf[TLS_HEADER_SIZE]; if(epoch == 0) // Unencrypted initial handshake { - record.assign(&readbuf[header_size], &readbuf[header_size + record_len]); + record.assign(&readbuf[TLS_HEADER_SIZE], &readbuf[TLS_HEADER_SIZE + record_len]); readbuf.clear(); return 0; // got a full record } @@ -599,8 +573,6 @@ size_t read_record(secure_vector<byte>& readbuf, // Otherwise, decrypt, check MAC, return plaintext auto cipherstate = get_cipherstate(epoch); - // FIXME: DTLS reordering might cause us not to have the cipher state - BOTAN_ASSERT(cipherstate, "Have cipherstate for this epoch"); decrypt_record(record, @@ -618,6 +590,126 @@ size_t read_record(secure_vector<byte>& readbuf, return 0; } +size_t read_dtls_record(secure_vector<byte>& readbuf, + const byte input[], + size_t input_sz, + size_t& consumed, + secure_vector<byte>& record, + u64bit* record_sequence, + Protocol_Version* record_version, + Record_Type* record_type, + Connection_Sequence_Numbers* sequence_numbers, + std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate) + { + consumed = 0; + + if(readbuf.size() < DTLS_HEADER_SIZE) // header incomplete? + { + if(fill_buffer_to(readbuf, input, input_sz, consumed, DTLS_HEADER_SIZE)) + { + readbuf.clear(); + return 0; + } + + BOTAN_ASSERT_EQUAL(readbuf.size(), DTLS_HEADER_SIZE, "Have an entire header"); + } + + *record_version = Protocol_Version(readbuf[1], readbuf[2]); + + BOTAN_ASSERT(record_version->is_datagram_protocol(), "Expected DTLS"); + + const size_t record_len = make_u16bit(readbuf[DTLS_HEADER_SIZE-2], + readbuf[DTLS_HEADER_SIZE-1]); + + if(record_len > MAX_CIPHERTEXT_SIZE) + throw TLS_Exception(Alert::RECORD_OVERFLOW, + "Got message that exceeds maximum size"); + + if(fill_buffer_to(readbuf, input, input_sz, consumed, DTLS_HEADER_SIZE + record_len)) + { + // Truncated packet? + readbuf.clear(); + return 0; + } + + BOTAN_ASSERT_EQUAL(static_cast<size_t>(DTLS_HEADER_SIZE) + record_len, readbuf.size(), + "Have the full record"); + + *record_type = static_cast<Record_Type>(readbuf[0]); + + u16bit epoch = 0; + + *record_sequence = load_be<u64bit>(&readbuf[3], 0); + epoch = (*record_sequence >> 48); + + if(sequence_numbers && sequence_numbers->already_seen(*record_sequence)) + { + readbuf.clear(); + return 0; + } + + byte* record_contents = &readbuf[DTLS_HEADER_SIZE]; + + if(epoch == 0) // Unencrypted initial handshake + { + record.assign(&readbuf[DTLS_HEADER_SIZE], &readbuf[DTLS_HEADER_SIZE + record_len]); + readbuf.clear(); + return 0; // got a full record + } + + try + { + // Otherwise, decrypt, check MAC, return plaintext + auto cipherstate = get_cipherstate(epoch); + + BOTAN_ASSERT(cipherstate, "Have cipherstate for this epoch"); + + decrypt_record(record, + record_contents, + record_len, + *record_sequence, + *record_version, + *record_type, + *cipherstate); + } + catch(std::exception) + { + readbuf.clear(); + *record_type = NO_RECORD; + return 0; + } + + if(sequence_numbers) + sequence_numbers->read_accept(*record_sequence); + + readbuf.clear(); + return 0; + } + +} + +size_t read_record(secure_vector<byte>& readbuf, + const byte input[], + size_t input_sz, + bool is_datagram, + size_t& consumed, + secure_vector<byte>& record, + u64bit* record_sequence, + Protocol_Version* record_version, + Record_Type* record_type, + Connection_Sequence_Numbers* sequence_numbers, + std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate) + { + if(is_datagram) + return read_dtls_record(readbuf, input, input_sz, consumed, + record, record_sequence, record_version, record_type, + sequence_numbers, get_cipherstate); + else + return read_tls_record(readbuf, input, input_sz, consumed, + record, record_sequence, record_version, record_type, + sequence_numbers, get_cipherstate); + } + } } |