diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/cmd/tls_client.cpp | 24 | ||||
-rw-r--r-- | src/cmd/tls_server.cpp | 13 | ||||
-rw-r--r-- | src/lib/tls/info.txt | 34 | ||||
-rw-r--r-- | src/lib/tls/msg_client_hello.cpp | 12 | ||||
-rw-r--r-- | src/lib/tls/msg_next_protocol.cpp | 55 | ||||
-rw-r--r-- | src/lib/tls/msg_server_hello.cpp | 12 | ||||
-rw-r--r-- | src/lib/tls/tls_alert.cpp | 2 | ||||
-rw-r--r-- | src/lib/tls/tls_alert.h | 2 | ||||
-rw-r--r-- | src/lib/tls/tls_blocking.cpp | 4 | ||||
-rw-r--r-- | src/lib/tls/tls_blocking.h | 20 | ||||
-rw-r--r-- | src/lib/tls/tls_client.cpp | 63 | ||||
-rw-r--r-- | src/lib/tls/tls_client.h | 18 | ||||
-rw-r--r-- | src/lib/tls/tls_extensions.cpp | 39 | ||||
-rw-r--r-- | src/lib/tls/tls_extensions.h | 35 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_state.cpp | 15 | ||||
-rw-r--r-- | src/lib/tls/tls_handshake_state.h | 6 | ||||
-rw-r--r-- | src/lib/tls/tls_magic.h | 2 | ||||
-rw-r--r-- | src/lib/tls/tls_messages.h | 69 | ||||
-rw-r--r-- | src/lib/tls/tls_server.cpp | 60 | ||||
-rw-r--r-- | src/lib/tls/tls_server.h | 6 | ||||
-rw-r--r-- | src/tests/unit_tls.cpp | 21 |
21 files changed, 169 insertions, 343 deletions
diff --git a/src/cmd/tls_client.cpp b/src/cmd/tls_client.cpp index 903824a78..a1a6c0c5f 100644 --- a/src/cmd/tls_client.cpp +++ b/src/cmd/tls_client.cpp @@ -138,9 +138,6 @@ int tls_client(int argc, char* argv[]) return 1; } - const bool request_protocol = true; - const std::string use_protocol = "http/1.1"; - try { AutoSeeded_RNG rng; @@ -165,13 +162,9 @@ int tls_client(int argc, char* argv[]) const bool use_tcp = (transport == "tcp"); - int sockfd = connect_to_host(host, port, use_tcp); + const std::vector<std::string> protocols_to_offer = { "test/9.9", "http/1.1", "echo/9.1" }; - auto protocol_chooser = [use_protocol](const std::vector<std::string>& protocols) -> std::string { - for(size_t i = 0; i != protocols.size(); ++i) - std::cout << "Server offered protocol " << i << " = " << protocols[i] << "\n"; - return use_protocol; - }; + int sockfd = connect_to_host(host, port, use_tcp); auto socket_write = use_tcp ? @@ -190,7 +183,9 @@ int tls_client(int argc, char* argv[]) rng, TLS::Server_Information(host, port), version, - protocol_chooser); + protocols_to_offer); + + bool first_active = true; while(!client.is_closed()) { @@ -199,7 +194,16 @@ int tls_client(int argc, char* argv[]) FD_SET(sockfd, &readfds); if(client.is_active()) + { FD_SET(STDIN_FILENO, &readfds); + if(first_active && !protocols_to_offer.empty()) + { + std::string app = client.application_protocol(); + if(app != "") + std::cout << "Server choose protocol: " << client.application_protocol() << "\n"; + first_active = false; + } + } struct timeval timeout = { 1, 0 }; diff --git a/src/cmd/tls_server.cpp b/src/cmd/tls_server.cpp index fc8499be1..ee72ba5ac 100644 --- a/src/cmd/tls_server.cpp +++ b/src/cmd/tls_server.cpp @@ -146,12 +146,11 @@ int tls_server(int argc, char* argv[]) Basic_Credentials_Manager creds(rng, server_crt, server_key); - /* - * These are the protocols we advertise to the client, but the - * client will send back whatever it actually plans on talking, - * which may or may not take into account what we advertise. - */ - const std::vector<std::string> protocols = { "echo/1.0", "echo/1.1" }; + auto protocol_chooser = [](const std::vector<std::string>& protocols) -> std::string { + for(size_t i = 0; i != protocols.size(); ++i) + std::cout << "Client offered protocol " << i << " = " << protocols[i] << "\n"; + return "echo/1.0"; // too bad + }; std::cout << "Listening for new connections on " << transport << " port " << port << "\n"; @@ -210,7 +209,7 @@ int tls_server(int argc, char* argv[]) creds, policy, rng, - protocols, + protocol_chooser, !is_tcp); while(!server.is_closed()) diff --git a/src/lib/tls/info.txt b/src/lib/tls/info.txt index f65da5eea..3f3b323f1 100644 --- a/src/lib/tls/info.txt +++ b/src/lib/tls/info.txt @@ -1,4 +1,4 @@ -define TLS 20131128 +define TLS 20150319 load_on auto @@ -32,38 +32,6 @@ tls_seq_numbers.h tls_session_key.h </header:internal> -<source> -msg_cert_req.cpp -msg_cert_verify.cpp -msg_certificate.cpp -msg_client_hello.cpp -msg_client_kex.cpp -msg_finished.cpp -msg_hello_verify.cpp -msg_next_protocol.cpp -msg_server_hello.cpp -msg_server_kex.cpp -msg_session_ticket.cpp -tls_alert.cpp -tls_blocking.cpp -tls_channel.cpp -tls_ciphersuite.cpp -tls_client.cpp -tls_extensions.cpp -tls_handshake_hash.cpp -tls_handshake_io.cpp -tls_handshake_state.cpp -tls_heartbeats.cpp -tls_policy.cpp -tls_server.cpp -tls_session.cpp -tls_session_key.cpp -tls_session_manager_memory.cpp -tls_suite_info.cpp -tls_record.cpp -tls_version.cpp -</source> - <requires> aead aes diff --git a/src/lib/tls/msg_client_hello.cpp b/src/lib/tls/msg_client_hello.cpp index 473d9235f..8b75e93d6 100644 --- a/src/lib/tls/msg_client_hello.cpp +++ b/src/lib/tls/msg_client_hello.cpp @@ -72,7 +72,7 @@ Client_Hello::Client_Hello(Handshake_IO& io, const Policy& policy, RandomNumberGenerator& rng, const std::vector<byte>& reneg_info, - bool next_protocol, + const std::vector<std::string>& next_protocols, const std::string& hostname, const std::string& srp_identifier) : m_version(version), @@ -96,8 +96,8 @@ Client_Hello::Client_Hello(Handshake_IO& io, if(m_version.is_datagram_protocol()) m_extensions.add(new SRTP_Protection_Profiles(policy.srtp_profiles())); - if(reneg_info.empty() && next_protocol) - m_extensions.add(new Next_Protocol_Notification()); + if(reneg_info.empty() && !next_protocols.empty()) + m_extensions.add(new Application_Layer_Protocol_Notification(next_protocols)); BOTAN_ASSERT(policy.acceptable_protocol_version(version), "Our policy accepts the version we are offering"); @@ -117,7 +117,7 @@ Client_Hello::Client_Hello(Handshake_IO& io, RandomNumberGenerator& rng, const std::vector<byte>& reneg_info, const Session& session, - bool next_protocol) : + const std::vector<std::string>& next_protocols) : m_version(session.version()), m_session_id(session.session_id()), m_random(make_hello_random(rng, policy)), @@ -146,8 +146,8 @@ Client_Hello::Client_Hello(Handshake_IO& io, m_extensions.add(new Signature_Algorithms(policy.allowed_signature_hashes(), policy.allowed_signature_methods())); - if(reneg_info.empty() && next_protocol) - m_extensions.add(new Next_Protocol_Notification()); + if(reneg_info.empty() && !next_protocols.empty()) + m_extensions.add(new Application_Layer_Protocol_Notification(next_protocols)); hash.update(io.send(*this)); } diff --git a/src/lib/tls/msg_next_protocol.cpp b/src/lib/tls/msg_next_protocol.cpp deleted file mode 100644 index 6e56917d6..000000000 --- a/src/lib/tls/msg_next_protocol.cpp +++ /dev/null @@ -1,55 +0,0 @@ -/* -* Next Protocol Negotiation -* (C) 2012 Jack Lloyd -* -* Botan is released under the Simplified BSD License (see license.txt) -*/ - -#include <botan/internal/tls_messages.h> -#include <botan/internal/tls_extensions.h> -#include <botan/internal/tls_reader.h> -#include <botan/internal/tls_handshake_io.h> - -namespace Botan { - -namespace TLS { - -Next_Protocol::Next_Protocol(Handshake_IO& io, - Handshake_Hash& hash, - const std::string& protocol) : - m_protocol(protocol) - { - hash.update(io.send(*this)); - } - -Next_Protocol::Next_Protocol(const std::vector<byte>& buf) - { - TLS_Data_Reader reader("NextProtocol", buf); - - m_protocol = reader.get_string(1, 0, 255); - - reader.get_range_vector<byte>(1, 0, 255); // padding, ignored - } - -std::vector<byte> Next_Protocol::serialize() const - { - std::vector<byte> buf; - - append_tls_length_value(buf, - reinterpret_cast<const byte*>(m_protocol.data()), - m_protocol.size(), - 1); - - const byte padding_len = 32 - ((m_protocol.size() + 2) % 32); - - buf.push_back(padding_len); - - for(size_t i = 0; i != padding_len; ++i) - buf.push_back(0); - - return buf; - } - -} - -} diff --git a/src/lib/tls/msg_server_hello.cpp b/src/lib/tls/msg_server_hello.cpp index 73163a73b..0b352f080 100644 --- a/src/lib/tls/msg_server_hello.cpp +++ b/src/lib/tls/msg_server_hello.cpp @@ -28,7 +28,7 @@ Server_Hello::Server_Hello(Handshake_IO& io, u16bit ciphersuite, byte compression, bool offer_session_ticket, - const std::vector<std::string>& next_protocols) : + const std::string next_protocol) : m_version(new_session_version), m_session_id(new_session_id), m_random(make_hello_random(rng, policy)), @@ -47,8 +47,8 @@ Server_Hello::Server_Hello(Handshake_IO& io, if(policy.negotiate_heartbeat_support() && client_hello.supports_heartbeats()) m_extensions.add(new Heartbeat_Support_Indicator(true)); - if(client_hello.next_protocol_notification()) - m_extensions.add(new Next_Protocol_Notification(next_protocols)); + if(next_protocol != "" && client_hello.supports_alpn()) + m_extensions.add(new Application_Layer_Protocol_Notification(next_protocol)); if(m_version.is_datagram_protocol()) { @@ -83,7 +83,7 @@ Server_Hello::Server_Hello(Handshake_IO& io, const Client_Hello& client_hello, Session& resumed_session, bool offer_session_ticket, - const std::vector<std::string>& next_protocols) : + const std::string& next_protocol) : m_version(resumed_session.version()), m_session_id(client_hello.session_id()), m_random(make_hello_random(rng, policy)), @@ -102,8 +102,8 @@ Server_Hello::Server_Hello(Handshake_IO& io, if(policy.negotiate_heartbeat_support() && client_hello.supports_heartbeats()) m_extensions.add(new Heartbeat_Support_Indicator(true)); - if(client_hello.next_protocol_notification()) - m_extensions.add(new Next_Protocol_Notification(next_protocols)); + if(next_protocol != "" && client_hello.supports_alpn()) + m_extensions.add(new Application_Layer_Protocol_Notification(next_protocol)); hash.update(io.send(*this)); } diff --git a/src/lib/tls/tls_alert.cpp b/src/lib/tls/tls_alert.cpp index ecda5055c..5cfb1b0b1 100644 --- a/src/lib/tls/tls_alert.cpp +++ b/src/lib/tls/tls_alert.cpp @@ -103,6 +103,8 @@ std::string Alert::type_string() const return "bad_certificate_hash_value"; case UNKNOWN_PSK_IDENTITY: return "unknown_psk_identity"; + case NO_APPLICATION_PROTOCOL: + return "no_application_protocol"; case NULL_ALERT: return "none"; diff --git a/src/lib/tls/tls_alert.h b/src/lib/tls/tls_alert.h index 90bc80d45..81946d9db 100644 --- a/src/lib/tls/tls_alert.h +++ b/src/lib/tls/tls_alert.h @@ -57,6 +57,8 @@ class BOTAN_DLL Alert BAD_CERTIFICATE_HASH_VALUE = 114, UNKNOWN_PSK_IDENTITY = 115, + NO_APPLICATION_PROTOCOL = 120, // RFC 7301 + // pseudo alert values NULL_ALERT = 256, HEARTBEAT_PAYLOAD = 257 diff --git a/src/lib/tls/tls_blocking.cpp b/src/lib/tls/tls_blocking.cpp index b02c9ede1..b46961f9d 100644 --- a/src/lib/tls/tls_blocking.cpp +++ b/src/lib/tls/tls_blocking.cpp @@ -21,7 +21,7 @@ Blocking_Client::Blocking_Client(read_fn reader, RandomNumberGenerator& rng, const Server_Information& server_info, const Protocol_Version offer_version, - next_protocol_fn npn) : + const std::vector<std::string>& next) : m_read(reader), m_channel(writer, std::bind(&Blocking_Client::data_cb, this, _1, _2), @@ -33,7 +33,7 @@ Blocking_Client::Blocking_Client(read_fn reader, rng, server_info, offer_version, - npn) + next) { } diff --git a/src/lib/tls/tls_blocking.h b/src/lib/tls/tls_blocking.h index ca6906545..89421f5f5 100644 --- a/src/lib/tls/tls_blocking.h +++ b/src/lib/tls/tls_blocking.h @@ -32,17 +32,15 @@ class BOTAN_DLL Blocking_Client typedef std::function<size_t (byte[], size_t)> read_fn; typedef std::function<void (const byte[], size_t)> write_fn; - typedef Client::next_protocol_fn next_protocol_fn; - - Blocking_Client(read_fn reader, - write_fn writer, - Session_Manager& session_manager, - Credentials_Manager& creds, - const Policy& policy, - RandomNumberGenerator& rng, - const Server_Information& server_info = Server_Information(), - const Protocol_Version offer_version = Protocol_Version::latest_tls_version(), - next_protocol_fn npn = next_protocol_fn()); + Blocking_Client(read_fn reader, + write_fn writer, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + const Server_Information& server_info = Server_Information(), + const Protocol_Version offer_version = Protocol_Version::latest_tls_version(), + const std::vector<std::string>& next_protos = {}); /** * Completes full handshake then returns diff --git a/src/lib/tls/tls_client.cpp b/src/lib/tls/tls_client.cpp index bdc64283c..339e74e71 100644 --- a/src/lib/tls/tls_client.cpp +++ b/src/lib/tls/tls_client.cpp @@ -36,9 +36,6 @@ class Client_Handshake_State : public Handshake_State secure_vector<byte> resume_master_secret; std::unique_ptr<Public_Key> server_public_key; - - // Used by client using NPN - Client::next_protocol_fn client_npn_cb; }; } @@ -56,7 +53,7 @@ Client::Client(output_fn output_fn, RandomNumberGenerator& rng, const Server_Information& info, const Protocol_Version offer_version, - next_protocol_fn npn, + const std::vector<std::string>& next_protos, size_t io_buf_sz) : Channel(output_fn, proc_cb, alert_cb, handshake_cb, session_manager, rng, offer_version.is_datagram_protocol(), io_buf_sz), @@ -67,7 +64,7 @@ Client::Client(output_fn output_fn, const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_info.hostname()); Handshake_State& state = create_handshake_state(offer_version); - send_client_hello(state, false, offer_version, srp_identifier, npn); + send_client_hello(state, false, offer_version, srp_identifier, next_protos); } Handshake_State* Client::new_handshake_state(Handshake_IO* io) @@ -89,16 +86,14 @@ Client::get_peer_cert_chain(const Handshake_State& state) const void Client::initiate_handshake(Handshake_State& state, bool force_full_renegotiation) { - send_client_hello(state, - force_full_renegotiation, - state.version()); + send_client_hello(state, force_full_renegotiation, state.version()); } void Client::send_client_hello(Handshake_State& state_base, bool force_full_renegotiation, Protocol_Version version, const std::string& srp_identifier, - next_protocol_fn next_protocol) + const std::vector<std::string>& next_protocols) { Client_Handshake_State& state = dynamic_cast<Client_Handshake_State&>(state_base); @@ -106,10 +101,6 @@ void Client::send_client_hello(Handshake_State& state_base, state.set_expected_next(HELLO_VERIFY_REQUEST); // optional state.set_expected_next(SERVER_HELLO); - state.client_npn_cb = next_protocol; - - const bool send_npn_request = static_cast<bool>(next_protocol); - if(!force_full_renegotiation && !m_info.empty()) { Session session_info; @@ -124,7 +115,7 @@ void Client::send_client_hello(Handshake_State& state_base, rng(), secure_renegotiation_data_for_client_hello(), session_info, - send_npn_request)); + next_protocols)); state.resume_master_secret = session_info.master_secret(); } @@ -140,7 +131,7 @@ void Client::send_client_hello(Handshake_State& state_base, m_policy, rng(), secure_renegotiation_data_for_client_hello(), - send_npn_request, + next_protocols, m_info.hostname(), srp_identifier)); } @@ -247,6 +238,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, } state.set_version(state.server_hello()->version()); + m_application_protocol = state.server_hello()->next_protocol(); secure_renegotiation_check(state.server_hello()); @@ -389,20 +381,15 @@ void Client::process_handshake_msg(const Handshake_State* active_state, else if(type == CERTIFICATE_REQUEST) { state.set_expected_next(SERVER_HELLO_DONE); - state.cert_req( - new Certificate_Req(contents, state.version()) - ); + state.cert_req(new Certificate_Req(contents, state.version())); } else if(type == SERVER_HELLO_DONE) { - state.server_hello_done( - new Server_Hello_Done(contents) - ); + state.server_hello_done(new Server_Hello_Done(contents)); if(state.received_handshake_msg(CERTIFICATE_REQUEST)) { - const std::vector<std::string>& types = - state.cert_req()->acceptable_cert_types(); + const auto& types = state.cert_req()->acceptable_cert_types(); std::vector<X509_Certificate> client_certs = m_creds.cert_chain(types, @@ -449,19 +436,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, change_cipher_spec_writer(CLIENT); - if(state.server_hello()->next_protocol_notification()) - { - const std::string protocol = state.client_npn_cb( - state.server_hello()->next_protocols()); - - state.next_protocol( - new Next_Protocol(state.handshake_io(), state.hash(), protocol) - ); - } - - state.client_finished( - new Finished(state.handshake_io(), state, CLIENT) - ); + state.client_finished(new Finished(state.handshake_io(), state, CLIENT)); if(state.server_hello()->supports_session_ticket()) state.set_expected_next(NEW_SESSION_TICKET); @@ -493,22 +468,8 @@ void Client::process_handshake_msg(const Handshake_State* active_state, if(!state.client_finished()) // session resume case { state.handshake_io().send(Change_Cipher_Spec()); - change_cipher_spec_writer(CLIENT); - - if(state.server_hello()->next_protocol_notification()) - { - const std::string protocol = state.client_npn_cb( - state.server_hello()->next_protocols()); - - state.next_protocol( - new Next_Protocol(state.handshake_io(), state.hash(), protocol) - ); - } - - state.client_finished( - new Finished(state.handshake_io(), state, CLIENT) - ); + state.client_finished(new Finished(state.handshake_io(), state, CLIENT)); } std::vector<byte> session_id = state.server_hello()->session_id(); diff --git a/src/lib/tls/tls_client.h b/src/lib/tls/tls_client.h index a548a32e0..e4e0dc363 100644 --- a/src/lib/tls/tls_client.h +++ b/src/lib/tls/tls_client.h @@ -46,22 +46,13 @@ class BOTAN_DLL Client : public Channel * @param offer_version specifies which version we will offer * to the TLS server. * - * @param next_protocol allows the client to specify what the next - * protocol will be. For more information read - * http://technotes.googlecode.com/git/nextprotoneg.html. - * - * If the function is not empty, NPN will be negotiated - * and if the server supports NPN the function will be - * called with the list of protocols the server advertised; - * the client should return the protocol it would like to use. + * @param next_protocols specifies protocols to advertise with ALPN * * @param reserved_io_buffer_size This many bytes of memory will * be preallocated for the read and write buffers. Smaller * values just mean reallocations and copies are more likely. */ - typedef std::function<std::string (std::vector<std::string>)> next_protocol_fn; - Client(output_fn out, data_cb app_data_cb, alert_cb alert_cb, @@ -72,9 +63,11 @@ class BOTAN_DLL Client : public Channel RandomNumberGenerator& rng, const Server_Information& server_info = Server_Information(), const Protocol_Version offer_version = Protocol_Version::latest_tls_version(), - next_protocol_fn next_protocol = next_protocol_fn(), + const std::vector<std::string>& next_protocols = {}, size_t reserved_io_buffer_size = 16*1024 ); + + const std::string& application_protocol() const { return m_application_protocol; } private: std::vector<X509_Certificate> get_peer_cert_chain(const Handshake_State& state) const override; @@ -86,7 +79,7 @@ class BOTAN_DLL Client : public Channel bool force_full_renegotiation, Protocol_Version version, const std::string& srp_identifier = "", - next_protocol_fn next_protocol = next_protocol_fn()); + const std::vector<std::string>& next_protocols = {}); void process_handshake_msg(const Handshake_State* active_state, Handshake_State& pending_state, @@ -98,6 +91,7 @@ class BOTAN_DLL Client : public Channel const Policy& m_policy; Credentials_Manager& m_creds; const Server_Information m_info; + std::string m_application_protocol; }; } diff --git a/src/lib/tls/tls_extensions.cpp b/src/lib/tls/tls_extensions.cpp index 2c3056d9f..b7ba4a917 100644 --- a/src/lib/tls/tls_extensions.cpp +++ b/src/lib/tls/tls_extensions.cpp @@ -42,8 +42,8 @@ Extension* make_extension(TLS_Data_Reader& reader, case TLSEXT_USE_SRTP: return new SRTP_Protection_Profiles(reader, size); - case TLSEXT_NEXT_PROTOCOL: - return new Next_Protocol_Notification(reader, size); + case TLSEXT_ALPN: + return new Application_Layer_Protocol_Notification(reader, size); case TLSEXT_HEARTBEAT_SUPPORT: return new Heartbeat_Support_Indicator(reader, size); @@ -258,20 +258,25 @@ Maximum_Fragment_Length::Maximum_Fragment_Length(TLS_Data_Reader& reader, } } -Next_Protocol_Notification::Next_Protocol_Notification(TLS_Data_Reader& reader, - u16bit extension_size) +Application_Layer_Protocol_Notification::Application_Layer_Protocol_Notification(TLS_Data_Reader& reader, + u16bit extension_size) { if(extension_size == 0) return; // empty extension - size_t bytes_remaining = extension_size; + const u16bit name_bytes = reader.get_u16bit(); + + size_t bytes_remaining = extension_size - 2; + + if(name_bytes != bytes_remaining) + throw Decoding_Error("Bad encoding of ALPN extension, bad length field"); while(bytes_remaining) { const std::string p = reader.get_string(1, 0, 255); if(bytes_remaining < p.size() + 1) - throw Decoding_Error("Bad encoding for next protocol extension"); + throw Decoding_Error("Bad encoding of ALPN, length field too long"); bytes_remaining -= (p.size() + 1); @@ -279,14 +284,23 @@ Next_Protocol_Notification::Next_Protocol_Notification(TLS_Data_Reader& reader, } } -std::vector<byte> Next_Protocol_Notification::serialize() const +const std::string& Application_Layer_Protocol_Notification::single_protocol() const { - std::vector<byte> buf; + if(m_protocols.size() != 1) + throw TLS_Exception(Alert::HANDSHAKE_FAILURE, + "Server sent " + std::to_string(m_protocols.size()) + + " protocols in ALPN extension response"); + return m_protocols[0]; + } - for(size_t i = 0; i != m_protocols.size(); ++i) - { - const std::string p = m_protocols[i]; +std::vector<byte> Application_Layer_Protocol_Notification::serialize() const + { + std::vector<byte> buf(2); + for(auto&& p: m_protocols) + { + if(p.length() >= 256) + throw TLS_Exception(Alert::INTERNAL_ERROR, "ALPN name too long"); if(p != "") append_tls_length_value(buf, reinterpret_cast<const byte*>(p.data()), @@ -294,6 +308,9 @@ std::vector<byte> Next_Protocol_Notification::serialize() const 1); } + buf[0] = get_byte<u16bit>(0, buf.size()-2); + buf[1] = get_byte<u16bit>(1, buf.size()-2); + return buf; } diff --git a/src/lib/tls/tls_extensions.h b/src/lib/tls/tls_extensions.h index 393cada12..83e819509 100644 --- a/src/lib/tls/tls_extensions.h +++ b/src/lib/tls/tls_extensions.h @@ -35,11 +35,10 @@ enum Handshake_Extension_Type { TLSEXT_SIGNATURE_ALGORITHMS = 13, TLSEXT_USE_SRTP = 14, TLSEXT_HEARTBEAT_SUPPORT = 15, + TLSEXT_ALPN = 16, TLSEXT_SESSION_TICKET = 35, - TLSEXT_NEXT_PROTOCOL = 13172, - TLSEXT_SAFE_RENEGOTIATION = 65281, }; @@ -181,41 +180,37 @@ class Maximum_Fragment_Length : public Extension }; /** -* Next Protocol Negotiation -* http://technotes.googlecode.com/git/nextprotoneg.html -* -* This implementation requires the semantics defined in the Google -* spec (implemented in Chromium); the internet draft leaves the format -* unspecified. +* ALPN (RFC 7301) */ -class Next_Protocol_Notification : public Extension +class Application_Layer_Protocol_Notification : public Extension { public: - static Handshake_Extension_Type static_type() - { return TLSEXT_NEXT_PROTOCOL; } + static Handshake_Extension_Type static_type() { return TLSEXT_ALPN; } Handshake_Extension_Type type() const { return static_type(); } - const std::vector<std::string>& protocols() const - { return m_protocols; } + const std::vector<std::string>& protocols() const { return m_protocols; } + + const std::string& single_protocol() const; /** - * Empty extension, used by client + * Single protocol, used by server */ - Next_Protocol_Notification() {} + Application_Layer_Protocol_Notification(const std::string& protocol) : + m_protocols(1, protocol) {} /** - * List of protocols, used by server + * List of protocols, used by client */ - Next_Protocol_Notification(const std::vector<std::string>& protocols) : + Application_Layer_Protocol_Notification(const std::vector<std::string>& protocols) : m_protocols(protocols) {} - Next_Protocol_Notification(TLS_Data_Reader& reader, - u16bit extension_size); + Application_Layer_Protocol_Notification(TLS_Data_Reader& reader, + u16bit extension_size); std::vector<byte> serialize() const; - bool empty() const { return false; } + bool empty() const { return m_protocols.empty(); } private: std::vector<std::string> m_protocols; }; diff --git a/src/lib/tls/tls_handshake_state.cpp b/src/lib/tls/tls_handshake_state.cpp index f0d80556d..cbbca3a0d 100644 --- a/src/lib/tls/tls_handshake_state.cpp +++ b/src/lib/tls/tls_handshake_state.cpp @@ -58,17 +58,14 @@ u32bit bitmask_for_handshake_type(Handshake_Type type) case CLIENT_KEX: return (1 << 11); - case NEXT_PROTOCOL: - return (1 << 12); - case NEW_SESSION_TICKET: - return (1 << 13); + return (1 << 12); case HANDSHAKE_CCS: - return (1 << 14); + return (1 << 13); case FINISHED: - return (1 << 15); + return (1 << 14); // allow explicitly disabling new handshakes case HANDSHAKE_NONE: @@ -157,12 +154,6 @@ void Handshake_State::client_verify(Certificate_Verify* client_verify) note_message(*m_client_verify); } -void Handshake_State::next_protocol(Next_Protocol* next_protocol) - { - m_next_protocol.reset(next_protocol); - note_message(*m_next_protocol); - } - void Handshake_State::new_session_ticket(New_Session_Ticket* new_session_ticket) { m_new_session_ticket.reset(new_session_ticket); diff --git a/src/lib/tls/tls_handshake_state.h b/src/lib/tls/tls_handshake_state.h index 3ad44c613..3b60178b4 100644 --- a/src/lib/tls/tls_handshake_state.h +++ b/src/lib/tls/tls_handshake_state.h @@ -36,7 +36,6 @@ class Server_Hello_Done; class Certificate; class Client_Key_Exchange; class Certificate_Verify; -class Next_Protocol; class New_Session_Ticket; class Finished; @@ -111,7 +110,6 @@ class Handshake_State void client_certs(Certificate* client_certs); void client_kex(Client_Key_Exchange* client_kex); void client_verify(Certificate_Verify* client_verify); - void next_protocol(Next_Protocol* next_protocol); void new_session_ticket(New_Session_Ticket* new_session_ticket); void server_finished(Finished* server_finished); void client_finished(Finished* client_finished); @@ -143,9 +141,6 @@ class Handshake_State const Certificate_Verify* client_verify() const { return m_client_verify.get(); } - const Next_Protocol* next_protocol() const - { return m_next_protocol.get(); } - const New_Session_Ticket* new_session_ticket() const { return m_new_session_ticket.get(); } @@ -195,7 +190,6 @@ class Handshake_State std::unique_ptr<Certificate> m_client_certs; std::unique_ptr<Client_Key_Exchange> m_client_kex; std::unique_ptr<Certificate_Verify> m_client_verify; - std::unique_ptr<Next_Protocol> m_next_protocol; std::unique_ptr<New_Session_Ticket> m_new_session_ticket; std::unique_ptr<Finished> m_server_finished; std::unique_ptr<Finished> m_client_finished; diff --git a/src/lib/tls/tls_magic.h b/src/lib/tls/tls_magic.h index 4a7237722..882e59158 100644 --- a/src/lib/tls/tls_magic.h +++ b/src/lib/tls/tls_magic.h @@ -53,8 +53,6 @@ enum Handshake_Type { CERTIFICATE_URL = 21, CERTIFICATE_STATUS = 22, - NEXT_PROTOCOL = 67, - HANDSHAKE_CCS = 254, // Not a wire value HANDSHAKE_NONE = 255 // Null value }; diff --git a/src/lib/tls/tls_messages.h b/src/lib/tls/tls_messages.h index 18cc90c39..befbdb932 100644 --- a/src/lib/tls/tls_messages.h +++ b/src/lib/tls/tls_messages.h @@ -115,11 +115,6 @@ class Client_Hello : public Handshake_Message return std::vector<byte>(); } - bool next_protocol_notification() const - { - return m_extensions.has<Next_Protocol_Notification>(); - } - size_t fragment_size() const { if(Maximum_Fragment_Length* frag = m_extensions.get<Maximum_Fragment_Length>()) @@ -139,6 +134,18 @@ class Client_Hello : public Handshake_Message return std::vector<byte>(); } + bool supports_alpn() const + { + return m_extensions.has<Application_Layer_Protocol_Notification>(); + } + + std::vector<std::string> next_protocols() const + { + if(auto alpn = m_extensions.get<Application_Layer_Protocol_Notification>()) + return alpn->protocols(); + return std::vector<std::string>(); + } + bool supports_heartbeats() const { return m_extensions.has<Heartbeat_Support_Indicator>(); @@ -169,7 +176,7 @@ class Client_Hello : public Handshake_Message const Policy& policy, RandomNumberGenerator& rng, const std::vector<byte>& reneg_info, - bool next_protocol = false, + const std::vector<std::string>& next_protocols, const std::string& hostname = "", const std::string& srp_identifier = ""); @@ -179,7 +186,7 @@ class Client_Hello : public Handshake_Message RandomNumberGenerator& rng, const std::vector<byte>& reneg_info, const Session& resumed_session, - bool next_protocol = false); + const std::vector<std::string>& next_protocols); Client_Hello(const std::vector<byte>& buf); @@ -226,18 +233,6 @@ class Server_Hello : public Handshake_Message return std::vector<byte>(); } - bool next_protocol_notification() const - { - return m_extensions.has<Next_Protocol_Notification>(); - } - - std::vector<std::string> next_protocols() const - { - if(Next_Protocol_Notification* npn = m_extensions.get<Next_Protocol_Notification>()) - return npn->protocols(); - return std::vector<std::string>(); - } - size_t fragment_size() const { if(Maximum_Fragment_Length* frag = m_extensions.get<Maximum_Fragment_Length>()) @@ -257,14 +252,14 @@ class Server_Hello : public Handshake_Message bool peer_can_send_heartbeats() const { - if(Heartbeat_Support_Indicator* hb = m_extensions.get<Heartbeat_Support_Indicator>()) + if(auto hb = m_extensions.get<Heartbeat_Support_Indicator>()) return hb->peer_allowed_to_send(); return false; } u16bit srtp_profile() const { - if(SRTP_Protection_Profiles* srtp = m_extensions.get<SRTP_Protection_Profiles>()) + if(auto srtp = m_extensions.get<SRTP_Protection_Profiles>()) { auto prof = srtp->profiles(); if(prof.size() != 1 || prof[0] == 0) @@ -275,6 +270,13 @@ class Server_Hello : public Handshake_Message return 0; } + std::string next_protocol() const + { + if(auto alpn = m_extensions.get<Application_Layer_Protocol_Notification>()) + return alpn->single_protocol(); + return ""; + } + std::set<Handshake_Extension_Type> extension_types() const { return m_extensions.extension_types(); } @@ -289,7 +291,7 @@ class Server_Hello : public Handshake_Message u16bit ciphersuite, byte compression, bool offer_session_ticket, - const std::vector<std::string>& next_protocols); + const std::string next_protocol); Server_Hello(Handshake_IO& io, Handshake_Hash& hash, @@ -299,7 +301,7 @@ class Server_Hello : public Handshake_Message const Client_Hello& client_hello, Session& resumed_session, bool offer_session_ticket, - const std::vector<std::string>& next_protocols); + const std::string& next_protocol); Server_Hello(const std::vector<byte>& buf); private: @@ -534,27 +536,6 @@ class Server_Hello_Done : public Handshake_Message }; /** -* Next Protocol Message -*/ -class Next_Protocol : public Handshake_Message - { - public: - Handshake_Type type() const override { return NEXT_PROTOCOL; } - - std::string protocol() const { return m_protocol; } - - Next_Protocol(Handshake_IO& io, - Handshake_Hash& hash, - const std::string& protocol); - - Next_Protocol(const std::vector<byte>& buf); - private: - std::vector<byte> serialize() const override; - - std::string m_protocol; - }; - -/** * New Session Ticket Message */ class New_Session_Ticket : public Handshake_Message diff --git a/src/lib/tls/tls_server.cpp b/src/lib/tls/tls_server.cpp index 515bd9e17..2f5a0e00d 100644 --- a/src/lib/tls/tls_server.cpp +++ b/src/lib/tls/tls_server.cpp @@ -212,14 +212,14 @@ Server::Server(output_fn output, Credentials_Manager& creds, const Policy& policy, RandomNumberGenerator& rng, - const std::vector<std::string>& next_protocols, + next_protocol_fn next_proto, bool is_datagram, size_t io_buf_sz) : Channel(output, data_cb, alert_cb, handshake_cb, session_manager, rng, is_datagram, io_buf_sz), m_policy(policy), m_creds(creds), - m_possible_protocols(next_protocols) + m_choose_next_protocol(next_proto) { } @@ -348,10 +348,6 @@ void Server::process_handshake_msg(const Handshake_State* active_state, "Client signalled fallback SCSV, possible attack"); } - if(!initial_handshake && state.client_hello()->next_protocol_notification()) - throw TLS_Exception(Alert::HANDSHAKE_FAILURE, - "Client included NPN extension for renegotiation"); - secure_renegotiation_check(state.client_hello()); state.set_version(negotiated_version); @@ -374,6 +370,10 @@ void Server::process_handshake_msg(const Handshake_State* active_state, } catch(...) {} + m_next_protocol = ""; + if(state.client_hello()->supports_alpn()) + m_next_protocol = m_choose_next_protocol(state.client_hello()->next_protocols()); + if(resuming) { // Only offer a resuming client a new ticket if they didn't send one this time, @@ -393,7 +393,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, *state.client_hello(), session_info, offer_new_session_ticket, - m_possible_protocols + m_next_protocol )); secure_renegotiation_check(state.server_hello()); @@ -440,10 +440,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, change_cipher_spec_writer(SERVER); - state.server_finished( - new Finished(state.handshake_io(), state, SERVER) - ); - + state.server_finished(new Finished(state.handshake_io(), state, SERVER)); state.set_expected_next(HANDSHAKE_CCS); } else // new session @@ -481,7 +478,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, choose_ciphersuite(m_policy, state.version(), m_creds, cert_chains, state.client_hello()), choose_compression(m_policy, state.client_hello()->compression_methods()), have_session_ticket_key, - m_possible_protocols) + m_next_protocol) ); secure_renegotiation_check(state.server_hello()); @@ -494,11 +491,9 @@ void Server::process_handshake_msg(const Handshake_State* active_state, BOTAN_ASSERT(!cert_chains[sig_algo].empty(), "Attempting to send empty certificate chain"); - state.server_certs( - new Certificate(state.handshake_io(), - state.hash(), - cert_chains[sig_algo]) - ); + state.server_certs(new Certificate(state.handshake_io(), + state.hash(), + cert_chains[sig_algo])); } Private_Key* private_key = nullptr; @@ -538,12 +533,8 @@ void Server::process_handshake_msg(const Handshake_State* active_state, if(!client_auth_CAs.empty() && state.ciphersuite().sig_algo() != "") { state.cert_req( - new Certificate_Req(state.handshake_io(), - state.hash(), - m_policy, - client_auth_CAs, - state.version()) - ); + new Certificate_Req(state.handshake_io(), state.hash(), + m_policy, client_auth_CAs, state.version())); state.set_expected_next(CERTIFICATE); } @@ -555,9 +546,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, */ state.set_expected_next(CLIENT_KEX); - state.server_hello_done( - new Server_Hello_Done(state.handshake_io(), state.hash()) - ); + state.server_hello_done(new Server_Hello_Done(state.handshake_io(), state.hash())); } } else if(type == CERTIFICATE) @@ -614,21 +603,8 @@ void Server::process_handshake_msg(const Handshake_State* active_state, } else if(type == HANDSHAKE_CCS) { - if(state.server_hello()->next_protocol_notification()) - state.set_expected_next(NEXT_PROTOCOL); - else - state.set_expected_next(FINISHED); - - change_cipher_spec_reader(SERVER); - } - else if(type == NEXT_PROTOCOL) - { state.set_expected_next(FINISHED); - - state.next_protocol(new Next_Protocol(contents)); - - // should this be a callback? - m_next_protocol = state.next_protocol()->protocol(); + change_cipher_spec_reader(SERVER); } else if(type == FINISHED) { @@ -694,9 +670,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, change_cipher_spec_writer(SERVER); - state.server_finished( - new Finished(state.handshake_io(), state, SERVER) - ); + state.server_finished(new Finished(state.handshake_io(), state, SERVER)); } activate_session(); diff --git a/src/lib/tls/tls_server.h b/src/lib/tls/tls_server.h index 4b15e837b..4f2a11ba4 100644 --- a/src/lib/tls/tls_server.h +++ b/src/lib/tls/tls_server.h @@ -22,6 +22,8 @@ namespace TLS { class BOTAN_DLL Server : public Channel { public: + typedef std::function<std::string (std::vector<std::string>)> next_protocol_fn; + /** * Server initialization */ @@ -33,7 +35,7 @@ class BOTAN_DLL Server : public Channel Credentials_Manager& creds, const Policy& policy, RandomNumberGenerator& rng, - const std::vector<std::string>& protocols = std::vector<std::string>(), + next_protocol_fn next_proto = next_protocol_fn(), bool is_datagram = false, size_t reserved_io_buffer_size = 16*1024 ); @@ -63,7 +65,7 @@ class BOTAN_DLL Server : public Channel const Policy& m_policy; Credentials_Manager& m_creds; - std::vector<std::string> m_possible_protocols; + next_protocol_fn m_choose_next_protocol; std::string m_next_protocol; }; diff --git a/src/tests/unit_tls.cpp b/src/tests/unit_tls.cpp index aeab7cd4e..8e21cc484 100644 --- a/src/tests/unit_tls.cpp +++ b/src/tests/unit_tls.cpp @@ -176,6 +176,15 @@ size_t basic_test_handshake(RandomNumberGenerator& rng, s2c_data.insert(s2c_data.end(), buf, buf+sz); }; + auto next_protocol_chooser = [&](std::vector<std::string> protos) { + if(protos.size() != 2) + std::cout << "Bad protocol size\n"; + if(protos[0] != "test/1" || protos[1] != "test/2") + std::cout << "Bad protocol values\n"; + return "test/3"; + }; + const std::vector<std::string> protocols_offered = { "test/1", "test/2" }; + TLS::Server server([&](const byte buf[], size_t sz) { s2c_q.insert(s2c_q.end(), buf, buf+sz); }, save_server_data, @@ -185,15 +194,7 @@ size_t basic_test_handshake(RandomNumberGenerator& rng, creds, policy, rng, - { "test/1", "test/2" }); - - auto next_protocol_chooser = [&](std::vector<std::string> protos) { - if(protos.size() != 2) - std::cout << "Bad protocol size\n"; - if(protos[0] != "test/1" || protos[1] != "test/2") - std::cout << "Bad protocol values\n"; - return "test/3"; - }; + next_protocol_chooser); TLS::Client client([&](const byte buf[], size_t sz) { c2s_q.insert(c2s_q.end(), buf, buf+sz); }, @@ -206,7 +207,7 @@ size_t basic_test_handshake(RandomNumberGenerator& rng, rng, TLS::Server_Information(), offer_version, - next_protocol_chooser); + protocols_offered); while(true) { |