diff options
Diffstat (limited to 'src/lib')
-rw-r--r-- | src/lib/tls/tls_channel.cpp | 50 | ||||
-rw-r--r-- | src/lib/tls/tls_channel.h | 3 | ||||
-rw-r--r-- | src/lib/tls/tls_client.cpp | 3 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_io.cpp | 1 | ||||
-rw-r--r-- | src/lib/tls/tls_magic.h | 4 | ||||
-rw-r--r-- | src/lib/tls/tls_policy.cpp | 4 | ||||
-rw-r--r-- | src/lib/tls/tls_record.cpp | 186 | ||||
-rw-r--r-- | src/lib/tls/tls_record.h | 1 | ||||
-rw-r--r-- | src/lib/tls/tls_server.cpp | 3 | ||||
-rw-r--r-- | src/lib/tls/tls_server.h | 1 |
10 files changed, 181 insertions, 75 deletions
diff --git a/src/lib/tls/tls_channel.cpp b/src/lib/tls/tls_channel.cpp index 25307166b..76332f7d2 100644 --- a/src/lib/tls/tls_channel.cpp +++ b/src/lib/tls/tls_channel.cpp @@ -25,7 +25,9 @@ Channel::Channel(std::function<void (const byte[], size_t)> output_fn, std::function<bool (const Session&)> handshake_cb, Session_Manager& session_manager, RandomNumberGenerator& rng, + bool is_datagram, size_t reserved_io_buffer_size) : + m_is_datagram(is_datagram), m_handshake_cb(handshake_cb), m_data_cb(data_cb), m_alert_cb(alert_cb), @@ -142,6 +144,8 @@ bool Channel::timeout_check() { if(m_pending_state) return m_pending_state->handshake_io().timeout_check(); + + //FIXME: scan cipher suites and remove epochs older than 2*MSL return false; } @@ -252,11 +256,7 @@ void Channel::activate_session() std::swap(m_active_state, m_pending_state); m_pending_state.reset(); - if(m_active_state->version().is_datagram_protocol()) - { - // FIXME, remove old states when we are sure not needed anymore - } - else + if(!m_active_state->version().is_datagram_protocol()) { // TLS is easy just remove all but the current state auto current_epoch = sequence_numbers().current_write_epoch(); @@ -307,6 +307,7 @@ size_t Channel::received_data(const byte input[], size_t input_size) read_record(m_readbuf, input, input_size, + m_is_datagram, consumed, record, &record_sequence, @@ -340,24 +341,31 @@ size_t Channel::received_data(const byte input[], size_t input_size) { if(record_version.is_datagram_protocol()) { - sequence_numbers().read_accept(record_sequence); - - /* - * Might be a peer retransmit under epoch - 1 in which - * case we must retransmit last flight - */ - - const u16bit epoch = record_sequence >> 48; - - if(epoch == sequence_numbers().current_read_epoch()) + if(m_sequence_numbers) { - create_handshake_state(record_version); + /* + * Might be a peer retransmit under epoch - 1 in which + * case we must retransmit last flight + */ + sequence_numbers().read_accept(record_sequence); + + const u16bit epoch = record_sequence >> 48; + + if(epoch == sequence_numbers().current_read_epoch()) + { + create_handshake_state(record_version); + } + else if(epoch == sequence_numbers().current_read_epoch() - 1) + { + BOTAN_ASSERT(m_active_state, "Have active state here"); + m_active_state->handshake_io().add_record(unlock(record), + record_type, + record_sequence); + } } - else if(epoch == sequence_numbers().current_read_epoch() - 1) + else if(record_sequence == 0) { - m_active_state->handshake_io().add_record(unlock(record), - record_type, - record_sequence); + create_handshake_state(record_version); } } else @@ -445,7 +453,7 @@ size_t Channel::received_data(const byte input[], size_t input_size) return 0; } } - else + else if(record_type != NO_RECORD) throw Unexpected_Message("Unexpected record type " + std::to_string(record_type) + " from counterparty"); diff --git a/src/lib/tls/tls_channel.h b/src/lib/tls/tls_channel.h index 3cdfe3d5e..8aea2dab0 100644 --- a/src/lib/tls/tls_channel.h +++ b/src/lib/tls/tls_channel.h @@ -164,6 +164,7 @@ class BOTAN_DLL Channel std::function<bool (const Session&)> handshake_cb, Session_Manager& session_manager, RandomNumberGenerator& rng, + bool is_datagram, size_t reserved_io_buffer_size); Channel(const Channel&) = delete; @@ -234,6 +235,8 @@ class BOTAN_DLL Channel const Handshake_State* pending_state() const { return m_pending_state.get(); } + bool m_is_datagram; + /* callbacks */ std::function<bool (const Session&)> m_handshake_cb; std::function<void (const byte[], size_t)> m_data_cb; diff --git a/src/lib/tls/tls_client.cpp b/src/lib/tls/tls_client.cpp index 86d1998e1..7c3e48ca6 100644 --- a/src/lib/tls/tls_client.cpp +++ b/src/lib/tls/tls_client.cpp @@ -59,7 +59,8 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn, const Protocol_Version offer_version, std::function<std::string (std::vector<std::string>)> next_protocol, size_t io_buf_sz) : - Channel(output_fn, proc_cb, alert_cb, handshake_cb, session_manager, rng, io_buf_sz), + Channel(output_fn, proc_cb, alert_cb, handshake_cb, session_manager, rng, + offer_version.is_datagram_protocol(), io_buf_sz), m_policy(policy), m_creds(creds), m_info(info) diff --git a/src/lib/tls/tls_handshake_io.cpp b/src/lib/tls/tls_handshake_io.cpp index da27cc4ce..659818139 100644 --- a/src/lib/tls/tls_handshake_io.cpp +++ b/src/lib/tls/tls_handshake_io.cpp @@ -70,6 +70,7 @@ Stream_Handshake_IO::get_next_record(bool) if(m_queue.size() >= length + 4) { Handshake_Type type = static_cast<Handshake_Type>(m_queue[0]); + BOTAN_ASSERT(type < 250, "Not in reserved range"); std::vector<byte> contents(m_queue.begin() + 4, m_queue.begin() + 4 + length); diff --git a/src/lib/tls/tls_magic.h b/src/lib/tls/tls_magic.h index 51f1fce47..e22ab7248 100644 --- a/src/lib/tls/tls_magic.h +++ b/src/lib/tls/tls_magic.h @@ -27,13 +27,13 @@ enum Size_Limits { enum Connection_Side { CLIENT = 1, SERVER = 2 }; enum Record_Type { - NO_RECORD = 0, - CHANGE_CIPHER_SPEC = 20, ALERT = 21, HANDSHAKE = 22, APPLICATION_DATA = 23, HEARTBEAT = 24, + + NO_RECORD = 256 }; enum Handshake_Type { diff --git a/src/lib/tls/tls_policy.cpp b/src/lib/tls/tls_policy.cpp index c4867d81a..0f2190562 100644 --- a/src/lib/tls/tls_policy.cpp +++ b/src/lib/tls/tls_policy.cpp @@ -146,10 +146,8 @@ bool Policy::send_fallback_scsv(Protocol_Version version) const bool Policy::acceptable_protocol_version(Protocol_Version version) const { - // By default require TLS to minimize surprise if(version.is_datagram_protocol()) - return false; - + return (version >= Protocol_Version::DTLS_V12); return (version >= Protocol_Version::TLS_V10); } diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp index 925961764..0b356fad3 100644 --- a/src/lib/tls/tls_record.cpp +++ b/src/lib/tls/tls_record.cpp @@ -464,18 +464,16 @@ void decrypt_record(secure_vector<byte>& output, } } -} - -size_t read_record(secure_vector<byte>& readbuf, - const byte input[], - size_t input_sz, - size_t& consumed, - secure_vector<byte>& record, - u64bit* record_sequence, - Protocol_Version* record_version, - Record_Type* record_type, - Connection_Sequence_Numbers* sequence_numbers, - std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate) +size_t read_tls_record(secure_vector<byte>& readbuf, + const byte input[], + size_t input_sz, + size_t& consumed, + secure_vector<byte>& record, + u64bit* record_sequence, + Protocol_Version* record_version, + Record_Type* record_type, + Connection_Sequence_Numbers* sequence_numbers, + std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate) { consumed = 0; @@ -529,23 +527,10 @@ size_t read_record(secure_vector<byte>& readbuf, *record_version = Protocol_Version(readbuf[1], readbuf[2]); - const bool is_dtls = record_version->is_datagram_protocol(); + BOTAN_ASSERT(!record_version->is_datagram_protocol(), "Expected TLS"); - if(is_dtls && readbuf.size() < DTLS_HEADER_SIZE) - { - if(size_t needed = fill_buffer_to(readbuf, - input, input_sz, consumed, - DTLS_HEADER_SIZE)) - return needed; - - BOTAN_ASSERT_EQUAL(readbuf.size(), DTLS_HEADER_SIZE, - "Have an entire header"); - } - - const size_t header_size = (is_dtls) ? DTLS_HEADER_SIZE : TLS_HEADER_SIZE; - - const size_t record_len = make_u16bit(readbuf[header_size-2], - readbuf[header_size-1]); + const size_t record_len = make_u16bit(readbuf[TLS_HEADER_SIZE-2], + readbuf[TLS_HEADER_SIZE-1]); if(record_len > MAX_CIPHERTEXT_SIZE) throw TLS_Exception(Alert::RECORD_OVERFLOW, @@ -553,10 +538,10 @@ size_t read_record(secure_vector<byte>& readbuf, if(size_t needed = fill_buffer_to(readbuf, input, input_sz, consumed, - header_size + record_len)) - return needed; // wrong for DTLS? + TLS_HEADER_SIZE + record_len)) + return needed; - BOTAN_ASSERT_EQUAL(static_cast<size_t>(header_size) + record_len, + BOTAN_ASSERT_EQUAL(static_cast<size_t>(TLS_HEADER_SIZE) + record_len, readbuf.size(), "Have the full record"); @@ -564,12 +549,7 @@ size_t read_record(secure_vector<byte>& readbuf, u16bit epoch = 0; - if(is_dtls) - { - *record_sequence = load_be<u64bit>(&readbuf[3], 0); - epoch = (*record_sequence >> 48); - } - else if(sequence_numbers) + if(sequence_numbers) { *record_sequence = sequence_numbers->next_read_sequence(); epoch = sequence_numbers->current_read_epoch(); @@ -581,17 +561,11 @@ size_t read_record(secure_vector<byte>& readbuf, epoch = 0; } - if(sequence_numbers && sequence_numbers->already_seen(*record_sequence)) - { - readbuf.clear(); - return 0; - } - - byte* record_contents = &readbuf[header_size]; + byte* record_contents = &readbuf[TLS_HEADER_SIZE]; if(epoch == 0) // Unencrypted initial handshake { - record.assign(&readbuf[header_size], &readbuf[header_size + record_len]); + record.assign(&readbuf[TLS_HEADER_SIZE], &readbuf[TLS_HEADER_SIZE + record_len]); readbuf.clear(); return 0; // got a full record } @@ -599,8 +573,6 @@ size_t read_record(secure_vector<byte>& readbuf, // Otherwise, decrypt, check MAC, return plaintext auto cipherstate = get_cipherstate(epoch); - // FIXME: DTLS reordering might cause us not to have the cipher state - BOTAN_ASSERT(cipherstate, "Have cipherstate for this epoch"); decrypt_record(record, @@ -618,6 +590,126 @@ size_t read_record(secure_vector<byte>& readbuf, return 0; } +size_t read_dtls_record(secure_vector<byte>& readbuf, + const byte input[], + size_t input_sz, + size_t& consumed, + secure_vector<byte>& record, + u64bit* record_sequence, + Protocol_Version* record_version, + Record_Type* record_type, + Connection_Sequence_Numbers* sequence_numbers, + std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate) + { + consumed = 0; + + if(readbuf.size() < DTLS_HEADER_SIZE) // header incomplete? + { + if(fill_buffer_to(readbuf, input, input_sz, consumed, DTLS_HEADER_SIZE)) + { + readbuf.clear(); + return 0; + } + + BOTAN_ASSERT_EQUAL(readbuf.size(), DTLS_HEADER_SIZE, "Have an entire header"); + } + + *record_version = Protocol_Version(readbuf[1], readbuf[2]); + + BOTAN_ASSERT(record_version->is_datagram_protocol(), "Expected DTLS"); + + const size_t record_len = make_u16bit(readbuf[DTLS_HEADER_SIZE-2], + readbuf[DTLS_HEADER_SIZE-1]); + + if(record_len > MAX_CIPHERTEXT_SIZE) + throw TLS_Exception(Alert::RECORD_OVERFLOW, + "Got message that exceeds maximum size"); + + if(fill_buffer_to(readbuf, input, input_sz, consumed, DTLS_HEADER_SIZE + record_len)) + { + // Truncated packet? + readbuf.clear(); + return 0; + } + + BOTAN_ASSERT_EQUAL(static_cast<size_t>(DTLS_HEADER_SIZE) + record_len, readbuf.size(), + "Have the full record"); + + *record_type = static_cast<Record_Type>(readbuf[0]); + + u16bit epoch = 0; + + *record_sequence = load_be<u64bit>(&readbuf[3], 0); + epoch = (*record_sequence >> 48); + + if(sequence_numbers && sequence_numbers->already_seen(*record_sequence)) + { + readbuf.clear(); + return 0; + } + + byte* record_contents = &readbuf[DTLS_HEADER_SIZE]; + + if(epoch == 0) // Unencrypted initial handshake + { + record.assign(&readbuf[DTLS_HEADER_SIZE], &readbuf[DTLS_HEADER_SIZE + record_len]); + readbuf.clear(); + return 0; // got a full record + } + + try + { + // Otherwise, decrypt, check MAC, return plaintext + auto cipherstate = get_cipherstate(epoch); + + BOTAN_ASSERT(cipherstate, "Have cipherstate for this epoch"); + + decrypt_record(record, + record_contents, + record_len, + *record_sequence, + *record_version, + *record_type, + *cipherstate); + } + catch(std::exception) + { + readbuf.clear(); + *record_type = NO_RECORD; + return 0; + } + + if(sequence_numbers) + sequence_numbers->read_accept(*record_sequence); + + readbuf.clear(); + return 0; + } + +} + +size_t read_record(secure_vector<byte>& readbuf, + const byte input[], + size_t input_sz, + bool is_datagram, + size_t& consumed, + secure_vector<byte>& record, + u64bit* record_sequence, + Protocol_Version* record_version, + Record_Type* record_type, + Connection_Sequence_Numbers* sequence_numbers, + std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate) + { + if(is_datagram) + return read_dtls_record(readbuf, input, input_sz, consumed, + record, record_sequence, record_version, record_type, + sequence_numbers, get_cipherstate); + else + return read_tls_record(readbuf, input, input_sz, consumed, + record, record_sequence, record_version, record_type, + sequence_numbers, get_cipherstate); + } + } } diff --git a/src/lib/tls/tls_record.h b/src/lib/tls/tls_record.h index 8431e68c0..2dae96164 100644 --- a/src/lib/tls/tls_record.h +++ b/src/lib/tls/tls_record.h @@ -122,6 +122,7 @@ void write_record(secure_vector<byte>& write_buffer, size_t read_record(secure_vector<byte>& read_buffer, const byte input[], size_t input_length, + bool is_datagram, size_t& input_consumed, secure_vector<byte>& record, u64bit* record_sequence, diff --git a/src/lib/tls/tls_server.cpp b/src/lib/tls/tls_server.cpp index ff285881a..9b8a0d811 100644 --- a/src/lib/tls/tls_server.cpp +++ b/src/lib/tls/tls_server.cpp @@ -216,8 +216,9 @@ Server::Server(std::function<void (const byte[], size_t)> output_fn, const Policy& policy, RandomNumberGenerator& rng, const std::vector<std::string>& next_protocols, + bool is_datagram, size_t io_buf_sz) : - Channel(output_fn, data_cb, alert_cb, handshake_cb, session_manager, rng, io_buf_sz), + Channel(output_fn, data_cb, alert_cb, handshake_cb, session_manager, rng, is_datagram, io_buf_sz), m_policy(policy), m_creds(creds), m_possible_protocols(next_protocols) diff --git a/src/lib/tls/tls_server.h b/src/lib/tls/tls_server.h index a514607ba..c0646bdbc 100644 --- a/src/lib/tls/tls_server.h +++ b/src/lib/tls/tls_server.h @@ -34,6 +34,7 @@ class BOTAN_DLL Server : public Channel const Policy& policy, RandomNumberGenerator& rng, const std::vector<std::string>& protocols = std::vector<std::string>(), + bool is_datagram = false, size_t reserved_io_buffer_size = 16*1024 ); |