diff options
author | lloyd <[email protected]> | 2012-10-13 20:55:17 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2012-10-13 20:55:17 +0000 |
commit | 72eb425d699e0571857432e4271d10afb6431a6e (patch) | |
tree | cb8d3174f4844c84f8295a63ea0d44aaefa9b31c /src | |
parent | 8bd5519105f6978e7d937294d2a2e8deadda20ca (diff) | |
parent | 4be75ae1e9e473fc3e939be5e54e51f552d5934b (diff) |
merge of '415e0ca58c566cb2990758c1261d47d6b09fc76c'
and 'e616da4002c659a5f5f6c16aecaafef7c37a5f96'
Diffstat (limited to 'src')
-rw-r--r-- | src/block/aes_ni/aes_ni.cpp | 44 | ||||
-rw-r--r-- | src/tls/info.txt | 1 | ||||
-rw-r--r-- | src/tls/msg_client_hello.cpp | 10 | ||||
-rw-r--r-- | src/tls/msg_server_hello.cpp | 7 | ||||
-rw-r--r-- | src/tls/sessions_sqlite/tls_session_manager_sqlite.cpp | 18 | ||||
-rw-r--r-- | src/tls/sessions_sqlite/tls_session_manager_sqlite.h | 6 | ||||
-rw-r--r-- | src/tls/tls_client.cpp | 32 | ||||
-rw-r--r-- | src/tls/tls_client.h | 14 | ||||
-rw-r--r-- | src/tls/tls_messages.h | 1 | ||||
-rw-r--r-- | src/tls/tls_policy.cpp | 5 | ||||
-rw-r--r-- | src/tls/tls_policy.h | 10 | ||||
-rw-r--r-- | src/tls/tls_server.cpp | 19 | ||||
-rw-r--r-- | src/tls/tls_server.h | 13 | ||||
-rw-r--r-- | src/tls/tls_server_info.h | 91 | ||||
-rw-r--r-- | src/tls/tls_session.cpp | 26 | ||||
-rw-r--r-- | src/tls/tls_session.h | 14 | ||||
-rw-r--r-- | src/tls/tls_session_manager.cpp | 31 | ||||
-rw-r--r-- | src/tls/tls_session_manager.h | 25 | ||||
-rw-r--r-- | src/tls/tls_version.h | 10 | ||||
-rw-r--r-- | src/utils/assert.h | 6 |
20 files changed, 244 insertions, 139 deletions
diff --git a/src/block/aes_ni/aes_ni.cpp b/src/block/aes_ni/aes_ni.cpp index 3ee0e608c..4dca6c7f2 100644 --- a/src/block/aes_ni/aes_ni.cpp +++ b/src/block/aes_ni/aes_ni.cpp @@ -1,6 +1,6 @@ /* * AES using AES-NI instructions -* (C) 2009 Jack Lloyd +* (C) 2009,2012 Jack Lloyd * * Distributed under the terms of the Botan license */ @@ -485,10 +485,10 @@ void AES_192_NI::key_schedule(const byte key[], size_t) load_le(&EK[0], key, 6); -#define AES_192_key_exp(RCON, EK_OFF) \ - aes_192_key_expansion(&K0, &K1, \ - _mm_aeskeygenassist_si128(K1, RCON), \ - EK + EK_OFF, EK_OFF == 48) + #define AES_192_key_exp(RCON, EK_OFF) \ + aes_192_key_expansion(&K0, &K1, \ + _mm_aeskeygenassist_si128(K1, RCON), \ + &EK[EK_OFF], EK_OFF == 48) AES_192_key_exp(0x01, 6); AES_192_key_exp(0x02, 12); @@ -499,22 +499,25 @@ void AES_192_NI::key_schedule(const byte key[], size_t) AES_192_key_exp(0x40, 42); AES_192_key_exp(0x80, 48); + #undef AES_192_key_exp + // Now generate decryption keys const __m128i* EK_mm = (const __m128i*)&EK[0]; + __m128i* DK_mm = (__m128i*)&DK[0]; - _mm_storeu_si128(DK_mm , EK_mm[12]); - _mm_storeu_si128(DK_mm + 1, _mm_aesimc_si128(EK_mm[11])); - _mm_storeu_si128(DK_mm + 2, _mm_aesimc_si128(EK_mm[10])); - _mm_storeu_si128(DK_mm + 3, _mm_aesimc_si128(EK_mm[9])); - _mm_storeu_si128(DK_mm + 4, _mm_aesimc_si128(EK_mm[8])); - _mm_storeu_si128(DK_mm + 5, _mm_aesimc_si128(EK_mm[7])); - _mm_storeu_si128(DK_mm + 6, _mm_aesimc_si128(EK_mm[6])); - _mm_storeu_si128(DK_mm + 7, _mm_aesimc_si128(EK_mm[5])); - _mm_storeu_si128(DK_mm + 8, _mm_aesimc_si128(EK_mm[4])); - _mm_storeu_si128(DK_mm + 9, _mm_aesimc_si128(EK_mm[3])); - _mm_storeu_si128(DK_mm + 10, _mm_aesimc_si128(EK_mm[2])); - _mm_storeu_si128(DK_mm + 11, _mm_aesimc_si128(EK_mm[1])); - _mm_storeu_si128(DK_mm + 12, EK_mm[0]); + _mm_storeu_si128(DK_mm , _mm_loadu_si128(EK_mm + 12)); + _mm_storeu_si128(DK_mm + 1, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 11))); + _mm_storeu_si128(DK_mm + 2, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 10))); + _mm_storeu_si128(DK_mm + 3, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 9))); + _mm_storeu_si128(DK_mm + 4, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 8))); + _mm_storeu_si128(DK_mm + 5, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 7))); + _mm_storeu_si128(DK_mm + 6, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 6))); + _mm_storeu_si128(DK_mm + 7, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 5))); + _mm_storeu_si128(DK_mm + 8, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 4))); + _mm_storeu_si128(DK_mm + 9, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 3))); + _mm_storeu_si128(DK_mm + 10, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 2))); + _mm_storeu_si128(DK_mm + 11, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 1))); + _mm_storeu_si128(DK_mm + 12, _mm_loadu_si128(EK_mm + 0)); } /* @@ -776,4 +779,9 @@ void AES_256_NI::clear() zeroise(DK); } +#undef AES_ENC_4_ROUNDS +#undef AES_ENC_4_LAST_ROUNDS +#undef AES_DEC_4_ROUNDS +#undef AES_DEC_4_LAST_ROUNDS + } diff --git a/src/tls/info.txt b/src/tls/info.txt index e61b2c0da..47de42598 100644 --- a/src/tls/info.txt +++ b/src/tls/info.txt @@ -15,6 +15,7 @@ tls_client.h tls_exceptn.h tls_handshake_msg.h tls_magic.h +tls_server_info.h tls_policy.h tls_server.h tls_session.h diff --git a/src/tls/msg_client_hello.cpp b/src/tls/msg_client_hello.cpp index 2149ac5e5..6176ca6bf 100644 --- a/src/tls/msg_client_hello.cpp +++ b/src/tls/msg_client_hello.cpp @@ -74,13 +74,15 @@ Client_Hello::Client_Hello(Handshake_IO& io, m_suites(ciphersuite_list(policy, m_version, (srp_identifier != ""))), m_comp_methods(policy.compression()) { - m_extensions.add(new Heartbeat_Support_Indicator(true)); m_extensions.add(new Renegotiation_Extension(reneg_info)); m_extensions.add(new SRP_Identifier(srp_identifier)); m_extensions.add(new Server_Name_Indicator(hostname)); m_extensions.add(new Session_Ticket()); m_extensions.add(new Supported_Elliptic_Curves(policy.allowed_ecc_curves())); + if(policy.negotiate_heartbeat_support()) + m_extensions.add(new Heartbeat_Support_Indicator(true)); + if(m_version.supports_negotiable_signature_algorithms()) m_extensions.add(new Signature_Algorithms(policy.allowed_signature_hashes(), policy.allowed_signature_methods())); @@ -113,13 +115,15 @@ Client_Hello::Client_Hello(Handshake_IO& io, if(!value_exists(m_comp_methods, session.compression_method())) m_comp_methods.push_back(session.compression_method()); - m_extensions.add(new Heartbeat_Support_Indicator(true)); m_extensions.add(new Renegotiation_Extension(reneg_info)); m_extensions.add(new SRP_Identifier(session.srp_identifier())); - m_extensions.add(new Server_Name_Indicator(session.sni_hostname())); + m_extensions.add(new Server_Name_Indicator(session.server_info().hostname())); m_extensions.add(new Session_Ticket(session.session_ticket())); m_extensions.add(new Supported_Elliptic_Curves(policy.allowed_ecc_curves())); + if(policy.negotiate_heartbeat_support()) + m_extensions.add(new Heartbeat_Support_Indicator(true)); + if(session.fragment_size() != 0) m_extensions.add(new Maximum_Fragment_Length(session.fragment_size())); diff --git a/src/tls/msg_server_hello.cpp b/src/tls/msg_server_hello.cpp index 6ca5e3b30..a775e0b4b 100644 --- a/src/tls/msg_server_hello.cpp +++ b/src/tls/msg_server_hello.cpp @@ -21,6 +21,7 @@ namespace TLS { */ Server_Hello::Server_Hello(Handshake_IO& io, Handshake_Hash& hash, + const Policy& policy, const std::vector<byte>& session_id, Protocol_Version ver, u16bit ciphersuite, @@ -39,9 +40,13 @@ Server_Hello::Server_Hello(Handshake_IO& io, m_ciphersuite(ciphersuite), m_comp_method(compression) { - if(client_has_heartbeat) + if(client_has_heartbeat && policy.negotiate_heartbeat_support()) m_extensions.add(new Heartbeat_Support_Indicator(true)); + /* + * Even a client that offered SSLv3 and sent the SCSV will get an + * extension back. This is probably the right thing to do. + */ if(client_has_secure_renegotiation) m_extensions.add(new Renegotiation_Extension(reneg_info)); diff --git a/src/tls/sessions_sqlite/tls_session_manager_sqlite.cpp b/src/tls/sessions_sqlite/tls_session_manager_sqlite.cpp index d10366c60..87556ff75 100644 --- a/src/tls/sessions_sqlite/tls_session_manager_sqlite.cpp +++ b/src/tls/sessions_sqlite/tls_session_manager_sqlite.cpp @@ -142,16 +142,15 @@ bool Session_Manager_SQLite::load_from_session_id(const std::vector<byte>& sessi return false; } -bool Session_Manager_SQLite::load_from_host_info(const std::string& hostname, - u16bit port, - Session& session) +bool Session_Manager_SQLite::load_from_server_info(const Server_Information& server, + Session& session) { sqlite3_statement stmt(m_db, "select session from tls_sessions" " where hostname = ?1 and hostport = ?2" " order by session_start desc"); - stmt.bind(1, hostname); - stmt.bind(2, port); + stmt.bind(1, server.hostname()); + stmt.bind(2, server.port()); while(stmt.step()) { @@ -167,9 +166,6 @@ bool Session_Manager_SQLite::load_from_host_info(const std::string& hostname, } } - if(port != 0) - return load_from_host_info(hostname, 0, session); - return false; } @@ -182,15 +178,15 @@ void Session_Manager_SQLite::remove_entry(const std::vector<byte>& session_id) stmt.spin(); } -void Session_Manager_SQLite::save(const Session& session, u16bit port) +void Session_Manager_SQLite::save(const Session& session) { sqlite3_statement stmt(m_db, "insert or replace into tls_sessions" " values(?1, ?2, ?3, ?4, ?5)"); stmt.bind(1, hex_encode(session.session_id())); stmt.bind(2, session.start_time()); - stmt.bind(3, session.sni_hostname()); - stmt.bind(4, port); + stmt.bind(3, session.server_info().hostname()); + stmt.bind(4, session.server_info().port()); stmt.bind(5, session.encrypt(m_session_key, m_rng)); stmt.spin(); diff --git a/src/tls/sessions_sqlite/tls_session_manager_sqlite.h b/src/tls/sessions_sqlite/tls_session_manager_sqlite.h index db74f54b7..7892ccd6a 100644 --- a/src/tls/sessions_sqlite/tls_session_manager_sqlite.h +++ b/src/tls/sessions_sqlite/tls_session_manager_sqlite.h @@ -50,12 +50,12 @@ class BOTAN_DLL Session_Manager_SQLite : public Session_Manager bool load_from_session_id(const std::vector<byte>& session_id, Session& session) override; - bool load_from_host_info(const std::string& hostname, u16bit port, - Session& session) override; + bool load_from_server_info(const Server_Information& info, + Session& session) override; void remove_entry(const std::vector<byte>& session_id) override; - void save(const Session& session_data, u16bit port) override; + void save(const Session& session_data) override; std::chrono::seconds session_lifetime() const override { return m_session_lifetime; } diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index bb6b7a45f..0e1d84bed 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -55,20 +55,18 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn, Credentials_Manager& creds, const Policy& policy, RandomNumberGenerator& rng, - const std::string& hostname, - u16bit port, + const Server_Information& info, + const Protocol_Version offer_version, std::function<std::string (std::vector<std::string>)> next_protocol) : Channel(output_fn, proc_fn, handshake_fn, session_manager, rng), m_policy(policy), m_creds(creds), - m_hostname(hostname), - m_port(port) + m_info(info) { - const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_hostname); + const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_info.hostname()); - const Protocol_Version version = m_policy.pref_version(); - Handshake_State& state = create_handshake_state(version); - send_client_hello(state, false, version, srp_identifier, next_protocol); + Handshake_State& state = create_handshake_state(offer_version); + send_client_hello(state, false, offer_version, srp_identifier, next_protocol); } Handshake_State* Client::new_handshake_state(Handshake_IO* io) @@ -111,10 +109,10 @@ void Client::send_client_hello(Handshake_State& state_base, const bool send_npn_request = static_cast<bool>(next_protocol); - if(!force_full_renegotiation && m_hostname != "") + if(!force_full_renegotiation && !m_info.empty()) { Session session_info; - if(session_manager().load_from_host_info(m_hostname, m_port, session_info)) + if(session_manager().load_from_server_info(m_info, session_info)) { if(srp_identifier == "" || session_info.srp_identifier() == srp_identifier) { @@ -142,7 +140,7 @@ void Client::send_client_hello(Handshake_State& state_base, rng(), secure_renegotiation_data_for_client_hello(), send_npn_request, - m_hostname, + m_info.hostname(), srp_identifier)); } @@ -321,7 +319,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, try { - m_creds.verify_certificate_chain("tls-client", m_hostname, server_certs); + m_creds.verify_certificate_chain("tls-client", m_info.hostname(), server_certs); } catch(std::exception& e) { @@ -380,7 +378,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, std::vector<X509_Certificate> client_certs = m_creds.cert_chain(types, "tls-client", - m_hostname); + m_info.hostname()); state.client_certs( new Certificate(state.handshake_io(), @@ -395,7 +393,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, m_policy, m_creds, state.server_public_key.get(), - m_hostname, + m_info.hostname(), rng()) ); @@ -407,7 +405,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, Private_Key* private_key = m_creds.private_key_for(state.client_certs()->cert_chain()[0], "tls-client", - m_hostname); + m_info.hostname()); state.client_verify( new Certificate_Verify(state.handshake_io(), @@ -501,7 +499,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, state.server_hello()->fragment_size(), get_peer_cert_chain(state), session_ticket, - m_hostname, + m_info, "" ); @@ -510,7 +508,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, if(!session_id.empty()) { if(should_save) - session_manager().save(session_info, m_port); + session_manager().save(session_info); else session_manager().remove_entry(session_info.session_id()); } diff --git a/src/tls/tls_client.h b/src/tls/tls_client.h index d7f17f878..b40896e94 100644 --- a/src/tls/tls_client.h +++ b/src/tls/tls_client.h @@ -39,11 +39,10 @@ class BOTAN_DLL Client : public Channel * * @param rng a random number generator * - * @param servername the server's DNS name, if known + * @param server_info is identifying information about the TLS server * - * @param port specifies the protocol port of the server (eg for - * TCP/UDP). Only used if servername is also specified. - * Use 0 if unknown. + * @param offer_version specifies which version we will offer + * to the TLS server. * * @param next_protocol allows the client to specify what the next * protocol will be. For more information read @@ -61,8 +60,8 @@ class BOTAN_DLL Client : public Channel Credentials_Manager& creds, const Policy& policy, RandomNumberGenerator& rng, - const std::string& servername = "", - u16bit port = 0, + const Server_Information& server_info = Server_Information(), + const Protocol_Version offer_version = Protocol_Version::latest_tls_version(), std::function<std::string (std::vector<std::string>)> next_protocol = std::function<std::string (std::vector<std::string>)>()); private: @@ -88,8 +87,7 @@ class BOTAN_DLL Client : public Channel const Policy& m_policy; Credentials_Manager& m_creds; - const std::string m_hostname; - const u16bit m_port; + const Server_Information m_info; }; } diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h index 70745ad9c..f1d4aa887 100644 --- a/src/tls/tls_messages.h +++ b/src/tls/tls_messages.h @@ -254,6 +254,7 @@ class Server_Hello : public Handshake_Message Server_Hello(Handshake_IO& io, Handshake_Hash& hash, + const Policy& policy, const std::vector<byte>& session_id, Protocol_Version ver, u16bit ciphersuite, diff --git a/src/tls/tls_policy.cpp b/src/tls/tls_policy.cpp index b26bd4225..c76fe30a5 100644 --- a/src/tls/tls_policy.cpp +++ b/src/tls/tls_policy.cpp @@ -136,11 +136,6 @@ bool Policy::acceptable_protocol_version(Protocol_Version version) const version == Protocol_Version::TLS_V12); } -Protocol_Version Policy::pref_version() const - { - return Protocol_Version::TLS_V12; - } - namespace { class Ciphersuite_Preference_Ordering diff --git a/src/tls/tls_policy.h b/src/tls/tls_policy.h index 8b73fea9d..cc02dd9b1 100644 --- a/src/tls/tls_policy.h +++ b/src/tls/tls_policy.h @@ -74,6 +74,11 @@ class BOTAN_DLL Policy virtual std::string choose_curve(const std::vector<std::string>& curve_names) const; /** + * Attempt to negotiate the use of the heartbeat extension + */ + virtual bool negotiate_heartbeat_support() const { return false; } + + /** * Allow renegotiation even if the counterparty doesn't * support the secure renegotiation extension. * @@ -119,11 +124,6 @@ class BOTAN_DLL Policy */ virtual bool acceptable_protocol_version(Protocol_Version version) const; - /** - * @return the version we would prefer to negotiate - */ - virtual Protocol_Version pref_version() const; - virtual ~Policy() {} }; diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp index 1a29d317c..1189019bc 100644 --- a/src/tls/tls_server.cpp +++ b/src/tls/tls_server.cpp @@ -96,7 +96,7 @@ bool check_for_resume(Session& session_info, // client sent a different SNI hostname if(client_hello->sni_hostname() != "") { - if(client_hello->sni_hostname() != session_info.sni_hostname()) + if(client_hello->sni_hostname() != session_info.server_info().hostname()) return false; } @@ -288,9 +288,6 @@ void Server::process_handshake_msg(const Handshake_State* active_state, state.client_hello(new Client_Hello(contents, type)); - if(state.client_hello()->sni_hostname() != "") - m_hostname = state.client_hello()->sni_hostname(); - Protocol_Version client_version = state.client_hello()->version(); Protocol_Version negotiated_version; @@ -380,6 +377,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, new Server_Hello( state.handshake_io(), state.hash(), + m_policy, state.client_hello()->session_id(), Protocol_Version(session_info.version()), session_info.ciphersuite_code(), @@ -451,9 +449,11 @@ void Server::process_handshake_msg(const Handshake_State* active_state, { std::map<std::string, std::vector<X509_Certificate> > cert_chains; - cert_chains = get_server_certs(m_hostname, m_creds); + const std::string sni_hostname = state.client_hello()->sni_hostname(); + + cert_chains = get_server_certs(sni_hostname, m_creds); - if(m_hostname != "" && cert_chains.empty()) + if(sni_hostname != "" && cert_chains.empty()) { cert_chains = get_server_certs("", m_creds); @@ -472,6 +472,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, new Server_Hello( state.handshake_io(), state.hash(), + m_policy, make_hello_random(rng()), // new session ID state.version(), choose_ciphersuite(m_policy, @@ -517,7 +518,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, private_key = m_creds.private_key_for( state.server_certs()->cert_chain()[0], "tls-server", - m_hostname); + sni_hostname); if(!private_key) throw Internal_Error("No private key located for associated server cert"); @@ -540,7 +541,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, } std::vector<X509_Certificate> client_auth_CAs = - m_creds.trusted_certificate_authorities("tls-server", m_hostname); + m_creds.trusted_certificate_authorities("tls-server", sni_hostname); if(!client_auth_CAs.empty() && state.ciphersuite().sig_algo() != "") { @@ -663,7 +664,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, state.server_hello()->fragment_size(), get_peer_cert_chain(state), std::vector<byte>(), - m_hostname, + Server_Information(state.client_hello()->sni_hostname()), state.srp_identifier() ); diff --git a/src/tls/tls_server.h b/src/tls/tls_server.h index 761ff6028..25c5b6506 100644 --- a/src/tls/tls_server.h +++ b/src/tls/tls_server.h @@ -36,16 +36,10 @@ class BOTAN_DLL Server : public Channel std::vector<std::string>()); /** - * Return the server name indicator, if sent by the client + * Return the protocol notification set by the client (using the + * NPN extension) for this connection, if any */ - std::string server_name_indicator() const - { return m_hostname; } - - /** - * Return the protocol negotiated with NPN extension - */ - std::string next_protocol() const - { return m_next_protocol; } + std::string next_protocol() const { return m_next_protocol; } private: std::vector<X509_Certificate> @@ -65,7 +59,6 @@ class BOTAN_DLL Server : public Channel Credentials_Manager& m_creds; std::vector<std::string> m_possible_protocols; - std::string m_hostname; std::string m_next_protocol; }; diff --git a/src/tls/tls_server_info.h b/src/tls/tls_server_info.h new file mode 100644 index 000000000..773296eaf --- /dev/null +++ b/src/tls/tls_server_info.h @@ -0,0 +1,91 @@ +/* +* TLS Server Information +* (C) 2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#ifndef BOTAN_TLS_SERVER_INFO_H__ +#define BOTAN_TLS_SERVER_INFO_H__ + +#include <botan/types.h> +#include <string> + +namespace Botan { + +namespace TLS { + +/** +* Represents information known about a TLS server. +*/ +class BOTAN_DLL Server_Information + { + public: + /** + * An empty server info - nothing known + */ + Server_Information() : m_hostname(""), m_service(""), m_port(0) {} + + /** + * @param hostname the host's DNS name, if known + * @param port specifies the protocol port of the server (eg for + * TCP/UDP). Zero represents unknown. + */ + Server_Information(const std::string& hostname, + u16bit port = 0) : + m_hostname(hostname), m_service(""), m_port(port) {} + + /** + * @param hostname the host's DNS name, if known + * @param service is a text string of the service type + * (eg "https", "tor", or "git") + * @param port specifies the protocol port of the server (eg for + * TCP/UDP). Zero represents unknown. + */ + Server_Information(const std::string& hostname, + const std::string& service, + u16bit port = 0) : + m_hostname(hostname), m_service(service), m_port(port) {} + + std::string hostname() const { return m_hostname; } + + std::string service() const { return m_service; } + + u16bit port() const { return m_port; } + + bool empty() const { return m_hostname.empty(); } + + private: + std::string m_hostname, m_service; + u16bit m_port; + }; + +inline bool operator==(const Server_Information& a, const Server_Information& b) + { + return (a.hostname() == b.hostname()) && + (a.service() == b.service()) && + (a.port() == b.port()); + + } + +inline bool operator!=(const Server_Information& a, const Server_Information& b) + { + return !(a == b); + } + +inline bool operator<(const Server_Information& a, const Server_Information& b) + { + if(a.hostname() != b.hostname()) + return (a.hostname() < b.hostname()); + if(a.service() != b.service()) + return (a.service() < b.service()); + if(a.port() != b.port()) + return (a.port() < b.port()); + return false; // equal + } + +} + +} + +#endif diff --git a/src/tls/tls_session.cpp b/src/tls/tls_session.cpp index ae57de0c2..85cb6d69e 100644 --- a/src/tls/tls_session.cpp +++ b/src/tls/tls_session.cpp @@ -27,7 +27,7 @@ Session::Session(const std::vector<byte>& session_identifier, size_t fragment_size, const std::vector<X509_Certificate>& certs, const std::vector<byte>& ticket, - const std::string& sni_hostname, + const Server_Information& server_info, const std::string& srp_identifier) : m_start_time(std::chrono::system_clock::now()), m_identifier(session_identifier), @@ -39,7 +39,7 @@ Session::Session(const std::vector<byte>& session_identifier, m_connection_side(side), m_fragment_size(fragment_size), m_peer_certs(certs), - m_sni_hostname(sni_hostname), + m_server_info(server_info), m_srp_identifier(srp_identifier) { } @@ -54,7 +54,11 @@ Session::Session(const std::string& pem) Session::Session(const byte ber[], size_t ber_len) { byte side_code = 0; - ASN1_String sni_hostname_str; + + ASN1_String server_hostname; + ASN1_String server_service; + size_t server_port; + ASN1_String srp_identifier_str; byte major_version = 0, minor_version = 0; @@ -78,17 +82,23 @@ Session::Session(const byte ber[], size_t ber_len) .decode_integer_type(m_fragment_size) .decode(m_master_secret, OCTET_STRING) .decode(peer_cert_bits, OCTET_STRING) - .decode(sni_hostname_str) + .decode(server_hostname) + .decode(server_service) + .decode(server_port) .decode(srp_identifier_str) .end_cons() .verify_end(); m_version = Protocol_Version(major_version, minor_version); m_start_time = std::chrono::system_clock::from_time_t(start_time); - m_sni_hostname = sni_hostname_str.value(); - m_srp_identifier = srp_identifier_str.value(); m_connection_side = static_cast<Connection_Side>(side_code); + m_server_info = Server_Information(server_hostname.value(), + server_service.value(), + server_port); + + m_srp_identifier = srp_identifier_str.value(); + if(!peer_cert_bits.empty()) { DataSource_Memory certs(&peer_cert_bits[0], peer_cert_bits.size()); @@ -118,7 +128,9 @@ secure_vector<byte> Session::DER_encode() const .encode(static_cast<size_t>(m_fragment_size)) .encode(m_master_secret, OCTET_STRING) .encode(peer_cert_bits, OCTET_STRING) - .encode(ASN1_String(m_sni_hostname, UTF8_STRING)) + .encode(ASN1_String(m_server_info.hostname(), UTF8_STRING)) + .encode(ASN1_String(m_server_info.service(), UTF8_STRING)) + .encode(static_cast<size_t>(m_server_info.port())) .encode(ASN1_String(m_srp_identifier, UTF8_STRING)) .end_cons() .get_contents(); diff --git a/src/tls/tls_session.h b/src/tls/tls_session.h index 206a75081..65154dfce 100644 --- a/src/tls/tls_session.h +++ b/src/tls/tls_session.h @@ -12,6 +12,7 @@ #include <botan/tls_version.h> #include <botan/tls_ciphersuite.h> #include <botan/tls_magic.h> +#include <botan/tls_server_info.h> #include <botan/secmem.h> #include <botan/symkey.h> #include <chrono> @@ -51,8 +52,8 @@ class BOTAN_DLL Session size_t fragment_size, const std::vector<X509_Certificate>& peer_certs, const std::vector<byte>& session_ticket, - const std::string& sni_hostname = "", - const std::string& srp_identifier = ""); + const Server_Information& server_info, + const std::string& srp_identifier); /** * Load a session from DER representation (created by DER_encode) @@ -133,11 +134,6 @@ class BOTAN_DLL Session Connection_Side side() const { return m_connection_side; } /** - * Get the SNI hostname (if sent by the client in the initial handshake) - */ - std::string sni_hostname() const { return m_sni_hostname; } - - /** * Get the SRP identity (if sent by the client in the initial handshake) */ std::string srp_identifier() const { return m_srp_identifier; } @@ -180,6 +176,8 @@ class BOTAN_DLL Session */ const std::vector<byte>& session_ticket() const { return m_session_ticket; } + Server_Information server_info() const { return m_server_info; } + private: enum { TLS_SESSION_PARAM_STRUCT_VERSION = 0x2994e301 }; @@ -197,7 +195,7 @@ class BOTAN_DLL Session size_t m_fragment_size; std::vector<X509_Certificate> m_peer_certs; - std::string m_sni_hostname; // optional + Server_Information m_server_info; // optional std::string m_srp_identifier; // optional }; diff --git a/src/tls/tls_session_manager.cpp b/src/tls/tls_session_manager.cpp index 673ee90ff..ca18231a0 100644 --- a/src/tls/tls_session_manager.cpp +++ b/src/tls/tls_session_manager.cpp @@ -61,27 +61,24 @@ bool Session_Manager_In_Memory::load_from_session_id( return load_from_session_str(hex_encode(session_id), session); } -bool Session_Manager_In_Memory::load_from_host_info( - const std::string& hostname, u16bit port, Session& session) +bool Session_Manager_In_Memory::load_from_server_info( + const Server_Information& info, Session& session) { std::lock_guard<std::mutex> lock(m_mutex); - auto i = m_host_sessions.find(hostname + ":" + std::to_string(port)); + auto i = m_info_sessions.find(info); - if(i == m_host_sessions.end()) - { - if(port > 0) - i = m_host_sessions.find(hostname + ":" + std::to_string(0)); - - if(i == m_host_sessions.end()) - return false; - } + if(i == m_info_sessions.end()) + return false; if(load_from_session_str(i->second, session)) return true; - // was removed from sessions map, remove m_host_sessions entry - m_host_sessions.erase(i); + /* + * It existed at one point but was removed from the sessions map, + * remove m_info_sessions entry as well + */ + m_info_sessions.erase(i); return false; } @@ -97,7 +94,7 @@ void Session_Manager_In_Memory::remove_entry( m_sessions.erase(i); } -void Session_Manager_In_Memory::save(const Session& session, u16bit port) +void Session_Manager_In_Memory::save(const Session& session) { std::lock_guard<std::mutex> lock(m_mutex); @@ -115,10 +112,8 @@ void Session_Manager_In_Memory::save(const Session& session, u16bit port) m_sessions[session_id_str] = session.encrypt(m_session_key, m_rng); - const std::string hostname = session.sni_hostname(); - - if(session.side() == CLIENT && hostname != "") - m_host_sessions[hostname + ":" + std::to_string(port)] = session_id_str; + if(session.side() == CLIENT && !session.server_info().empty()) + m_info_sessions[session.server_info()] = session_id_str; } } diff --git a/src/tls/tls_session_manager.h b/src/tls/tls_session_manager.h index 4efefb6ff..d7c805195 100644 --- a/src/tls/tls_session_manager.h +++ b/src/tls/tls_session_manager.h @@ -30,7 +30,7 @@ class BOTAN_DLL Session_Manager { public: /** - * Try to load a saved session (server side) + * Try to load a saved session (using session ID) * @param session_id the session identifier we are trying to resume * @param session will be set to the saved session data (if found), or not modified if not found @@ -40,15 +40,14 @@ class BOTAN_DLL Session_Manager Session& session) = 0; /** - * Try to load a saved session (client side) - * @param hostname of the host we are connecting to - * @param port the port number if we know it, or 0 if unknown + * Try to load a saved session (using info about server) + * @param info the information about the server * @param session will be set to the saved session data (if found), or not modified if not found * @return true if session was modified */ - virtual bool load_from_host_info(const std::string& hostname, u16bit port, - Session& session) = 0; + virtual bool load_from_server_info(const Server_Information& info, + Session& session) = 0; /** * Remove this session id from the cache, if it exists @@ -64,7 +63,7 @@ class BOTAN_DLL Session_Manager * @param session to save * @param port the protocol port (if known) */ - virtual void save(const Session& session, u16bit port = 0) = 0; + virtual void save(const Session& session) = 0; /** * Return the allowed lifetime of a session; beyond this time, @@ -86,12 +85,12 @@ class BOTAN_DLL Session_Manager_Noop : public Session_Manager bool load_from_session_id(const std::vector<byte>&, Session&) override { return false; } - bool load_from_host_info(const std::string&, u16bit, Session&) override + bool load_from_server_info(const Server_Information&, Session&) override { return false; } void remove_entry(const std::vector<byte>&) override {} - void save(const Session&, u16bit) override {} + void save(const Session&) override {} std::chrono::seconds session_lifetime() const override { return std::chrono::seconds(0); } @@ -116,12 +115,12 @@ class BOTAN_DLL Session_Manager_In_Memory : public Session_Manager bool load_from_session_id(const std::vector<byte>& session_id, Session& session) override; - bool load_from_host_info(const std::string& hostname, u16bit port, - Session& session) override; + bool load_from_server_info(const Server_Information& info, + Session& session) override; void remove_entry(const std::vector<byte>& session_id) override; - void save(const Session& session_data, u16bit port) override; + void save(const Session& session_data) override; std::chrono::seconds session_lifetime() const override { return m_session_lifetime; } @@ -140,7 +139,7 @@ class BOTAN_DLL Session_Manager_In_Memory : public Session_Manager SymmetricKey m_session_key; std::map<std::string, std::vector<byte>> m_sessions; // hex(session_id) -> session - std::map<std::string, std::string> m_host_sessions; + std::map<Server_Information, std::string> m_info_sessions; }; } diff --git a/src/tls/tls_version.h b/src/tls/tls_version.h index 651eebafc..39712db27 100644 --- a/src/tls/tls_version.h +++ b/src/tls/tls_version.h @@ -31,6 +31,16 @@ class BOTAN_DLL Protocol_Version DTLS_V12 = 0xFEFD }; + static Protocol_Version latest_tls_version() + { + return Protocol_Version(TLS_V12); + } + + static Protocol_Version latest_dtls_version() + { + return Protocol_Version(DTLS_V12); + } + Protocol_Version() : m_version(0) {} /** diff --git a/src/utils/assert.h b/src/utils/assert.h index 88d514b43..d92b41111 100644 --- a/src/utils/assert.h +++ b/src/utils/assert.h @@ -35,10 +35,10 @@ void assertion_failure(const char* expr_str, /** * Assert that value1 == value2 */ -#define BOTAN_ASSERT_EQUAL(value1, value2, assertion_made) \ +#define BOTAN_ASSERT_EQUAL(expr1, expr2, assertion_made) \ do { \ - if(value1 != value2) \ - Botan::assertion_failure(#value1 " == " #value2, \ + if((expr1) != (expr2)) \ + Botan::assertion_failure(#expr1 " == " #expr2, \ assertion_made, \ __func__, \ __FILE__, \ |