aboutsummaryrefslogtreecommitdiffstats
path: root/src/lib/tls/tls_record.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/tls/tls_record.cpp')
-rw-r--r--src/lib/tls/tls_record.cpp186
1 files changed, 139 insertions, 47 deletions
diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp
index 925961764..0b356fad3 100644
--- a/src/lib/tls/tls_record.cpp
+++ b/src/lib/tls/tls_record.cpp
@@ -464,18 +464,16 @@ void decrypt_record(secure_vector<byte>& output,
}
}
-}
-
-size_t read_record(secure_vector<byte>& readbuf,
- const byte input[],
- size_t input_sz,
- size_t& consumed,
- secure_vector<byte>& record,
- u64bit* record_sequence,
- Protocol_Version* record_version,
- Record_Type* record_type,
- Connection_Sequence_Numbers* sequence_numbers,
- std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate)
+size_t read_tls_record(secure_vector<byte>& readbuf,
+ const byte input[],
+ size_t input_sz,
+ size_t& consumed,
+ secure_vector<byte>& record,
+ u64bit* record_sequence,
+ Protocol_Version* record_version,
+ Record_Type* record_type,
+ Connection_Sequence_Numbers* sequence_numbers,
+ std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate)
{
consumed = 0;
@@ -529,23 +527,10 @@ size_t read_record(secure_vector<byte>& readbuf,
*record_version = Protocol_Version(readbuf[1], readbuf[2]);
- const bool is_dtls = record_version->is_datagram_protocol();
+ BOTAN_ASSERT(!record_version->is_datagram_protocol(), "Expected TLS");
- if(is_dtls && readbuf.size() < DTLS_HEADER_SIZE)
- {
- if(size_t needed = fill_buffer_to(readbuf,
- input, input_sz, consumed,
- DTLS_HEADER_SIZE))
- return needed;
-
- BOTAN_ASSERT_EQUAL(readbuf.size(), DTLS_HEADER_SIZE,
- "Have an entire header");
- }
-
- const size_t header_size = (is_dtls) ? DTLS_HEADER_SIZE : TLS_HEADER_SIZE;
-
- const size_t record_len = make_u16bit(readbuf[header_size-2],
- readbuf[header_size-1]);
+ const size_t record_len = make_u16bit(readbuf[TLS_HEADER_SIZE-2],
+ readbuf[TLS_HEADER_SIZE-1]);
if(record_len > MAX_CIPHERTEXT_SIZE)
throw TLS_Exception(Alert::RECORD_OVERFLOW,
@@ -553,10 +538,10 @@ size_t read_record(secure_vector<byte>& readbuf,
if(size_t needed = fill_buffer_to(readbuf,
input, input_sz, consumed,
- header_size + record_len))
- return needed; // wrong for DTLS?
+ TLS_HEADER_SIZE + record_len))
+ return needed;
- BOTAN_ASSERT_EQUAL(static_cast<size_t>(header_size) + record_len,
+ BOTAN_ASSERT_EQUAL(static_cast<size_t>(TLS_HEADER_SIZE) + record_len,
readbuf.size(),
"Have the full record");
@@ -564,12 +549,7 @@ size_t read_record(secure_vector<byte>& readbuf,
u16bit epoch = 0;
- if(is_dtls)
- {
- *record_sequence = load_be<u64bit>(&readbuf[3], 0);
- epoch = (*record_sequence >> 48);
- }
- else if(sequence_numbers)
+ if(sequence_numbers)
{
*record_sequence = sequence_numbers->next_read_sequence();
epoch = sequence_numbers->current_read_epoch();
@@ -581,17 +561,11 @@ size_t read_record(secure_vector<byte>& readbuf,
epoch = 0;
}
- if(sequence_numbers && sequence_numbers->already_seen(*record_sequence))
- {
- readbuf.clear();
- return 0;
- }
-
- byte* record_contents = &readbuf[header_size];
+ byte* record_contents = &readbuf[TLS_HEADER_SIZE];
if(epoch == 0) // Unencrypted initial handshake
{
- record.assign(&readbuf[header_size], &readbuf[header_size + record_len]);
+ record.assign(&readbuf[TLS_HEADER_SIZE], &readbuf[TLS_HEADER_SIZE + record_len]);
readbuf.clear();
return 0; // got a full record
}
@@ -599,8 +573,6 @@ size_t read_record(secure_vector<byte>& readbuf,
// Otherwise, decrypt, check MAC, return plaintext
auto cipherstate = get_cipherstate(epoch);
- // FIXME: DTLS reordering might cause us not to have the cipher state
-
BOTAN_ASSERT(cipherstate, "Have cipherstate for this epoch");
decrypt_record(record,
@@ -618,6 +590,126 @@ size_t read_record(secure_vector<byte>& readbuf,
return 0;
}
+size_t read_dtls_record(secure_vector<byte>& readbuf,
+ const byte input[],
+ size_t input_sz,
+ size_t& consumed,
+ secure_vector<byte>& record,
+ u64bit* record_sequence,
+ Protocol_Version* record_version,
+ Record_Type* record_type,
+ Connection_Sequence_Numbers* sequence_numbers,
+ std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate)
+ {
+ consumed = 0;
+
+ if(readbuf.size() < DTLS_HEADER_SIZE) // header incomplete?
+ {
+ if(fill_buffer_to(readbuf, input, input_sz, consumed, DTLS_HEADER_SIZE))
+ {
+ readbuf.clear();
+ return 0;
+ }
+
+ BOTAN_ASSERT_EQUAL(readbuf.size(), DTLS_HEADER_SIZE, "Have an entire header");
+ }
+
+ *record_version = Protocol_Version(readbuf[1], readbuf[2]);
+
+ BOTAN_ASSERT(record_version->is_datagram_protocol(), "Expected DTLS");
+
+ const size_t record_len = make_u16bit(readbuf[DTLS_HEADER_SIZE-2],
+ readbuf[DTLS_HEADER_SIZE-1]);
+
+ if(record_len > MAX_CIPHERTEXT_SIZE)
+ throw TLS_Exception(Alert::RECORD_OVERFLOW,
+ "Got message that exceeds maximum size");
+
+ if(fill_buffer_to(readbuf, input, input_sz, consumed, DTLS_HEADER_SIZE + record_len))
+ {
+ // Truncated packet?
+ readbuf.clear();
+ return 0;
+ }
+
+ BOTAN_ASSERT_EQUAL(static_cast<size_t>(DTLS_HEADER_SIZE) + record_len, readbuf.size(),
+ "Have the full record");
+
+ *record_type = static_cast<Record_Type>(readbuf[0]);
+
+ u16bit epoch = 0;
+
+ *record_sequence = load_be<u64bit>(&readbuf[3], 0);
+ epoch = (*record_sequence >> 48);
+
+ if(sequence_numbers && sequence_numbers->already_seen(*record_sequence))
+ {
+ readbuf.clear();
+ return 0;
+ }
+
+ byte* record_contents = &readbuf[DTLS_HEADER_SIZE];
+
+ if(epoch == 0) // Unencrypted initial handshake
+ {
+ record.assign(&readbuf[DTLS_HEADER_SIZE], &readbuf[DTLS_HEADER_SIZE + record_len]);
+ readbuf.clear();
+ return 0; // got a full record
+ }
+
+ try
+ {
+ // Otherwise, decrypt, check MAC, return plaintext
+ auto cipherstate = get_cipherstate(epoch);
+
+ BOTAN_ASSERT(cipherstate, "Have cipherstate for this epoch");
+
+ decrypt_record(record,
+ record_contents,
+ record_len,
+ *record_sequence,
+ *record_version,
+ *record_type,
+ *cipherstate);
+ }
+ catch(std::exception)
+ {
+ readbuf.clear();
+ *record_type = NO_RECORD;
+ return 0;
+ }
+
+ if(sequence_numbers)
+ sequence_numbers->read_accept(*record_sequence);
+
+ readbuf.clear();
+ return 0;
+ }
+
+}
+
+size_t read_record(secure_vector<byte>& readbuf,
+ const byte input[],
+ size_t input_sz,
+ bool is_datagram,
+ size_t& consumed,
+ secure_vector<byte>& record,
+ u64bit* record_sequence,
+ Protocol_Version* record_version,
+ Record_Type* record_type,
+ Connection_Sequence_Numbers* sequence_numbers,
+ std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate)
+ {
+ if(is_datagram)
+ return read_dtls_record(readbuf, input, input_sz, consumed,
+ record, record_sequence, record_version, record_type,
+ sequence_numbers, get_cipherstate);
+ else
+ return read_tls_record(readbuf, input, input_sz, consumed,
+ record, record_sequence, record_version, record_type,
+ sequence_numbers, get_cipherstate);
+ }
+
}
}