diff options
author | lloyd <[email protected]> | 2014-11-15 23:39:24 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2014-11-15 23:39:24 +0000 |
commit | 060df7809a64d1b589554169443c48bc428ca726 (patch) | |
tree | 74ca96453ddb4bd3a8abca43fb81d67859c9f6f8 /src | |
parent | 9751f1a9084aadbfebbc7f7e67fcd5806ead6492 (diff) |
A TLS Server can now process either TLS or DTLS but not either,
with the setting set in the constructor. This prevents various surprising
things from happening to applications and simplifies record processing.
Diffstat (limited to 'src')
-rw-r--r-- | src/cmd/tls_server.cpp | 3 | ||||
-rw-r--r-- | src/lib/tls/tls_channel.cpp | 50 | ||||
-rw-r--r-- | src/lib/tls/tls_channel.h | 3 | ||||
-rw-r--r-- | src/lib/tls/tls_client.cpp | 3 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_io.cpp | 1 | ||||
-rw-r--r-- | src/lib/tls/tls_magic.h | 4 | ||||
-rw-r--r-- | src/lib/tls/tls_policy.cpp | 4 | ||||
-rw-r--r-- | src/lib/tls/tls_record.cpp | 186 | ||||
-rw-r--r-- | src/lib/tls/tls_record.h | 1 | ||||
-rw-r--r-- | src/lib/tls/tls_server.cpp | 3 | ||||
-rw-r--r-- | src/lib/tls/tls_server.h | 1 |
11 files changed, 183 insertions, 76 deletions
diff --git a/src/cmd/tls_server.cpp b/src/cmd/tls_server.cpp index fd9f7ef5d..a892835dc 100644 --- a/src/cmd/tls_server.cpp +++ b/src/cmd/tls_server.cpp @@ -205,7 +205,8 @@ int tls_server(int argc, char* argv[]) creds, policy, rng, - protocols); + protocols, + (transport != "tcp")); while(!server.is_closed()) { diff --git a/src/lib/tls/tls_channel.cpp b/src/lib/tls/tls_channel.cpp index 25307166b..76332f7d2 100644 --- a/src/lib/tls/tls_channel.cpp +++ b/src/lib/tls/tls_channel.cpp @@ -25,7 +25,9 @@ Channel::Channel(std::function<void (const byte[], size_t)> output_fn, std::function<bool (const Session&)> handshake_cb, Session_Manager& session_manager, RandomNumberGenerator& rng, + bool is_datagram, size_t reserved_io_buffer_size) : + m_is_datagram(is_datagram), m_handshake_cb(handshake_cb), m_data_cb(data_cb), m_alert_cb(alert_cb), @@ -142,6 +144,8 @@ bool Channel::timeout_check() { if(m_pending_state) return m_pending_state->handshake_io().timeout_check(); + + //FIXME: scan cipher suites and remove epochs older than 2*MSL return false; } @@ -252,11 +256,7 @@ void Channel::activate_session() std::swap(m_active_state, m_pending_state); m_pending_state.reset(); - if(m_active_state->version().is_datagram_protocol()) - { - // FIXME, remove old states when we are sure not needed anymore - } - else + if(!m_active_state->version().is_datagram_protocol()) { // TLS is easy just remove all but the current state auto current_epoch = sequence_numbers().current_write_epoch(); @@ -307,6 +307,7 @@ size_t Channel::received_data(const byte input[], size_t input_size) read_record(m_readbuf, input, input_size, + m_is_datagram, consumed, record, &record_sequence, @@ -340,24 +341,31 @@ size_t Channel::received_data(const byte input[], size_t input_size) { if(record_version.is_datagram_protocol()) { - sequence_numbers().read_accept(record_sequence); - - /* - * Might be a peer retransmit under epoch - 1 in which - * case we must retransmit last flight - */ - - const u16bit epoch = record_sequence >> 48; - - if(epoch == sequence_numbers().current_read_epoch()) + if(m_sequence_numbers) { - create_handshake_state(record_version); + /* + * Might be a peer retransmit under epoch - 1 in which + * case we must retransmit last flight + */ + sequence_numbers().read_accept(record_sequence); + + const u16bit epoch = record_sequence >> 48; + + if(epoch == sequence_numbers().current_read_epoch()) + { + create_handshake_state(record_version); + } + else if(epoch == sequence_numbers().current_read_epoch() - 1) + { + BOTAN_ASSERT(m_active_state, "Have active state here"); + m_active_state->handshake_io().add_record(unlock(record), + record_type, + record_sequence); + } } - else if(epoch == sequence_numbers().current_read_epoch() - 1) + else if(record_sequence == 0) { - m_active_state->handshake_io().add_record(unlock(record), - record_type, - record_sequence); + create_handshake_state(record_version); } } else @@ -445,7 +453,7 @@ size_t Channel::received_data(const byte input[], size_t input_size) return 0; } } - else + else if(record_type != NO_RECORD) throw Unexpected_Message("Unexpected record type " + std::to_string(record_type) + " from counterparty"); diff --git a/src/lib/tls/tls_channel.h b/src/lib/tls/tls_channel.h index 3cdfe3d5e..8aea2dab0 100644 --- a/src/lib/tls/tls_channel.h +++ b/src/lib/tls/tls_channel.h @@ -164,6 +164,7 @@ class BOTAN_DLL Channel std::function<bool (const Session&)> handshake_cb, Session_Manager& session_manager, RandomNumberGenerator& rng, + bool is_datagram, size_t reserved_io_buffer_size); Channel(const Channel&) = delete; @@ -234,6 +235,8 @@ class BOTAN_DLL Channel const Handshake_State* pending_state() const { return m_pending_state.get(); } + bool m_is_datagram; + /* callbacks */ std::function<bool (const Session&)> m_handshake_cb; std::function<void (const byte[], size_t)> m_data_cb; diff --git a/src/lib/tls/tls_client.cpp b/src/lib/tls/tls_client.cpp index 86d1998e1..7c3e48ca6 100644 --- a/src/lib/tls/tls_client.cpp +++ b/src/lib/tls/tls_client.cpp @@ -59,7 +59,8 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn, const Protocol_Version offer_version, std::function<std::string (std::vector<std::string>)> next_protocol, size_t io_buf_sz) : - Channel(output_fn, proc_cb, alert_cb, handshake_cb, session_manager, rng, io_buf_sz), + Channel(output_fn, proc_cb, alert_cb, handshake_cb, session_manager, rng, + offer_version.is_datagram_protocol(), io_buf_sz), m_policy(policy), m_creds(creds), m_info(info) diff --git a/src/lib/tls/tls_handshake_io.cpp b/src/lib/tls/tls_handshake_io.cpp index da27cc4ce..659818139 100644 --- a/src/lib/tls/tls_handshake_io.cpp +++ b/src/lib/tls/tls_handshake_io.cpp @@ -70,6 +70,7 @@ Stream_Handshake_IO::get_next_record(bool) if(m_queue.size() >= length + 4) { Handshake_Type type = static_cast<Handshake_Type>(m_queue[0]); + BOTAN_ASSERT(type < 250, "Not in reserved range"); std::vector<byte> contents(m_queue.begin() + 4, m_queue.begin() + 4 + length); diff --git a/src/lib/tls/tls_magic.h b/src/lib/tls/tls_magic.h index 51f1fce47..e22ab7248 100644 --- a/src/lib/tls/tls_magic.h +++ b/src/lib/tls/tls_magic.h @@ -27,13 +27,13 @@ enum Size_Limits { enum Connection_Side { CLIENT = 1, SERVER = 2 }; enum Record_Type { - NO_RECORD = 0, - CHANGE_CIPHER_SPEC = 20, ALERT = 21, HANDSHAKE = 22, APPLICATION_DATA = 23, HEARTBEAT = 24, + + NO_RECORD = 256 }; enum Handshake_Type { diff --git a/src/lib/tls/tls_policy.cpp b/src/lib/tls/tls_policy.cpp index c4867d81a..0f2190562 100644 --- a/src/lib/tls/tls_policy.cpp +++ b/src/lib/tls/tls_policy.cpp @@ -146,10 +146,8 @@ bool Policy::send_fallback_scsv(Protocol_Version version) const bool Policy::acceptable_protocol_version(Protocol_Version version) const { - // By default require TLS to minimize surprise if(version.is_datagram_protocol()) - return false; - + return (version >= Protocol_Version::DTLS_V12); return (version >= Protocol_Version::TLS_V10); } diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp index 925961764..0b356fad3 100644 --- a/src/lib/tls/tls_record.cpp +++ b/src/lib/tls/tls_record.cpp @@ -464,18 +464,16 @@ void decrypt_record(secure_vector<byte>& output, } } -} - -size_t read_record(secure_vector<byte>& readbuf, - const byte input[], - size_t input_sz, - size_t& consumed, - secure_vector<byte>& record, - u64bit* record_sequence, - Protocol_Version* record_version, - Record_Type* record_type, - Connection_Sequence_Numbers* sequence_numbers, - std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate) +size_t read_tls_record(secure_vector<byte>& readbuf, + const byte input[], + size_t input_sz, + size_t& consumed, + secure_vector<byte>& record, + u64bit* record_sequence, + Protocol_Version* record_version, + Record_Type* record_type, + Connection_Sequence_Numbers* sequence_numbers, + std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate) { consumed = 0; @@ -529,23 +527,10 @@ size_t read_record(secure_vector<byte>& readbuf, *record_version = Protocol_Version(readbuf[1], readbuf[2]); - const bool is_dtls = record_version->is_datagram_protocol(); + BOTAN_ASSERT(!record_version->is_datagram_protocol(), "Expected TLS"); - if(is_dtls && readbuf.size() < DTLS_HEADER_SIZE) - { - if(size_t needed = fill_buffer_to(readbuf, - input, input_sz, consumed, - DTLS_HEADER_SIZE)) - return needed; - - BOTAN_ASSERT_EQUAL(readbuf.size(), DTLS_HEADER_SIZE, - "Have an entire header"); - } - - const size_t header_size = (is_dtls) ? DTLS_HEADER_SIZE : TLS_HEADER_SIZE; - - const size_t record_len = make_u16bit(readbuf[header_size-2], - readbuf[header_size-1]); + const size_t record_len = make_u16bit(readbuf[TLS_HEADER_SIZE-2], + readbuf[TLS_HEADER_SIZE-1]); if(record_len > MAX_CIPHERTEXT_SIZE) throw TLS_Exception(Alert::RECORD_OVERFLOW, @@ -553,10 +538,10 @@ size_t read_record(secure_vector<byte>& readbuf, if(size_t needed = fill_buffer_to(readbuf, input, input_sz, consumed, - header_size + record_len)) - return needed; // wrong for DTLS? + TLS_HEADER_SIZE + record_len)) + return needed; - BOTAN_ASSERT_EQUAL(static_cast<size_t>(header_size) + record_len, + BOTAN_ASSERT_EQUAL(static_cast<size_t>(TLS_HEADER_SIZE) + record_len, readbuf.size(), "Have the full record"); @@ -564,12 +549,7 @@ size_t read_record(secure_vector<byte>& readbuf, u16bit epoch = 0; - if(is_dtls) - { - *record_sequence = load_be<u64bit>(&readbuf[3], 0); - epoch = (*record_sequence >> 48); - } - else if(sequence_numbers) + if(sequence_numbers) { *record_sequence = sequence_numbers->next_read_sequence(); epoch = sequence_numbers->current_read_epoch(); @@ -581,17 +561,11 @@ size_t read_record(secure_vector<byte>& readbuf, epoch = 0; } - if(sequence_numbers && sequence_numbers->already_seen(*record_sequence)) - { - readbuf.clear(); - return 0; - } - - byte* record_contents = &readbuf[header_size]; + byte* record_contents = &readbuf[TLS_HEADER_SIZE]; if(epoch == 0) // Unencrypted initial handshake { - record.assign(&readbuf[header_size], &readbuf[header_size + record_len]); + record.assign(&readbuf[TLS_HEADER_SIZE], &readbuf[TLS_HEADER_SIZE + record_len]); readbuf.clear(); return 0; // got a full record } @@ -599,8 +573,6 @@ size_t read_record(secure_vector<byte>& readbuf, // Otherwise, decrypt, check MAC, return plaintext auto cipherstate = get_cipherstate(epoch); - // FIXME: DTLS reordering might cause us not to have the cipher state - BOTAN_ASSERT(cipherstate, "Have cipherstate for this epoch"); decrypt_record(record, @@ -618,6 +590,126 @@ size_t read_record(secure_vector<byte>& readbuf, return 0; } +size_t read_dtls_record(secure_vector<byte>& readbuf, + const byte input[], + size_t input_sz, + size_t& consumed, + secure_vector<byte>& record, + u64bit* record_sequence, + Protocol_Version* record_version, + Record_Type* record_type, + Connection_Sequence_Numbers* sequence_numbers, + std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate) + { + consumed = 0; + + if(readbuf.size() < DTLS_HEADER_SIZE) // header incomplete? + { + if(fill_buffer_to(readbuf, input, input_sz, consumed, DTLS_HEADER_SIZE)) + { + readbuf.clear(); + return 0; + } + + BOTAN_ASSERT_EQUAL(readbuf.size(), DTLS_HEADER_SIZE, "Have an entire header"); + } + + *record_version = Protocol_Version(readbuf[1], readbuf[2]); + + BOTAN_ASSERT(record_version->is_datagram_protocol(), "Expected DTLS"); + + const size_t record_len = make_u16bit(readbuf[DTLS_HEADER_SIZE-2], + readbuf[DTLS_HEADER_SIZE-1]); + + if(record_len > MAX_CIPHERTEXT_SIZE) + throw TLS_Exception(Alert::RECORD_OVERFLOW, + "Got message that exceeds maximum size"); + + if(fill_buffer_to(readbuf, input, input_sz, consumed, DTLS_HEADER_SIZE + record_len)) + { + // Truncated packet? + readbuf.clear(); + return 0; + } + + BOTAN_ASSERT_EQUAL(static_cast<size_t>(DTLS_HEADER_SIZE) + record_len, readbuf.size(), + "Have the full record"); + + *record_type = static_cast<Record_Type>(readbuf[0]); + + u16bit epoch = 0; + + *record_sequence = load_be<u64bit>(&readbuf[3], 0); + epoch = (*record_sequence >> 48); + + if(sequence_numbers && sequence_numbers->already_seen(*record_sequence)) + { + readbuf.clear(); + return 0; + } + + byte* record_contents = &readbuf[DTLS_HEADER_SIZE]; + + if(epoch == 0) // Unencrypted initial handshake + { + record.assign(&readbuf[DTLS_HEADER_SIZE], &readbuf[DTLS_HEADER_SIZE + record_len]); + readbuf.clear(); + return 0; // got a full record + } + + try + { + // Otherwise, decrypt, check MAC, return plaintext + auto cipherstate = get_cipherstate(epoch); + + BOTAN_ASSERT(cipherstate, "Have cipherstate for this epoch"); + + decrypt_record(record, + record_contents, + record_len, + *record_sequence, + *record_version, + *record_type, + *cipherstate); + } + catch(std::exception) + { + readbuf.clear(); + *record_type = NO_RECORD; + return 0; + } + + if(sequence_numbers) + sequence_numbers->read_accept(*record_sequence); + + readbuf.clear(); + return 0; + } + +} + +size_t read_record(secure_vector<byte>& readbuf, + const byte input[], + size_t input_sz, + bool is_datagram, + size_t& consumed, + secure_vector<byte>& record, + u64bit* record_sequence, + Protocol_Version* record_version, + Record_Type* record_type, + Connection_Sequence_Numbers* sequence_numbers, + std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cipherstate) + { + if(is_datagram) + return read_dtls_record(readbuf, input, input_sz, consumed, + record, record_sequence, record_version, record_type, + sequence_numbers, get_cipherstate); + else + return read_tls_record(readbuf, input, input_sz, consumed, + record, record_sequence, record_version, record_type, + sequence_numbers, get_cipherstate); + } + } } diff --git a/src/lib/tls/tls_record.h b/src/lib/tls/tls_record.h index 8431e68c0..2dae96164 100644 --- a/src/lib/tls/tls_record.h +++ b/src/lib/tls/tls_record.h @@ -122,6 +122,7 @@ void write_record(secure_vector<byte>& write_buffer, size_t read_record(secure_vector<byte>& read_buffer, const byte input[], size_t input_length, + bool is_datagram, size_t& input_consumed, secure_vector<byte>& record, u64bit* record_sequence, diff --git a/src/lib/tls/tls_server.cpp b/src/lib/tls/tls_server.cpp index ff285881a..9b8a0d811 100644 --- a/src/lib/tls/tls_server.cpp +++ b/src/lib/tls/tls_server.cpp @@ -216,8 +216,9 @@ Server::Server(std::function<void (const byte[], size_t)> output_fn, const Policy& policy, RandomNumberGenerator& rng, const std::vector<std::string>& next_protocols, + bool is_datagram, size_t io_buf_sz) : - Channel(output_fn, data_cb, alert_cb, handshake_cb, session_manager, rng, io_buf_sz), + Channel(output_fn, data_cb, alert_cb, handshake_cb, session_manager, rng, is_datagram, io_buf_sz), m_policy(policy), m_creds(creds), m_possible_protocols(next_protocols) diff --git a/src/lib/tls/tls_server.h b/src/lib/tls/tls_server.h index a514607ba..c0646bdbc 100644 --- a/src/lib/tls/tls_server.h +++ b/src/lib/tls/tls_server.h @@ -34,6 +34,7 @@ class BOTAN_DLL Server : public Channel const Policy& policy, RandomNumberGenerator& rng, const std::vector<std::string>& protocols = std::vector<std::string>(), + bool is_datagram = false, size_t reserved_io_buffer_size = 16*1024 ); |