diff options
-rw-r--r-- | src/tls/tls_channel.cpp | 19 | ||||
-rw-r--r-- | src/tls/tls_channel.h | 26 | ||||
-rw-r--r-- | src/tls/tls_client.cpp | 17 | ||||
-rw-r--r-- | src/tls/tls_client.h | 2 | ||||
-rw-r--r-- | src/tls/tls_server.cpp | 76 | ||||
-rw-r--r-- | src/tls/tls_server.h | 8 |
6 files changed, 85 insertions, 63 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index c06bd3e3a..48f142fb5 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -18,11 +18,13 @@ namespace TLS { Channel::Channel(std::function<void (const byte[], size_t)> socket_output_fn, std::function<void (const byte[], size_t, Alert)> proc_fn, - std::function<bool (const Session&)> handshake_complete) : + std::function<bool (const Session&)> handshake_complete, + Session_Manager& session_manager) : m_proc_fn(proc_fn), m_handshake_fn(handshake_complete), - m_writer(socket_output_fn), m_state(nullptr), + m_session_manager(session_manager), + m_writer(socket_output_fn), m_handshake_completed(false), m_connection_closed(false), m_peer_supports_heartbeats(false), @@ -120,6 +122,13 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) else if(alert_msg.is_fatal()) { // delete state immediately + + if(!m_active_session.empty()) + { + m_session_manager.remove_entry(m_active_session); + m_active_session.clear(); + } + m_connection_closed = true; delete m_state; @@ -236,6 +245,12 @@ void Channel::send_alert(const Alert& alert) catch(...) { /* swallow it */ } } + if(alert.is_fatal() && !m_active_session.empty()) + { + m_session_manager.remove_entry(m_active_session); + m_active_session.clear(); + } + if(!m_connection_closed && (alert.type() == Alert::CLOSE_NOTIFY || alert.is_fatal())) { m_connection_closed = true; diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h index ca4247c85..1f4075b01 100644 --- a/src/tls/tls_channel.h +++ b/src/tls/tls_channel.h @@ -12,6 +12,7 @@ #include <botan/tls_record.h> #include <botan/tls_session.h> #include <botan/tls_alert.h> +#include <botan/tls_session_manager.h> #include <botan/x509cert.h> #include <vector> @@ -78,7 +79,8 @@ class BOTAN_DLL Channel Channel(std::function<void (const byte[], size_t)> socket_output_fn, std::function<void (const byte[], size_t, Alert)> proc_fn, - std::function<bool (const Session&)> handshake_complete); + std::function<bool (const Session&)> handshake_complete, + Session_Manager& session_manager); virtual ~Channel(); protected: @@ -99,16 +101,6 @@ class BOTAN_DLL Channel virtual void alert_notify(const Alert& alert) = 0; - std::function<void (const byte[], size_t, Alert)> m_proc_fn; - std::function<bool (const Session&)> m_handshake_fn; - - Record_Writer m_writer; - Record_Reader m_reader; - - std::vector<X509_Certificate> m_peer_certs; - - class Handshake_State* m_state; - class Secure_Renegotiation_State { public: @@ -142,8 +134,20 @@ class BOTAN_DLL Channel std::vector<byte> m_client_verify, m_server_verify; }; + std::function<void (const byte[], size_t, Alert)> m_proc_fn; + std::function<bool (const Session&)> m_handshake_fn; + + class Handshake_State* m_state; + + Session_Manager& m_session_manager; + Record_Writer m_writer; + Record_Reader m_reader; + + std::vector<X509_Certificate> m_peer_certs; + Secure_Renegotiation_State m_secure_renegotiation; + std::vector<byte> m_active_session; bool m_handshake_completed; bool m_connection_closed; bool m_peer_supports_heartbeats; diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index 1a515b1f6..3dd6484db 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -27,10 +27,9 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn, RandomNumberGenerator& rng, const std::string& hostname, std::function<std::string (std::vector<std::string>)> next_protocol) : - Channel(output_fn, proc_fn, handshake_fn), + Channel(output_fn, proc_fn, handshake_fn, session_manager), m_policy(policy), m_rng(rng), - m_session_manager(session_manager), m_creds(creds), m_hostname(hostname) { @@ -471,14 +470,20 @@ void Client::process_handshake_msg(Handshake_Type type, "" ); - if(m_handshake_fn(session_info)) - m_session_manager.save(session_info); - else - m_session_manager.remove_entry(session_info.session_id()); + const bool should_save = m_handshake_fn(session_info); + + if(!session_id.empty()) + { + if(should_save) + m_session_manager.save(session_info); + else + m_session_manager.remove_entry(session_info.session_id()); + } delete m_state; m_state = nullptr; m_handshake_completed = true; + m_active_session = session_info.session_id(); } else throw Unexpected_Message("Unknown handshake message received"); diff --git a/src/tls/tls_client.h b/src/tls/tls_client.h index 7e38465af..c85b528d2 100644 --- a/src/tls/tls_client.h +++ b/src/tls/tls_client.h @@ -9,7 +9,6 @@ #define BOTAN_TLS_CLIENT_H__ #include <botan/tls_channel.h> -#include <botan/tls_session_manager.h> #include <botan/credentials_manager.h> #include <vector> @@ -62,7 +61,6 @@ class BOTAN_DLL Client : public Channel const Policy& m_policy; RandomNumberGenerator& m_rng; - Session_Manager& m_session_manager; Credentials_Manager& m_creds; const std::string m_hostname; }; diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp index fcde7a8ce..d1d9463e2 100644 --- a/src/tls/tls_server.cpp +++ b/src/tls/tls_server.cpp @@ -192,11 +192,10 @@ Server::Server(std::function<void (const byte[], size_t)> output_fn, const Policy& policy, RandomNumberGenerator& rng, const std::vector<std::string>& next_protocols) : - Channel(output_fn, proc_fn, handshake_fn), - policy(policy), - rng(rng), - session_manager(session_manager), - creds(creds), + Channel(output_fn, proc_fn, handshake_fn, session_manager), + m_policy(policy), + m_rng(rng), + m_creds(creds), m_possible_protocols(next_protocols) { } @@ -278,13 +277,13 @@ void Server::process_handshake_msg(Handshake_Type type, Protocol_Version client_version = m_state->client_hello->version(); - if(client_version < policy.min_version()) + if(client_version < m_policy.min_version()) throw TLS_Exception(Alert::PROTOCOL_VERSION, "Client version is unacceptable by policy"); - if(client_version > policy.pref_version()) + if(client_version > m_policy.pref_version()) { - m_state->set_version(policy.pref_version()); + m_state->set_version(m_policy.pref_version()); } else { @@ -319,7 +318,7 @@ void Server::process_handshake_msg(Handshake_Type type, } } - if(!policy.allow_insecure_renegotiation() && + if(!m_policy.allow_insecure_renegotiation() && !(m_secure_renegotiation.initial_handshake() || m_secure_renegotiation.supported())) { delete m_state; @@ -340,17 +339,17 @@ void Server::process_handshake_msg(Handshake_Type type, const bool resuming = m_state->allow_session_resumption && check_for_resume(session_info, - session_manager, - creds, + m_session_manager, + m_creds, m_state->client_hello, - std::chrono::seconds(policy.session_ticket_lifetime())); + std::chrono::seconds(m_policy.session_ticket_lifetime())); bool have_session_ticket_key = false; try { have_session_ticket_key = - creds.psk("tls-server", "session-ticket", "").length() > 0; + m_creds.psk("tls-server", "session-ticket", "").length() > 0; } catch(...) {} @@ -374,7 +373,7 @@ void Server::process_handshake_msg(Handshake_Type type, m_state->client_hello->next_protocol_notification(), m_possible_protocols, m_state->client_hello->supports_heartbeats(), - rng); + m_rng); m_secure_renegotiation.update(m_state->server_hello); @@ -390,7 +389,7 @@ void Server::process_handshake_msg(Handshake_Type type, if(!m_handshake_fn(session_info)) { - session_manager.remove_entry(session_info.session_id()); + m_session_manager.remove_entry(session_info.session_id()); if(m_state->server_hello->supports_session_ticket()) // send an empty ticket m_state->new_session_ticket = new New_Session_Ticket(m_writer, m_state->hash); @@ -400,12 +399,12 @@ void Server::process_handshake_msg(Handshake_Type type, { try { - const SymmetricKey ticket_key = creds.psk("tls-server", "session-ticket", ""); + const SymmetricKey ticket_key = m_creds.psk("tls-server", "session-ticket", ""); m_state->new_session_ticket = new New_Session_Ticket(m_writer, m_state->hash, - session_info.encrypt(ticket_key, rng), - policy.session_ticket_lifetime()); + session_info.encrypt(ticket_key, m_rng), + m_policy.session_ticket_lifetime()); } catch(...) {} @@ -416,7 +415,7 @@ void Server::process_handshake_msg(Handshake_Type type, m_writer.send(CHANGE_CIPHER_SPEC, 1); m_writer.activate(SERVER, m_state->suite, m_state->keys, - m_state->server_hello->compression_method()); + m_state->server_hello->compression_method()); m_state->server_finished = new Finished(m_writer, m_state, SERVER); @@ -426,11 +425,11 @@ void Server::process_handshake_msg(Handshake_Type type, { std::map<std::string, std::vector<X509_Certificate> > cert_chains; - cert_chains = get_server_certs(m_hostname, creds); + cert_chains = get_server_certs(m_hostname, m_creds); if(m_hostname != "" && cert_chains.empty()) { - cert_chains = get_server_certs("", creds); + cert_chains = get_server_certs("", m_creds); /* * Only send the unrecognized_name alert if we couldn't @@ -446,10 +445,10 @@ void Server::process_handshake_msg(Handshake_Type type, m_state->server_hello = new Server_Hello( m_writer, m_state->hash, - unlock(rng.random_vec(32)), // new session ID + unlock(m_rng.random_vec(32)), // new session ID m_state->version(), - choose_ciphersuite(policy, creds, cert_chains, m_state->client_hello), - choose_compression(policy, m_state->client_hello->compression_methods()), + choose_ciphersuite(m_policy, m_creds, cert_chains, m_state->client_hello), + choose_compression(m_policy, m_state->client_hello->compression_methods()), m_state->client_hello->fragment_size(), m_secure_renegotiation.supported(), m_secure_renegotiation.for_server_hello(), @@ -457,7 +456,7 @@ void Server::process_handshake_msg(Handshake_Type type, m_state->client_hello->next_protocol_notification(), m_possible_protocols, m_state->client_hello->supports_heartbeats(), - rng); + m_rng); m_secure_renegotiation.update(m_state->server_hello); @@ -486,9 +485,10 @@ void Server::process_handshake_msg(Handshake_Type type, if(kex_algo == "RSA" || sig_algo != "") { - private_key = creds.private_key_for(m_state->server_certs->cert_chain()[0], - "tls-server", - m_hostname); + private_key = m_creds.private_key_for( + m_state->server_certs->cert_chain()[0], + "tls-server", + m_hostname); if(!private_key) throw Internal_Error("No private key located for associated server cert"); @@ -501,17 +501,17 @@ void Server::process_handshake_msg(Handshake_Type type, else { m_state->server_kex = - new Server_Key_Exchange(m_writer, m_state, policy, creds, rng, private_key); + new Server_Key_Exchange(m_writer, m_state, m_policy, m_creds, m_rng, private_key); } std::vector<X509_Certificate> client_auth_CAs = - creds.trusted_certificate_authorities("tls-server", m_hostname); + m_creds.trusted_certificate_authorities("tls-server", m_hostname); if(!client_auth_CAs.empty() && m_state->suite.sig_algo() != "") { m_state->cert_req = new Certificate_Req(m_writer, m_state->hash, - policy, + m_policy, client_auth_CAs, m_state->version()); @@ -541,7 +541,7 @@ void Server::process_handshake_msg(Handshake_Type type, else m_state->set_expected_next(HANDSHAKE_CCS); - m_state->client_kex = new Client_Key_Exchange(contents, m_state, creds, policy, rng); + m_state->client_kex = new Client_Key_Exchange(contents, m_state, m_creds, m_policy, m_rng); m_state->keys = Session_Keys(m_state, m_state->client_kex->pre_master_secret(), false); } @@ -566,7 +566,7 @@ void Server::process_handshake_msg(Handshake_Type type, try { - creds.verify_certificate_chain("tls-server", "", m_peer_certs); + m_creds.verify_certificate_chain("tls-server", "", m_peer_certs); } catch(std::exception& e) { @@ -630,17 +630,17 @@ void Server::process_handshake_msg(Handshake_Type type, { try { - const SymmetricKey ticket_key = creds.psk("tls-server", "session-ticket", ""); + const SymmetricKey ticket_key = m_creds.psk("tls-server", "session-ticket", ""); m_state->new_session_ticket = new New_Session_Ticket(m_writer, m_state->hash, - session_info.encrypt(ticket_key, rng), - policy.session_ticket_lifetime()); + session_info.encrypt(ticket_key, m_rng), + m_policy.session_ticket_lifetime()); } catch(...) {} } else - session_manager.save(session_info); + m_session_manager.save(session_info); } if(m_state->server_hello->supports_session_ticket() && !m_state->new_session_ticket) @@ -657,9 +657,11 @@ void Server::process_handshake_msg(Handshake_Type type, m_secure_renegotiation.update(m_state->client_finished, m_state->server_finished); + m_active_session = m_state->server_hello->session_id(); delete m_state; m_state = nullptr; m_handshake_completed = true; + } else throw Unexpected_Message("Unknown handshake message received"); diff --git a/src/tls/tls_server.h b/src/tls/tls_server.h index 441e03eb2..9625adcf3 100644 --- a/src/tls/tls_server.h +++ b/src/tls/tls_server.h @@ -9,7 +9,6 @@ #define BOTAN_TLS_SERVER_H__ #include <botan/tls_channel.h> -#include <botan/tls_session_manager.h> #include <botan/credentials_manager.h> #include <vector> @@ -57,10 +56,9 @@ class BOTAN_DLL Server : public Channel void alert_notify(const Alert& alert); - const Policy& policy; - RandomNumberGenerator& rng; - Session_Manager& session_manager; - Credentials_Manager& creds; + const Policy& m_policy; + RandomNumberGenerator& m_rng; + Credentials_Manager& m_creds; std::vector<std::string> m_possible_protocols; std::string m_hostname; |