aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/tls/tls_channel.cpp19
-rw-r--r--src/tls/tls_channel.h26
-rw-r--r--src/tls/tls_client.cpp17
-rw-r--r--src/tls/tls_client.h2
-rw-r--r--src/tls/tls_server.cpp76
-rw-r--r--src/tls/tls_server.h8
6 files changed, 85 insertions, 63 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index c06bd3e3a..48f142fb5 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -18,11 +18,13 @@ namespace TLS {
Channel::Channel(std::function<void (const byte[], size_t)> socket_output_fn,
std::function<void (const byte[], size_t, Alert)> proc_fn,
- std::function<bool (const Session&)> handshake_complete) :
+ std::function<bool (const Session&)> handshake_complete,
+ Session_Manager& session_manager) :
m_proc_fn(proc_fn),
m_handshake_fn(handshake_complete),
- m_writer(socket_output_fn),
m_state(nullptr),
+ m_session_manager(session_manager),
+ m_writer(socket_output_fn),
m_handshake_completed(false),
m_connection_closed(false),
m_peer_supports_heartbeats(false),
@@ -120,6 +122,13 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
else if(alert_msg.is_fatal())
{
// delete state immediately
+
+ if(!m_active_session.empty())
+ {
+ m_session_manager.remove_entry(m_active_session);
+ m_active_session.clear();
+ }
+
m_connection_closed = true;
delete m_state;
@@ -236,6 +245,12 @@ void Channel::send_alert(const Alert& alert)
catch(...) { /* swallow it */ }
}
+ if(alert.is_fatal() && !m_active_session.empty())
+ {
+ m_session_manager.remove_entry(m_active_session);
+ m_active_session.clear();
+ }
+
if(!m_connection_closed && (alert.type() == Alert::CLOSE_NOTIFY || alert.is_fatal()))
{
m_connection_closed = true;
diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h
index ca4247c85..1f4075b01 100644
--- a/src/tls/tls_channel.h
+++ b/src/tls/tls_channel.h
@@ -12,6 +12,7 @@
#include <botan/tls_record.h>
#include <botan/tls_session.h>
#include <botan/tls_alert.h>
+#include <botan/tls_session_manager.h>
#include <botan/x509cert.h>
#include <vector>
@@ -78,7 +79,8 @@ class BOTAN_DLL Channel
Channel(std::function<void (const byte[], size_t)> socket_output_fn,
std::function<void (const byte[], size_t, Alert)> proc_fn,
- std::function<bool (const Session&)> handshake_complete);
+ std::function<bool (const Session&)> handshake_complete,
+ Session_Manager& session_manager);
virtual ~Channel();
protected:
@@ -99,16 +101,6 @@ class BOTAN_DLL Channel
virtual void alert_notify(const Alert& alert) = 0;
- std::function<void (const byte[], size_t, Alert)> m_proc_fn;
- std::function<bool (const Session&)> m_handshake_fn;
-
- Record_Writer m_writer;
- Record_Reader m_reader;
-
- std::vector<X509_Certificate> m_peer_certs;
-
- class Handshake_State* m_state;
-
class Secure_Renegotiation_State
{
public:
@@ -142,8 +134,20 @@ class BOTAN_DLL Channel
std::vector<byte> m_client_verify, m_server_verify;
};
+ std::function<void (const byte[], size_t, Alert)> m_proc_fn;
+ std::function<bool (const Session&)> m_handshake_fn;
+
+ class Handshake_State* m_state;
+
+ Session_Manager& m_session_manager;
+ Record_Writer m_writer;
+ Record_Reader m_reader;
+
+ std::vector<X509_Certificate> m_peer_certs;
+
Secure_Renegotiation_State m_secure_renegotiation;
+ std::vector<byte> m_active_session;
bool m_handshake_completed;
bool m_connection_closed;
bool m_peer_supports_heartbeats;
diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp
index 1a515b1f6..3dd6484db 100644
--- a/src/tls/tls_client.cpp
+++ b/src/tls/tls_client.cpp
@@ -27,10 +27,9 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn,
RandomNumberGenerator& rng,
const std::string& hostname,
std::function<std::string (std::vector<std::string>)> next_protocol) :
- Channel(output_fn, proc_fn, handshake_fn),
+ Channel(output_fn, proc_fn, handshake_fn, session_manager),
m_policy(policy),
m_rng(rng),
- m_session_manager(session_manager),
m_creds(creds),
m_hostname(hostname)
{
@@ -471,14 +470,20 @@ void Client::process_handshake_msg(Handshake_Type type,
""
);
- if(m_handshake_fn(session_info))
- m_session_manager.save(session_info);
- else
- m_session_manager.remove_entry(session_info.session_id());
+ const bool should_save = m_handshake_fn(session_info);
+
+ if(!session_id.empty())
+ {
+ if(should_save)
+ m_session_manager.save(session_info);
+ else
+ m_session_manager.remove_entry(session_info.session_id());
+ }
delete m_state;
m_state = nullptr;
m_handshake_completed = true;
+ m_active_session = session_info.session_id();
}
else
throw Unexpected_Message("Unknown handshake message received");
diff --git a/src/tls/tls_client.h b/src/tls/tls_client.h
index 7e38465af..c85b528d2 100644
--- a/src/tls/tls_client.h
+++ b/src/tls/tls_client.h
@@ -9,7 +9,6 @@
#define BOTAN_TLS_CLIENT_H__
#include <botan/tls_channel.h>
-#include <botan/tls_session_manager.h>
#include <botan/credentials_manager.h>
#include <vector>
@@ -62,7 +61,6 @@ class BOTAN_DLL Client : public Channel
const Policy& m_policy;
RandomNumberGenerator& m_rng;
- Session_Manager& m_session_manager;
Credentials_Manager& m_creds;
const std::string m_hostname;
};
diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp
index fcde7a8ce..d1d9463e2 100644
--- a/src/tls/tls_server.cpp
+++ b/src/tls/tls_server.cpp
@@ -192,11 +192,10 @@ Server::Server(std::function<void (const byte[], size_t)> output_fn,
const Policy& policy,
RandomNumberGenerator& rng,
const std::vector<std::string>& next_protocols) :
- Channel(output_fn, proc_fn, handshake_fn),
- policy(policy),
- rng(rng),
- session_manager(session_manager),
- creds(creds),
+ Channel(output_fn, proc_fn, handshake_fn, session_manager),
+ m_policy(policy),
+ m_rng(rng),
+ m_creds(creds),
m_possible_protocols(next_protocols)
{
}
@@ -278,13 +277,13 @@ void Server::process_handshake_msg(Handshake_Type type,
Protocol_Version client_version = m_state->client_hello->version();
- if(client_version < policy.min_version())
+ if(client_version < m_policy.min_version())
throw TLS_Exception(Alert::PROTOCOL_VERSION,
"Client version is unacceptable by policy");
- if(client_version > policy.pref_version())
+ if(client_version > m_policy.pref_version())
{
- m_state->set_version(policy.pref_version());
+ m_state->set_version(m_policy.pref_version());
}
else
{
@@ -319,7 +318,7 @@ void Server::process_handshake_msg(Handshake_Type type,
}
}
- if(!policy.allow_insecure_renegotiation() &&
+ if(!m_policy.allow_insecure_renegotiation() &&
!(m_secure_renegotiation.initial_handshake() || m_secure_renegotiation.supported()))
{
delete m_state;
@@ -340,17 +339,17 @@ void Server::process_handshake_msg(Handshake_Type type,
const bool resuming =
m_state->allow_session_resumption &&
check_for_resume(session_info,
- session_manager,
- creds,
+ m_session_manager,
+ m_creds,
m_state->client_hello,
- std::chrono::seconds(policy.session_ticket_lifetime()));
+ std::chrono::seconds(m_policy.session_ticket_lifetime()));
bool have_session_ticket_key = false;
try
{
have_session_ticket_key =
- creds.psk("tls-server", "session-ticket", "").length() > 0;
+ m_creds.psk("tls-server", "session-ticket", "").length() > 0;
}
catch(...) {}
@@ -374,7 +373,7 @@ void Server::process_handshake_msg(Handshake_Type type,
m_state->client_hello->next_protocol_notification(),
m_possible_protocols,
m_state->client_hello->supports_heartbeats(),
- rng);
+ m_rng);
m_secure_renegotiation.update(m_state->server_hello);
@@ -390,7 +389,7 @@ void Server::process_handshake_msg(Handshake_Type type,
if(!m_handshake_fn(session_info))
{
- session_manager.remove_entry(session_info.session_id());
+ m_session_manager.remove_entry(session_info.session_id());
if(m_state->server_hello->supports_session_ticket()) // send an empty ticket
m_state->new_session_ticket = new New_Session_Ticket(m_writer, m_state->hash);
@@ -400,12 +399,12 @@ void Server::process_handshake_msg(Handshake_Type type,
{
try
{
- const SymmetricKey ticket_key = creds.psk("tls-server", "session-ticket", "");
+ const SymmetricKey ticket_key = m_creds.psk("tls-server", "session-ticket", "");
m_state->new_session_ticket =
new New_Session_Ticket(m_writer, m_state->hash,
- session_info.encrypt(ticket_key, rng),
- policy.session_ticket_lifetime());
+ session_info.encrypt(ticket_key, m_rng),
+ m_policy.session_ticket_lifetime());
}
catch(...) {}
@@ -416,7 +415,7 @@ void Server::process_handshake_msg(Handshake_Type type,
m_writer.send(CHANGE_CIPHER_SPEC, 1);
m_writer.activate(SERVER, m_state->suite, m_state->keys,
- m_state->server_hello->compression_method());
+ m_state->server_hello->compression_method());
m_state->server_finished = new Finished(m_writer, m_state, SERVER);
@@ -426,11 +425,11 @@ void Server::process_handshake_msg(Handshake_Type type,
{
std::map<std::string, std::vector<X509_Certificate> > cert_chains;
- cert_chains = get_server_certs(m_hostname, creds);
+ cert_chains = get_server_certs(m_hostname, m_creds);
if(m_hostname != "" && cert_chains.empty())
{
- cert_chains = get_server_certs("", creds);
+ cert_chains = get_server_certs("", m_creds);
/*
* Only send the unrecognized_name alert if we couldn't
@@ -446,10 +445,10 @@ void Server::process_handshake_msg(Handshake_Type type,
m_state->server_hello = new Server_Hello(
m_writer,
m_state->hash,
- unlock(rng.random_vec(32)), // new session ID
+ unlock(m_rng.random_vec(32)), // new session ID
m_state->version(),
- choose_ciphersuite(policy, creds, cert_chains, m_state->client_hello),
- choose_compression(policy, m_state->client_hello->compression_methods()),
+ choose_ciphersuite(m_policy, m_creds, cert_chains, m_state->client_hello),
+ choose_compression(m_policy, m_state->client_hello->compression_methods()),
m_state->client_hello->fragment_size(),
m_secure_renegotiation.supported(),
m_secure_renegotiation.for_server_hello(),
@@ -457,7 +456,7 @@ void Server::process_handshake_msg(Handshake_Type type,
m_state->client_hello->next_protocol_notification(),
m_possible_protocols,
m_state->client_hello->supports_heartbeats(),
- rng);
+ m_rng);
m_secure_renegotiation.update(m_state->server_hello);
@@ -486,9 +485,10 @@ void Server::process_handshake_msg(Handshake_Type type,
if(kex_algo == "RSA" || sig_algo != "")
{
- private_key = creds.private_key_for(m_state->server_certs->cert_chain()[0],
- "tls-server",
- m_hostname);
+ private_key = m_creds.private_key_for(
+ m_state->server_certs->cert_chain()[0],
+ "tls-server",
+ m_hostname);
if(!private_key)
throw Internal_Error("No private key located for associated server cert");
@@ -501,17 +501,17 @@ void Server::process_handshake_msg(Handshake_Type type,
else
{
m_state->server_kex =
- new Server_Key_Exchange(m_writer, m_state, policy, creds, rng, private_key);
+ new Server_Key_Exchange(m_writer, m_state, m_policy, m_creds, m_rng, private_key);
}
std::vector<X509_Certificate> client_auth_CAs =
- creds.trusted_certificate_authorities("tls-server", m_hostname);
+ m_creds.trusted_certificate_authorities("tls-server", m_hostname);
if(!client_auth_CAs.empty() && m_state->suite.sig_algo() != "")
{
m_state->cert_req = new Certificate_Req(m_writer,
m_state->hash,
- policy,
+ m_policy,
client_auth_CAs,
m_state->version());
@@ -541,7 +541,7 @@ void Server::process_handshake_msg(Handshake_Type type,
else
m_state->set_expected_next(HANDSHAKE_CCS);
- m_state->client_kex = new Client_Key_Exchange(contents, m_state, creds, policy, rng);
+ m_state->client_kex = new Client_Key_Exchange(contents, m_state, m_creds, m_policy, m_rng);
m_state->keys = Session_Keys(m_state, m_state->client_kex->pre_master_secret(), false);
}
@@ -566,7 +566,7 @@ void Server::process_handshake_msg(Handshake_Type type,
try
{
- creds.verify_certificate_chain("tls-server", "", m_peer_certs);
+ m_creds.verify_certificate_chain("tls-server", "", m_peer_certs);
}
catch(std::exception& e)
{
@@ -630,17 +630,17 @@ void Server::process_handshake_msg(Handshake_Type type,
{
try
{
- const SymmetricKey ticket_key = creds.psk("tls-server", "session-ticket", "");
+ const SymmetricKey ticket_key = m_creds.psk("tls-server", "session-ticket", "");
m_state->new_session_ticket =
new New_Session_Ticket(m_writer, m_state->hash,
- session_info.encrypt(ticket_key, rng),
- policy.session_ticket_lifetime());
+ session_info.encrypt(ticket_key, m_rng),
+ m_policy.session_ticket_lifetime());
}
catch(...) {}
}
else
- session_manager.save(session_info);
+ m_session_manager.save(session_info);
}
if(m_state->server_hello->supports_session_ticket() && !m_state->new_session_ticket)
@@ -657,9 +657,11 @@ void Server::process_handshake_msg(Handshake_Type type,
m_secure_renegotiation.update(m_state->client_finished,
m_state->server_finished);
+ m_active_session = m_state->server_hello->session_id();
delete m_state;
m_state = nullptr;
m_handshake_completed = true;
+
}
else
throw Unexpected_Message("Unknown handshake message received");
diff --git a/src/tls/tls_server.h b/src/tls/tls_server.h
index 441e03eb2..9625adcf3 100644
--- a/src/tls/tls_server.h
+++ b/src/tls/tls_server.h
@@ -9,7 +9,6 @@
#define BOTAN_TLS_SERVER_H__
#include <botan/tls_channel.h>
-#include <botan/tls_session_manager.h>
#include <botan/credentials_manager.h>
#include <vector>
@@ -57,10 +56,9 @@ class BOTAN_DLL Server : public Channel
void alert_notify(const Alert& alert);
- const Policy& policy;
- RandomNumberGenerator& rng;
- Session_Manager& session_manager;
- Credentials_Manager& creds;
+ const Policy& m_policy;
+ RandomNumberGenerator& m_rng;
+ Credentials_Manager& m_creds;
std::vector<std::string> m_possible_protocols;
std::string m_hostname;