diff options
-rw-r--r-- | doc/examples/tls_client.cpp | 3 | ||||
-rw-r--r-- | doc/tls.rst | 41 | ||||
-rw-r--r-- | src/tls/info.txt | 1 | ||||
-rw-r--r-- | src/tls/msg_client_hello.cpp | 2 | ||||
-rw-r--r-- | src/tls/sessions_sqlite/tls_session_manager_sqlite.cpp | 18 | ||||
-rw-r--r-- | src/tls/sessions_sqlite/tls_session_manager_sqlite.h | 6 | ||||
-rw-r--r-- | src/tls/tls_client.cpp | 26 | ||||
-rw-r--r-- | src/tls/tls_client.h | 12 | ||||
-rw-r--r-- | src/tls/tls_server.cpp | 17 | ||||
-rw-r--r-- | src/tls/tls_server.h | 13 | ||||
-rw-r--r-- | src/tls/tls_server_info.h | 91 | ||||
-rw-r--r-- | src/tls/tls_session.cpp | 26 | ||||
-rw-r--r-- | src/tls/tls_session.h | 14 | ||||
-rw-r--r-- | src/tls/tls_session_manager.cpp | 31 | ||||
-rw-r--r-- | src/tls/tls_session_manager.h | 25 |
15 files changed, 202 insertions, 124 deletions
diff --git a/doc/examples/tls_client.cpp b/doc/examples/tls_client.cpp index d56143a36..a9efe21e1 100644 --- a/doc/examples/tls_client.cpp +++ b/doc/examples/tls_client.cpp @@ -168,8 +168,7 @@ int main(int argc, char* argv[]) creds, policy, rng, - host, - port, + TLS::Server_Information(host, port), protocol_chooser); while(!client.is_closed()) diff --git a/doc/tls.rst b/doc/tls.rst index 2f560b72f..18af678b9 100644 --- a/doc/tls.rst +++ b/doc/tls.rst @@ -165,7 +165,7 @@ TLS Clients Credentials_Manager& credendials_manager, \ const TLS::Policy& policy, \ RandomNumberGenerator& rng, \ - const std::string& servername = "", \ + const Server_Information& server_info = Server_Information(), \ std::function<std::string, std::vector<std::string> > next_protocol) Initialize a new TLS client. The constructor will immediately @@ -206,7 +206,7 @@ TLS Clients retrieve any certificates, secret keys, pre-shared keys, or SRP intformation; see :doc:`credentials_manager` for more information. - Use *servername* to specify the DNS name of the server you are + Use *server_info* to specify the DNS name of the server you are attempting to connect to, if you know it. This helps the server select what certificate to use and helps the client validate the connection. @@ -240,6 +240,16 @@ The first 7 arguments are treated similiarly to the :ref:`client <tls_client>`. The final (optional) argument, protocols, specifies the protocols the server is willing to advertise it supports. +.. cpp:class:: std::string TLS::Server::next_protocol() const + + If a handshake has completed, and if the client indicated a next + protocol (ie, the protocol that it intends to run over this TLS + session) this return value will specify it. The next protocol + extension is somewhat unusual in that it applies to the connection + rather than the session. The next protocol can not change during a + renegotiation, but might change across different connections using + that session. + A TLS server that can handle concurrent connections using asio: .. literalinclude:: examples/asio_tls_server.cpp @@ -270,9 +280,13 @@ information about that session: Returns the :cpp:class:`ciphersuite <TLS::Ciphersuite>` that was negotiated. - .. cpp:function:: std::string sni_hostname() const + .. cpp:function:: Server_Information server_info() const - Returns the hostname the client indicated in the hello message. + Returns information that identifies the server side of the + connection. This is useful for the client in that it + identifies what was originally passed to the constructor. For + the server, it includes the name the client specified in the + server name indicator extension. .. cpp:function:: std::vector<X509_Certificate> peer_certs() const @@ -331,17 +345,12 @@ implementation to the ``TLS::Client`` or ``TLS::Server`` constructor. .. cpp:class:: TLS::Session_Mananger - .. cpp:function:: void save(const Session& session, u16bit port) + .. cpp:function:: void save(const Session& session) Save a new *session*. It is possible that this sessions session ID will replicate a session ID already stored, in which case the new session information should overwrite the previous information. - Clients will specify *port* if they know it (it will be zero if - they do not, or for servers). It specifies the remote port of the - server which is used to assist with looking up the correct - session when using :cpp:func:`load_from_host_info`. - .. cpp:function:: void remove_entry(const std::vector<byte>& session_id) Remove the session identified by *session_id*. Future attempts @@ -355,16 +364,10 @@ implementation to the ``TLS::Client`` or ``TLS::Server`` constructor. to *save*, and ``true`` is returned. Otherwise *session* is not modified and ``false`` is returned. - .. cpp:function:: bool load_from_host_info(const std::string& hostname, \ - u16bit port, \ - Session& session) - - Attempt to resume a session for *hostname* / *port*. + .. cpp:function:: bool load_from_server_info(const Server_Information& server, \ + Session& session) - The session managers included in the library will, if they fail - to find an exact match for *hostname* and *port*, will also - check for a session saved using a matching hostname and a port - of zero. + Attempt to resume a session with a known server. .. cpp:function:: std::chrono::seconds session_lifetime() const diff --git a/src/tls/info.txt b/src/tls/info.txt index e61b2c0da..47de42598 100644 --- a/src/tls/info.txt +++ b/src/tls/info.txt @@ -15,6 +15,7 @@ tls_client.h tls_exceptn.h tls_handshake_msg.h tls_magic.h +tls_server_info.h tls_policy.h tls_server.h tls_session.h diff --git a/src/tls/msg_client_hello.cpp b/src/tls/msg_client_hello.cpp index 2149ac5e5..3df54cd8f 100644 --- a/src/tls/msg_client_hello.cpp +++ b/src/tls/msg_client_hello.cpp @@ -116,7 +116,7 @@ Client_Hello::Client_Hello(Handshake_IO& io, m_extensions.add(new Heartbeat_Support_Indicator(true)); m_extensions.add(new Renegotiation_Extension(reneg_info)); m_extensions.add(new SRP_Identifier(session.srp_identifier())); - m_extensions.add(new Server_Name_Indicator(session.sni_hostname())); + m_extensions.add(new Server_Name_Indicator(session.server_info().hostname())); m_extensions.add(new Session_Ticket(session.session_ticket())); m_extensions.add(new Supported_Elliptic_Curves(policy.allowed_ecc_curves())); diff --git a/src/tls/sessions_sqlite/tls_session_manager_sqlite.cpp b/src/tls/sessions_sqlite/tls_session_manager_sqlite.cpp index d10366c60..87556ff75 100644 --- a/src/tls/sessions_sqlite/tls_session_manager_sqlite.cpp +++ b/src/tls/sessions_sqlite/tls_session_manager_sqlite.cpp @@ -142,16 +142,15 @@ bool Session_Manager_SQLite::load_from_session_id(const std::vector<byte>& sessi return false; } -bool Session_Manager_SQLite::load_from_host_info(const std::string& hostname, - u16bit port, - Session& session) +bool Session_Manager_SQLite::load_from_server_info(const Server_Information& server, + Session& session) { sqlite3_statement stmt(m_db, "select session from tls_sessions" " where hostname = ?1 and hostport = ?2" " order by session_start desc"); - stmt.bind(1, hostname); - stmt.bind(2, port); + stmt.bind(1, server.hostname()); + stmt.bind(2, server.port()); while(stmt.step()) { @@ -167,9 +166,6 @@ bool Session_Manager_SQLite::load_from_host_info(const std::string& hostname, } } - if(port != 0) - return load_from_host_info(hostname, 0, session); - return false; } @@ -182,15 +178,15 @@ void Session_Manager_SQLite::remove_entry(const std::vector<byte>& session_id) stmt.spin(); } -void Session_Manager_SQLite::save(const Session& session, u16bit port) +void Session_Manager_SQLite::save(const Session& session) { sqlite3_statement stmt(m_db, "insert or replace into tls_sessions" " values(?1, ?2, ?3, ?4, ?5)"); stmt.bind(1, hex_encode(session.session_id())); stmt.bind(2, session.start_time()); - stmt.bind(3, session.sni_hostname()); - stmt.bind(4, port); + stmt.bind(3, session.server_info().hostname()); + stmt.bind(4, session.server_info().port()); stmt.bind(5, session.encrypt(m_session_key, m_rng)); stmt.spin(); diff --git a/src/tls/sessions_sqlite/tls_session_manager_sqlite.h b/src/tls/sessions_sqlite/tls_session_manager_sqlite.h index db74f54b7..7892ccd6a 100644 --- a/src/tls/sessions_sqlite/tls_session_manager_sqlite.h +++ b/src/tls/sessions_sqlite/tls_session_manager_sqlite.h @@ -50,12 +50,12 @@ class BOTAN_DLL Session_Manager_SQLite : public Session_Manager bool load_from_session_id(const std::vector<byte>& session_id, Session& session) override; - bool load_from_host_info(const std::string& hostname, u16bit port, - Session& session) override; + bool load_from_server_info(const Server_Information& info, + Session& session) override; void remove_entry(const std::vector<byte>& session_id) override; - void save(const Session& session_data, u16bit port) override; + void save(const Session& session_data) override; std::chrono::seconds session_lifetime() const override { return m_session_lifetime; } diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index bb6b7a45f..3793b7529 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -55,16 +55,14 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn, Credentials_Manager& creds, const Policy& policy, RandomNumberGenerator& rng, - const std::string& hostname, - u16bit port, + const Server_Information& info, std::function<std::string (std::vector<std::string>)> next_protocol) : Channel(output_fn, proc_fn, handshake_fn, session_manager, rng), m_policy(policy), m_creds(creds), - m_hostname(hostname), - m_port(port) + m_info(info) { - const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_hostname); + const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_info.hostname()); const Protocol_Version version = m_policy.pref_version(); Handshake_State& state = create_handshake_state(version); @@ -111,10 +109,10 @@ void Client::send_client_hello(Handshake_State& state_base, const bool send_npn_request = static_cast<bool>(next_protocol); - if(!force_full_renegotiation && m_hostname != "") + if(!force_full_renegotiation && !m_info.empty()) { Session session_info; - if(session_manager().load_from_host_info(m_hostname, m_port, session_info)) + if(session_manager().load_from_server_info(m_info, session_info)) { if(srp_identifier == "" || session_info.srp_identifier() == srp_identifier) { @@ -142,7 +140,7 @@ void Client::send_client_hello(Handshake_State& state_base, rng(), secure_renegotiation_data_for_client_hello(), send_npn_request, - m_hostname, + m_info.hostname(), srp_identifier)); } @@ -321,7 +319,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, try { - m_creds.verify_certificate_chain("tls-client", m_hostname, server_certs); + m_creds.verify_certificate_chain("tls-client", m_info.hostname(), server_certs); } catch(std::exception& e) { @@ -380,7 +378,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, std::vector<X509_Certificate> client_certs = m_creds.cert_chain(types, "tls-client", - m_hostname); + m_info.hostname()); state.client_certs( new Certificate(state.handshake_io(), @@ -395,7 +393,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, m_policy, m_creds, state.server_public_key.get(), - m_hostname, + m_info.hostname(), rng()) ); @@ -407,7 +405,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, Private_Key* private_key = m_creds.private_key_for(state.client_certs()->cert_chain()[0], "tls-client", - m_hostname); + m_info.hostname()); state.client_verify( new Certificate_Verify(state.handshake_io(), @@ -501,7 +499,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, state.server_hello()->fragment_size(), get_peer_cert_chain(state), session_ticket, - m_hostname, + m_info, "" ); @@ -510,7 +508,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state, if(!session_id.empty()) { if(should_save) - session_manager().save(session_info, m_port); + session_manager().save(session_info); else session_manager().remove_entry(session_info.session_id()); } diff --git a/src/tls/tls_client.h b/src/tls/tls_client.h index d7f17f878..e42322f7d 100644 --- a/src/tls/tls_client.h +++ b/src/tls/tls_client.h @@ -39,11 +39,7 @@ class BOTAN_DLL Client : public Channel * * @param rng a random number generator * - * @param servername the server's DNS name, if known - * - * @param port specifies the protocol port of the server (eg for - * TCP/UDP). Only used if servername is also specified. - * Use 0 if unknown. + * @param info is identifying information about the TLS server * * @param next_protocol allows the client to specify what the next * protocol will be. For more information read @@ -61,8 +57,7 @@ class BOTAN_DLL Client : public Channel Credentials_Manager& creds, const Policy& policy, RandomNumberGenerator& rng, - const std::string& servername = "", - u16bit port = 0, + const Server_Information& info = Server_Information(), std::function<std::string (std::vector<std::string>)> next_protocol = std::function<std::string (std::vector<std::string>)>()); private: @@ -88,8 +83,7 @@ class BOTAN_DLL Client : public Channel const Policy& m_policy; Credentials_Manager& m_creds; - const std::string m_hostname; - const u16bit m_port; + const Server_Information m_info; }; } diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp index 1a29d317c..4468854a4 100644 --- a/src/tls/tls_server.cpp +++ b/src/tls/tls_server.cpp @@ -96,7 +96,7 @@ bool check_for_resume(Session& session_info, // client sent a different SNI hostname if(client_hello->sni_hostname() != "") { - if(client_hello->sni_hostname() != session_info.sni_hostname()) + if(client_hello->sni_hostname() != session_info.server_info().hostname()) return false; } @@ -288,9 +288,6 @@ void Server::process_handshake_msg(const Handshake_State* active_state, state.client_hello(new Client_Hello(contents, type)); - if(state.client_hello()->sni_hostname() != "") - m_hostname = state.client_hello()->sni_hostname(); - Protocol_Version client_version = state.client_hello()->version(); Protocol_Version negotiated_version; @@ -451,9 +448,11 @@ void Server::process_handshake_msg(const Handshake_State* active_state, { std::map<std::string, std::vector<X509_Certificate> > cert_chains; - cert_chains = get_server_certs(m_hostname, m_creds); + const std::string sni_hostname = state.client_hello()->sni_hostname(); + + cert_chains = get_server_certs(sni_hostname, m_creds); - if(m_hostname != "" && cert_chains.empty()) + if(sni_hostname != "" && cert_chains.empty()) { cert_chains = get_server_certs("", m_creds); @@ -517,7 +516,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, private_key = m_creds.private_key_for( state.server_certs()->cert_chain()[0], "tls-server", - m_hostname); + sni_hostname); if(!private_key) throw Internal_Error("No private key located for associated server cert"); @@ -540,7 +539,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, } std::vector<X509_Certificate> client_auth_CAs = - m_creds.trusted_certificate_authorities("tls-server", m_hostname); + m_creds.trusted_certificate_authorities("tls-server", sni_hostname); if(!client_auth_CAs.empty() && state.ciphersuite().sig_algo() != "") { @@ -663,7 +662,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state, state.server_hello()->fragment_size(), get_peer_cert_chain(state), std::vector<byte>(), - m_hostname, + Server_Information(state.client_hello()->sni_hostname()), state.srp_identifier() ); diff --git a/src/tls/tls_server.h b/src/tls/tls_server.h index 761ff6028..25c5b6506 100644 --- a/src/tls/tls_server.h +++ b/src/tls/tls_server.h @@ -36,16 +36,10 @@ class BOTAN_DLL Server : public Channel std::vector<std::string>()); /** - * Return the server name indicator, if sent by the client + * Return the protocol notification set by the client (using the + * NPN extension) for this connection, if any */ - std::string server_name_indicator() const - { return m_hostname; } - - /** - * Return the protocol negotiated with NPN extension - */ - std::string next_protocol() const - { return m_next_protocol; } + std::string next_protocol() const { return m_next_protocol; } private: std::vector<X509_Certificate> @@ -65,7 +59,6 @@ class BOTAN_DLL Server : public Channel Credentials_Manager& m_creds; std::vector<std::string> m_possible_protocols; - std::string m_hostname; std::string m_next_protocol; }; diff --git a/src/tls/tls_server_info.h b/src/tls/tls_server_info.h new file mode 100644 index 000000000..773296eaf --- /dev/null +++ b/src/tls/tls_server_info.h @@ -0,0 +1,91 @@ +/* +* TLS Server Information +* (C) 2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#ifndef BOTAN_TLS_SERVER_INFO_H__ +#define BOTAN_TLS_SERVER_INFO_H__ + +#include <botan/types.h> +#include <string> + +namespace Botan { + +namespace TLS { + +/** +* Represents information known about a TLS server. +*/ +class BOTAN_DLL Server_Information + { + public: + /** + * An empty server info - nothing known + */ + Server_Information() : m_hostname(""), m_service(""), m_port(0) {} + + /** + * @param hostname the host's DNS name, if known + * @param port specifies the protocol port of the server (eg for + * TCP/UDP). Zero represents unknown. + */ + Server_Information(const std::string& hostname, + u16bit port = 0) : + m_hostname(hostname), m_service(""), m_port(port) {} + + /** + * @param hostname the host's DNS name, if known + * @param service is a text string of the service type + * (eg "https", "tor", or "git") + * @param port specifies the protocol port of the server (eg for + * TCP/UDP). Zero represents unknown. + */ + Server_Information(const std::string& hostname, + const std::string& service, + u16bit port = 0) : + m_hostname(hostname), m_service(service), m_port(port) {} + + std::string hostname() const { return m_hostname; } + + std::string service() const { return m_service; } + + u16bit port() const { return m_port; } + + bool empty() const { return m_hostname.empty(); } + + private: + std::string m_hostname, m_service; + u16bit m_port; + }; + +inline bool operator==(const Server_Information& a, const Server_Information& b) + { + return (a.hostname() == b.hostname()) && + (a.service() == b.service()) && + (a.port() == b.port()); + + } + +inline bool operator!=(const Server_Information& a, const Server_Information& b) + { + return !(a == b); + } + +inline bool operator<(const Server_Information& a, const Server_Information& b) + { + if(a.hostname() != b.hostname()) + return (a.hostname() < b.hostname()); + if(a.service() != b.service()) + return (a.service() < b.service()); + if(a.port() != b.port()) + return (a.port() < b.port()); + return false; // equal + } + +} + +} + +#endif diff --git a/src/tls/tls_session.cpp b/src/tls/tls_session.cpp index ae57de0c2..85cb6d69e 100644 --- a/src/tls/tls_session.cpp +++ b/src/tls/tls_session.cpp @@ -27,7 +27,7 @@ Session::Session(const std::vector<byte>& session_identifier, size_t fragment_size, const std::vector<X509_Certificate>& certs, const std::vector<byte>& ticket, - const std::string& sni_hostname, + const Server_Information& server_info, const std::string& srp_identifier) : m_start_time(std::chrono::system_clock::now()), m_identifier(session_identifier), @@ -39,7 +39,7 @@ Session::Session(const std::vector<byte>& session_identifier, m_connection_side(side), m_fragment_size(fragment_size), m_peer_certs(certs), - m_sni_hostname(sni_hostname), + m_server_info(server_info), m_srp_identifier(srp_identifier) { } @@ -54,7 +54,11 @@ Session::Session(const std::string& pem) Session::Session(const byte ber[], size_t ber_len) { byte side_code = 0; - ASN1_String sni_hostname_str; + + ASN1_String server_hostname; + ASN1_String server_service; + size_t server_port; + ASN1_String srp_identifier_str; byte major_version = 0, minor_version = 0; @@ -78,17 +82,23 @@ Session::Session(const byte ber[], size_t ber_len) .decode_integer_type(m_fragment_size) .decode(m_master_secret, OCTET_STRING) .decode(peer_cert_bits, OCTET_STRING) - .decode(sni_hostname_str) + .decode(server_hostname) + .decode(server_service) + .decode(server_port) .decode(srp_identifier_str) .end_cons() .verify_end(); m_version = Protocol_Version(major_version, minor_version); m_start_time = std::chrono::system_clock::from_time_t(start_time); - m_sni_hostname = sni_hostname_str.value(); - m_srp_identifier = srp_identifier_str.value(); m_connection_side = static_cast<Connection_Side>(side_code); + m_server_info = Server_Information(server_hostname.value(), + server_service.value(), + server_port); + + m_srp_identifier = srp_identifier_str.value(); + if(!peer_cert_bits.empty()) { DataSource_Memory certs(&peer_cert_bits[0], peer_cert_bits.size()); @@ -118,7 +128,9 @@ secure_vector<byte> Session::DER_encode() const .encode(static_cast<size_t>(m_fragment_size)) .encode(m_master_secret, OCTET_STRING) .encode(peer_cert_bits, OCTET_STRING) - .encode(ASN1_String(m_sni_hostname, UTF8_STRING)) + .encode(ASN1_String(m_server_info.hostname(), UTF8_STRING)) + .encode(ASN1_String(m_server_info.service(), UTF8_STRING)) + .encode(static_cast<size_t>(m_server_info.port())) .encode(ASN1_String(m_srp_identifier, UTF8_STRING)) .end_cons() .get_contents(); diff --git a/src/tls/tls_session.h b/src/tls/tls_session.h index 206a75081..65154dfce 100644 --- a/src/tls/tls_session.h +++ b/src/tls/tls_session.h @@ -12,6 +12,7 @@ #include <botan/tls_version.h> #include <botan/tls_ciphersuite.h> #include <botan/tls_magic.h> +#include <botan/tls_server_info.h> #include <botan/secmem.h> #include <botan/symkey.h> #include <chrono> @@ -51,8 +52,8 @@ class BOTAN_DLL Session size_t fragment_size, const std::vector<X509_Certificate>& peer_certs, const std::vector<byte>& session_ticket, - const std::string& sni_hostname = "", - const std::string& srp_identifier = ""); + const Server_Information& server_info, + const std::string& srp_identifier); /** * Load a session from DER representation (created by DER_encode) @@ -133,11 +134,6 @@ class BOTAN_DLL Session Connection_Side side() const { return m_connection_side; } /** - * Get the SNI hostname (if sent by the client in the initial handshake) - */ - std::string sni_hostname() const { return m_sni_hostname; } - - /** * Get the SRP identity (if sent by the client in the initial handshake) */ std::string srp_identifier() const { return m_srp_identifier; } @@ -180,6 +176,8 @@ class BOTAN_DLL Session */ const std::vector<byte>& session_ticket() const { return m_session_ticket; } + Server_Information server_info() const { return m_server_info; } + private: enum { TLS_SESSION_PARAM_STRUCT_VERSION = 0x2994e301 }; @@ -197,7 +195,7 @@ class BOTAN_DLL Session size_t m_fragment_size; std::vector<X509_Certificate> m_peer_certs; - std::string m_sni_hostname; // optional + Server_Information m_server_info; // optional std::string m_srp_identifier; // optional }; diff --git a/src/tls/tls_session_manager.cpp b/src/tls/tls_session_manager.cpp index 673ee90ff..ca18231a0 100644 --- a/src/tls/tls_session_manager.cpp +++ b/src/tls/tls_session_manager.cpp @@ -61,27 +61,24 @@ bool Session_Manager_In_Memory::load_from_session_id( return load_from_session_str(hex_encode(session_id), session); } -bool Session_Manager_In_Memory::load_from_host_info( - const std::string& hostname, u16bit port, Session& session) +bool Session_Manager_In_Memory::load_from_server_info( + const Server_Information& info, Session& session) { std::lock_guard<std::mutex> lock(m_mutex); - auto i = m_host_sessions.find(hostname + ":" + std::to_string(port)); + auto i = m_info_sessions.find(info); - if(i == m_host_sessions.end()) - { - if(port > 0) - i = m_host_sessions.find(hostname + ":" + std::to_string(0)); - - if(i == m_host_sessions.end()) - return false; - } + if(i == m_info_sessions.end()) + return false; if(load_from_session_str(i->second, session)) return true; - // was removed from sessions map, remove m_host_sessions entry - m_host_sessions.erase(i); + /* + * It existed at one point but was removed from the sessions map, + * remove m_info_sessions entry as well + */ + m_info_sessions.erase(i); return false; } @@ -97,7 +94,7 @@ void Session_Manager_In_Memory::remove_entry( m_sessions.erase(i); } -void Session_Manager_In_Memory::save(const Session& session, u16bit port) +void Session_Manager_In_Memory::save(const Session& session) { std::lock_guard<std::mutex> lock(m_mutex); @@ -115,10 +112,8 @@ void Session_Manager_In_Memory::save(const Session& session, u16bit port) m_sessions[session_id_str] = session.encrypt(m_session_key, m_rng); - const std::string hostname = session.sni_hostname(); - - if(session.side() == CLIENT && hostname != "") - m_host_sessions[hostname + ":" + std::to_string(port)] = session_id_str; + if(session.side() == CLIENT && !session.server_info().empty()) + m_info_sessions[session.server_info()] = session_id_str; } } diff --git a/src/tls/tls_session_manager.h b/src/tls/tls_session_manager.h index 4efefb6ff..d7c805195 100644 --- a/src/tls/tls_session_manager.h +++ b/src/tls/tls_session_manager.h @@ -30,7 +30,7 @@ class BOTAN_DLL Session_Manager { public: /** - * Try to load a saved session (server side) + * Try to load a saved session (using session ID) * @param session_id the session identifier we are trying to resume * @param session will be set to the saved session data (if found), or not modified if not found @@ -40,15 +40,14 @@ class BOTAN_DLL Session_Manager Session& session) = 0; /** - * Try to load a saved session (client side) - * @param hostname of the host we are connecting to - * @param port the port number if we know it, or 0 if unknown + * Try to load a saved session (using info about server) + * @param info the information about the server * @param session will be set to the saved session data (if found), or not modified if not found * @return true if session was modified */ - virtual bool load_from_host_info(const std::string& hostname, u16bit port, - Session& session) = 0; + virtual bool load_from_server_info(const Server_Information& info, + Session& session) = 0; /** * Remove this session id from the cache, if it exists @@ -64,7 +63,7 @@ class BOTAN_DLL Session_Manager * @param session to save * @param port the protocol port (if known) */ - virtual void save(const Session& session, u16bit port = 0) = 0; + virtual void save(const Session& session) = 0; /** * Return the allowed lifetime of a session; beyond this time, @@ -86,12 +85,12 @@ class BOTAN_DLL Session_Manager_Noop : public Session_Manager bool load_from_session_id(const std::vector<byte>&, Session&) override { return false; } - bool load_from_host_info(const std::string&, u16bit, Session&) override + bool load_from_server_info(const Server_Information&, Session&) override { return false; } void remove_entry(const std::vector<byte>&) override {} - void save(const Session&, u16bit) override {} + void save(const Session&) override {} std::chrono::seconds session_lifetime() const override { return std::chrono::seconds(0); } @@ -116,12 +115,12 @@ class BOTAN_DLL Session_Manager_In_Memory : public Session_Manager bool load_from_session_id(const std::vector<byte>& session_id, Session& session) override; - bool load_from_host_info(const std::string& hostname, u16bit port, - Session& session) override; + bool load_from_server_info(const Server_Information& info, + Session& session) override; void remove_entry(const std::vector<byte>& session_id) override; - void save(const Session& session_data, u16bit port) override; + void save(const Session& session_data) override; std::chrono::seconds session_lifetime() const override { return m_session_lifetime; } @@ -140,7 +139,7 @@ class BOTAN_DLL Session_Manager_In_Memory : public Session_Manager SymmetricKey m_session_key; std::map<std::string, std::vector<byte>> m_sessions; // hex(session_id) -> session - std::map<std::string, std::string> m_host_sessions; + std::map<Server_Information, std::string> m_info_sessions; }; } |