aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorlloyd <[email protected]>2012-10-13 20:55:17 +0000
committerlloyd <[email protected]>2012-10-13 20:55:17 +0000
commit72eb425d699e0571857432e4271d10afb6431a6e (patch)
treecb8d3174f4844c84f8295a63ea0d44aaefa9b31c /src
parent8bd5519105f6978e7d937294d2a2e8deadda20ca (diff)
parent4be75ae1e9e473fc3e939be5e54e51f552d5934b (diff)
merge of '415e0ca58c566cb2990758c1261d47d6b09fc76c'
and 'e616da4002c659a5f5f6c16aecaafef7c37a5f96'
Diffstat (limited to 'src')
-rw-r--r--src/block/aes_ni/aes_ni.cpp44
-rw-r--r--src/tls/info.txt1
-rw-r--r--src/tls/msg_client_hello.cpp10
-rw-r--r--src/tls/msg_server_hello.cpp7
-rw-r--r--src/tls/sessions_sqlite/tls_session_manager_sqlite.cpp18
-rw-r--r--src/tls/sessions_sqlite/tls_session_manager_sqlite.h6
-rw-r--r--src/tls/tls_client.cpp32
-rw-r--r--src/tls/tls_client.h14
-rw-r--r--src/tls/tls_messages.h1
-rw-r--r--src/tls/tls_policy.cpp5
-rw-r--r--src/tls/tls_policy.h10
-rw-r--r--src/tls/tls_server.cpp19
-rw-r--r--src/tls/tls_server.h13
-rw-r--r--src/tls/tls_server_info.h91
-rw-r--r--src/tls/tls_session.cpp26
-rw-r--r--src/tls/tls_session.h14
-rw-r--r--src/tls/tls_session_manager.cpp31
-rw-r--r--src/tls/tls_session_manager.h25
-rw-r--r--src/tls/tls_version.h10
-rw-r--r--src/utils/assert.h6
20 files changed, 244 insertions, 139 deletions
diff --git a/src/block/aes_ni/aes_ni.cpp b/src/block/aes_ni/aes_ni.cpp
index 3ee0e608c..4dca6c7f2 100644
--- a/src/block/aes_ni/aes_ni.cpp
+++ b/src/block/aes_ni/aes_ni.cpp
@@ -1,6 +1,6 @@
/*
* AES using AES-NI instructions
-* (C) 2009 Jack Lloyd
+* (C) 2009,2012 Jack Lloyd
*
* Distributed under the terms of the Botan license
*/
@@ -485,10 +485,10 @@ void AES_192_NI::key_schedule(const byte key[], size_t)
load_le(&EK[0], key, 6);
-#define AES_192_key_exp(RCON, EK_OFF) \
- aes_192_key_expansion(&K0, &K1, \
- _mm_aeskeygenassist_si128(K1, RCON), \
- EK + EK_OFF, EK_OFF == 48)
+ #define AES_192_key_exp(RCON, EK_OFF) \
+ aes_192_key_expansion(&K0, &K1, \
+ _mm_aeskeygenassist_si128(K1, RCON), \
+ &EK[EK_OFF], EK_OFF == 48)
AES_192_key_exp(0x01, 6);
AES_192_key_exp(0x02, 12);
@@ -499,22 +499,25 @@ void AES_192_NI::key_schedule(const byte key[], size_t)
AES_192_key_exp(0x40, 42);
AES_192_key_exp(0x80, 48);
+ #undef AES_192_key_exp
+
// Now generate decryption keys
const __m128i* EK_mm = (const __m128i*)&EK[0];
+
__m128i* DK_mm = (__m128i*)&DK[0];
- _mm_storeu_si128(DK_mm , EK_mm[12]);
- _mm_storeu_si128(DK_mm + 1, _mm_aesimc_si128(EK_mm[11]));
- _mm_storeu_si128(DK_mm + 2, _mm_aesimc_si128(EK_mm[10]));
- _mm_storeu_si128(DK_mm + 3, _mm_aesimc_si128(EK_mm[9]));
- _mm_storeu_si128(DK_mm + 4, _mm_aesimc_si128(EK_mm[8]));
- _mm_storeu_si128(DK_mm + 5, _mm_aesimc_si128(EK_mm[7]));
- _mm_storeu_si128(DK_mm + 6, _mm_aesimc_si128(EK_mm[6]));
- _mm_storeu_si128(DK_mm + 7, _mm_aesimc_si128(EK_mm[5]));
- _mm_storeu_si128(DK_mm + 8, _mm_aesimc_si128(EK_mm[4]));
- _mm_storeu_si128(DK_mm + 9, _mm_aesimc_si128(EK_mm[3]));
- _mm_storeu_si128(DK_mm + 10, _mm_aesimc_si128(EK_mm[2]));
- _mm_storeu_si128(DK_mm + 11, _mm_aesimc_si128(EK_mm[1]));
- _mm_storeu_si128(DK_mm + 12, EK_mm[0]);
+ _mm_storeu_si128(DK_mm , _mm_loadu_si128(EK_mm + 12));
+ _mm_storeu_si128(DK_mm + 1, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 11)));
+ _mm_storeu_si128(DK_mm + 2, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 10)));
+ _mm_storeu_si128(DK_mm + 3, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 9)));
+ _mm_storeu_si128(DK_mm + 4, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 8)));
+ _mm_storeu_si128(DK_mm + 5, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 7)));
+ _mm_storeu_si128(DK_mm + 6, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 6)));
+ _mm_storeu_si128(DK_mm + 7, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 5)));
+ _mm_storeu_si128(DK_mm + 8, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 4)));
+ _mm_storeu_si128(DK_mm + 9, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 3)));
+ _mm_storeu_si128(DK_mm + 10, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 2)));
+ _mm_storeu_si128(DK_mm + 11, _mm_aesimc_si128(_mm_loadu_si128(EK_mm + 1)));
+ _mm_storeu_si128(DK_mm + 12, _mm_loadu_si128(EK_mm + 0));
}
/*
@@ -776,4 +779,9 @@ void AES_256_NI::clear()
zeroise(DK);
}
+#undef AES_ENC_4_ROUNDS
+#undef AES_ENC_4_LAST_ROUNDS
+#undef AES_DEC_4_ROUNDS
+#undef AES_DEC_4_LAST_ROUNDS
+
}
diff --git a/src/tls/info.txt b/src/tls/info.txt
index e61b2c0da..47de42598 100644
--- a/src/tls/info.txt
+++ b/src/tls/info.txt
@@ -15,6 +15,7 @@ tls_client.h
tls_exceptn.h
tls_handshake_msg.h
tls_magic.h
+tls_server_info.h
tls_policy.h
tls_server.h
tls_session.h
diff --git a/src/tls/msg_client_hello.cpp b/src/tls/msg_client_hello.cpp
index 2149ac5e5..6176ca6bf 100644
--- a/src/tls/msg_client_hello.cpp
+++ b/src/tls/msg_client_hello.cpp
@@ -74,13 +74,15 @@ Client_Hello::Client_Hello(Handshake_IO& io,
m_suites(ciphersuite_list(policy, m_version, (srp_identifier != ""))),
m_comp_methods(policy.compression())
{
- m_extensions.add(new Heartbeat_Support_Indicator(true));
m_extensions.add(new Renegotiation_Extension(reneg_info));
m_extensions.add(new SRP_Identifier(srp_identifier));
m_extensions.add(new Server_Name_Indicator(hostname));
m_extensions.add(new Session_Ticket());
m_extensions.add(new Supported_Elliptic_Curves(policy.allowed_ecc_curves()));
+ if(policy.negotiate_heartbeat_support())
+ m_extensions.add(new Heartbeat_Support_Indicator(true));
+
if(m_version.supports_negotiable_signature_algorithms())
m_extensions.add(new Signature_Algorithms(policy.allowed_signature_hashes(),
policy.allowed_signature_methods()));
@@ -113,13 +115,15 @@ Client_Hello::Client_Hello(Handshake_IO& io,
if(!value_exists(m_comp_methods, session.compression_method()))
m_comp_methods.push_back(session.compression_method());
- m_extensions.add(new Heartbeat_Support_Indicator(true));
m_extensions.add(new Renegotiation_Extension(reneg_info));
m_extensions.add(new SRP_Identifier(session.srp_identifier()));
- m_extensions.add(new Server_Name_Indicator(session.sni_hostname()));
+ m_extensions.add(new Server_Name_Indicator(session.server_info().hostname()));
m_extensions.add(new Session_Ticket(session.session_ticket()));
m_extensions.add(new Supported_Elliptic_Curves(policy.allowed_ecc_curves()));
+ if(policy.negotiate_heartbeat_support())
+ m_extensions.add(new Heartbeat_Support_Indicator(true));
+
if(session.fragment_size() != 0)
m_extensions.add(new Maximum_Fragment_Length(session.fragment_size()));
diff --git a/src/tls/msg_server_hello.cpp b/src/tls/msg_server_hello.cpp
index 6ca5e3b30..a775e0b4b 100644
--- a/src/tls/msg_server_hello.cpp
+++ b/src/tls/msg_server_hello.cpp
@@ -21,6 +21,7 @@ namespace TLS {
*/
Server_Hello::Server_Hello(Handshake_IO& io,
Handshake_Hash& hash,
+ const Policy& policy,
const std::vector<byte>& session_id,
Protocol_Version ver,
u16bit ciphersuite,
@@ -39,9 +40,13 @@ Server_Hello::Server_Hello(Handshake_IO& io,
m_ciphersuite(ciphersuite),
m_comp_method(compression)
{
- if(client_has_heartbeat)
+ if(client_has_heartbeat && policy.negotiate_heartbeat_support())
m_extensions.add(new Heartbeat_Support_Indicator(true));
+ /*
+ * Even a client that offered SSLv3 and sent the SCSV will get an
+ * extension back. This is probably the right thing to do.
+ */
if(client_has_secure_renegotiation)
m_extensions.add(new Renegotiation_Extension(reneg_info));
diff --git a/src/tls/sessions_sqlite/tls_session_manager_sqlite.cpp b/src/tls/sessions_sqlite/tls_session_manager_sqlite.cpp
index d10366c60..87556ff75 100644
--- a/src/tls/sessions_sqlite/tls_session_manager_sqlite.cpp
+++ b/src/tls/sessions_sqlite/tls_session_manager_sqlite.cpp
@@ -142,16 +142,15 @@ bool Session_Manager_SQLite::load_from_session_id(const std::vector<byte>& sessi
return false;
}
-bool Session_Manager_SQLite::load_from_host_info(const std::string& hostname,
- u16bit port,
- Session& session)
+bool Session_Manager_SQLite::load_from_server_info(const Server_Information& server,
+ Session& session)
{
sqlite3_statement stmt(m_db, "select session from tls_sessions"
" where hostname = ?1 and hostport = ?2"
" order by session_start desc");
- stmt.bind(1, hostname);
- stmt.bind(2, port);
+ stmt.bind(1, server.hostname());
+ stmt.bind(2, server.port());
while(stmt.step())
{
@@ -167,9 +166,6 @@ bool Session_Manager_SQLite::load_from_host_info(const std::string& hostname,
}
}
- if(port != 0)
- return load_from_host_info(hostname, 0, session);
-
return false;
}
@@ -182,15 +178,15 @@ void Session_Manager_SQLite::remove_entry(const std::vector<byte>& session_id)
stmt.spin();
}
-void Session_Manager_SQLite::save(const Session& session, u16bit port)
+void Session_Manager_SQLite::save(const Session& session)
{
sqlite3_statement stmt(m_db, "insert or replace into tls_sessions"
" values(?1, ?2, ?3, ?4, ?5)");
stmt.bind(1, hex_encode(session.session_id()));
stmt.bind(2, session.start_time());
- stmt.bind(3, session.sni_hostname());
- stmt.bind(4, port);
+ stmt.bind(3, session.server_info().hostname());
+ stmt.bind(4, session.server_info().port());
stmt.bind(5, session.encrypt(m_session_key, m_rng));
stmt.spin();
diff --git a/src/tls/sessions_sqlite/tls_session_manager_sqlite.h b/src/tls/sessions_sqlite/tls_session_manager_sqlite.h
index db74f54b7..7892ccd6a 100644
--- a/src/tls/sessions_sqlite/tls_session_manager_sqlite.h
+++ b/src/tls/sessions_sqlite/tls_session_manager_sqlite.h
@@ -50,12 +50,12 @@ class BOTAN_DLL Session_Manager_SQLite : public Session_Manager
bool load_from_session_id(const std::vector<byte>& session_id,
Session& session) override;
- bool load_from_host_info(const std::string& hostname, u16bit port,
- Session& session) override;
+ bool load_from_server_info(const Server_Information& info,
+ Session& session) override;
void remove_entry(const std::vector<byte>& session_id) override;
- void save(const Session& session_data, u16bit port) override;
+ void save(const Session& session_data) override;
std::chrono::seconds session_lifetime() const override
{ return m_session_lifetime; }
diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp
index bb6b7a45f..0e1d84bed 100644
--- a/src/tls/tls_client.cpp
+++ b/src/tls/tls_client.cpp
@@ -55,20 +55,18 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn,
Credentials_Manager& creds,
const Policy& policy,
RandomNumberGenerator& rng,
- const std::string& hostname,
- u16bit port,
+ const Server_Information& info,
+ const Protocol_Version offer_version,
std::function<std::string (std::vector<std::string>)> next_protocol) :
Channel(output_fn, proc_fn, handshake_fn, session_manager, rng),
m_policy(policy),
m_creds(creds),
- m_hostname(hostname),
- m_port(port)
+ m_info(info)
{
- const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_hostname);
+ const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_info.hostname());
- const Protocol_Version version = m_policy.pref_version();
- Handshake_State& state = create_handshake_state(version);
- send_client_hello(state, false, version, srp_identifier, next_protocol);
+ Handshake_State& state = create_handshake_state(offer_version);
+ send_client_hello(state, false, offer_version, srp_identifier, next_protocol);
}
Handshake_State* Client::new_handshake_state(Handshake_IO* io)
@@ -111,10 +109,10 @@ void Client::send_client_hello(Handshake_State& state_base,
const bool send_npn_request = static_cast<bool>(next_protocol);
- if(!force_full_renegotiation && m_hostname != "")
+ if(!force_full_renegotiation && !m_info.empty())
{
Session session_info;
- if(session_manager().load_from_host_info(m_hostname, m_port, session_info))
+ if(session_manager().load_from_server_info(m_info, session_info))
{
if(srp_identifier == "" || session_info.srp_identifier() == srp_identifier)
{
@@ -142,7 +140,7 @@ void Client::send_client_hello(Handshake_State& state_base,
rng(),
secure_renegotiation_data_for_client_hello(),
send_npn_request,
- m_hostname,
+ m_info.hostname(),
srp_identifier));
}
@@ -321,7 +319,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state,
try
{
- m_creds.verify_certificate_chain("tls-client", m_hostname, server_certs);
+ m_creds.verify_certificate_chain("tls-client", m_info.hostname(), server_certs);
}
catch(std::exception& e)
{
@@ -380,7 +378,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state,
std::vector<X509_Certificate> client_certs =
m_creds.cert_chain(types,
"tls-client",
- m_hostname);
+ m_info.hostname());
state.client_certs(
new Certificate(state.handshake_io(),
@@ -395,7 +393,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state,
m_policy,
m_creds,
state.server_public_key.get(),
- m_hostname,
+ m_info.hostname(),
rng())
);
@@ -407,7 +405,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state,
Private_Key* private_key =
m_creds.private_key_for(state.client_certs()->cert_chain()[0],
"tls-client",
- m_hostname);
+ m_info.hostname());
state.client_verify(
new Certificate_Verify(state.handshake_io(),
@@ -501,7 +499,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state,
state.server_hello()->fragment_size(),
get_peer_cert_chain(state),
session_ticket,
- m_hostname,
+ m_info,
""
);
@@ -510,7 +508,7 @@ void Client::process_handshake_msg(const Handshake_State* active_state,
if(!session_id.empty())
{
if(should_save)
- session_manager().save(session_info, m_port);
+ session_manager().save(session_info);
else
session_manager().remove_entry(session_info.session_id());
}
diff --git a/src/tls/tls_client.h b/src/tls/tls_client.h
index d7f17f878..b40896e94 100644
--- a/src/tls/tls_client.h
+++ b/src/tls/tls_client.h
@@ -39,11 +39,10 @@ class BOTAN_DLL Client : public Channel
*
* @param rng a random number generator
*
- * @param servername the server's DNS name, if known
+ * @param server_info is identifying information about the TLS server
*
- * @param port specifies the protocol port of the server (eg for
- * TCP/UDP). Only used if servername is also specified.
- * Use 0 if unknown.
+ * @param offer_version specifies which version we will offer
+ * to the TLS server.
*
* @param next_protocol allows the client to specify what the next
* protocol will be. For more information read
@@ -61,8 +60,8 @@ class BOTAN_DLL Client : public Channel
Credentials_Manager& creds,
const Policy& policy,
RandomNumberGenerator& rng,
- const std::string& servername = "",
- u16bit port = 0,
+ const Server_Information& server_info = Server_Information(),
+ const Protocol_Version offer_version = Protocol_Version::latest_tls_version(),
std::function<std::string (std::vector<std::string>)> next_protocol =
std::function<std::string (std::vector<std::string>)>());
private:
@@ -88,8 +87,7 @@ class BOTAN_DLL Client : public Channel
const Policy& m_policy;
Credentials_Manager& m_creds;
- const std::string m_hostname;
- const u16bit m_port;
+ const Server_Information m_info;
};
}
diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h
index 70745ad9c..f1d4aa887 100644
--- a/src/tls/tls_messages.h
+++ b/src/tls/tls_messages.h
@@ -254,6 +254,7 @@ class Server_Hello : public Handshake_Message
Server_Hello(Handshake_IO& io,
Handshake_Hash& hash,
+ const Policy& policy,
const std::vector<byte>& session_id,
Protocol_Version ver,
u16bit ciphersuite,
diff --git a/src/tls/tls_policy.cpp b/src/tls/tls_policy.cpp
index b26bd4225..c76fe30a5 100644
--- a/src/tls/tls_policy.cpp
+++ b/src/tls/tls_policy.cpp
@@ -136,11 +136,6 @@ bool Policy::acceptable_protocol_version(Protocol_Version version) const
version == Protocol_Version::TLS_V12);
}
-Protocol_Version Policy::pref_version() const
- {
- return Protocol_Version::TLS_V12;
- }
-
namespace {
class Ciphersuite_Preference_Ordering
diff --git a/src/tls/tls_policy.h b/src/tls/tls_policy.h
index 8b73fea9d..cc02dd9b1 100644
--- a/src/tls/tls_policy.h
+++ b/src/tls/tls_policy.h
@@ -74,6 +74,11 @@ class BOTAN_DLL Policy
virtual std::string choose_curve(const std::vector<std::string>& curve_names) const;
/**
+ * Attempt to negotiate the use of the heartbeat extension
+ */
+ virtual bool negotiate_heartbeat_support() const { return false; }
+
+ /**
* Allow renegotiation even if the counterparty doesn't
* support the secure renegotiation extension.
*
@@ -119,11 +124,6 @@ class BOTAN_DLL Policy
*/
virtual bool acceptable_protocol_version(Protocol_Version version) const;
- /**
- * @return the version we would prefer to negotiate
- */
- virtual Protocol_Version pref_version() const;
-
virtual ~Policy() {}
};
diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp
index 1a29d317c..1189019bc 100644
--- a/src/tls/tls_server.cpp
+++ b/src/tls/tls_server.cpp
@@ -96,7 +96,7 @@ bool check_for_resume(Session& session_info,
// client sent a different SNI hostname
if(client_hello->sni_hostname() != "")
{
- if(client_hello->sni_hostname() != session_info.sni_hostname())
+ if(client_hello->sni_hostname() != session_info.server_info().hostname())
return false;
}
@@ -288,9 +288,6 @@ void Server::process_handshake_msg(const Handshake_State* active_state,
state.client_hello(new Client_Hello(contents, type));
- if(state.client_hello()->sni_hostname() != "")
- m_hostname = state.client_hello()->sni_hostname();
-
Protocol_Version client_version = state.client_hello()->version();
Protocol_Version negotiated_version;
@@ -380,6 +377,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state,
new Server_Hello(
state.handshake_io(),
state.hash(),
+ m_policy,
state.client_hello()->session_id(),
Protocol_Version(session_info.version()),
session_info.ciphersuite_code(),
@@ -451,9 +449,11 @@ void Server::process_handshake_msg(const Handshake_State* active_state,
{
std::map<std::string, std::vector<X509_Certificate> > cert_chains;
- cert_chains = get_server_certs(m_hostname, m_creds);
+ const std::string sni_hostname = state.client_hello()->sni_hostname();
+
+ cert_chains = get_server_certs(sni_hostname, m_creds);
- if(m_hostname != "" && cert_chains.empty())
+ if(sni_hostname != "" && cert_chains.empty())
{
cert_chains = get_server_certs("", m_creds);
@@ -472,6 +472,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state,
new Server_Hello(
state.handshake_io(),
state.hash(),
+ m_policy,
make_hello_random(rng()), // new session ID
state.version(),
choose_ciphersuite(m_policy,
@@ -517,7 +518,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state,
private_key = m_creds.private_key_for(
state.server_certs()->cert_chain()[0],
"tls-server",
- m_hostname);
+ sni_hostname);
if(!private_key)
throw Internal_Error("No private key located for associated server cert");
@@ -540,7 +541,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state,
}
std::vector<X509_Certificate> client_auth_CAs =
- m_creds.trusted_certificate_authorities("tls-server", m_hostname);
+ m_creds.trusted_certificate_authorities("tls-server", sni_hostname);
if(!client_auth_CAs.empty() && state.ciphersuite().sig_algo() != "")
{
@@ -663,7 +664,7 @@ void Server::process_handshake_msg(const Handshake_State* active_state,
state.server_hello()->fragment_size(),
get_peer_cert_chain(state),
std::vector<byte>(),
- m_hostname,
+ Server_Information(state.client_hello()->sni_hostname()),
state.srp_identifier()
);
diff --git a/src/tls/tls_server.h b/src/tls/tls_server.h
index 761ff6028..25c5b6506 100644
--- a/src/tls/tls_server.h
+++ b/src/tls/tls_server.h
@@ -36,16 +36,10 @@ class BOTAN_DLL Server : public Channel
std::vector<std::string>());
/**
- * Return the server name indicator, if sent by the client
+ * Return the protocol notification set by the client (using the
+ * NPN extension) for this connection, if any
*/
- std::string server_name_indicator() const
- { return m_hostname; }
-
- /**
- * Return the protocol negotiated with NPN extension
- */
- std::string next_protocol() const
- { return m_next_protocol; }
+ std::string next_protocol() const { return m_next_protocol; }
private:
std::vector<X509_Certificate>
@@ -65,7 +59,6 @@ class BOTAN_DLL Server : public Channel
Credentials_Manager& m_creds;
std::vector<std::string> m_possible_protocols;
- std::string m_hostname;
std::string m_next_protocol;
};
diff --git a/src/tls/tls_server_info.h b/src/tls/tls_server_info.h
new file mode 100644
index 000000000..773296eaf
--- /dev/null
+++ b/src/tls/tls_server_info.h
@@ -0,0 +1,91 @@
+/*
+* TLS Server Information
+* (C) 2012 Jack Lloyd
+*
+* Released under the terms of the Botan license
+*/
+
+#ifndef BOTAN_TLS_SERVER_INFO_H__
+#define BOTAN_TLS_SERVER_INFO_H__
+
+#include <botan/types.h>
+#include <string>
+
+namespace Botan {
+
+namespace TLS {
+
+/**
+* Represents information known about a TLS server.
+*/
+class BOTAN_DLL Server_Information
+ {
+ public:
+ /**
+ * An empty server info - nothing known
+ */
+ Server_Information() : m_hostname(""), m_service(""), m_port(0) {}
+
+ /**
+ * @param hostname the host's DNS name, if known
+ * @param port specifies the protocol port of the server (eg for
+ * TCP/UDP). Zero represents unknown.
+ */
+ Server_Information(const std::string& hostname,
+ u16bit port = 0) :
+ m_hostname(hostname), m_service(""), m_port(port) {}
+
+ /**
+ * @param hostname the host's DNS name, if known
+ * @param service is a text string of the service type
+ * (eg "https", "tor", or "git")
+ * @param port specifies the protocol port of the server (eg for
+ * TCP/UDP). Zero represents unknown.
+ */
+ Server_Information(const std::string& hostname,
+ const std::string& service,
+ u16bit port = 0) :
+ m_hostname(hostname), m_service(service), m_port(port) {}
+
+ std::string hostname() const { return m_hostname; }
+
+ std::string service() const { return m_service; }
+
+ u16bit port() const { return m_port; }
+
+ bool empty() const { return m_hostname.empty(); }
+
+ private:
+ std::string m_hostname, m_service;
+ u16bit m_port;
+ };
+
+inline bool operator==(const Server_Information& a, const Server_Information& b)
+ {
+ return (a.hostname() == b.hostname()) &&
+ (a.service() == b.service()) &&
+ (a.port() == b.port());
+
+ }
+
+inline bool operator!=(const Server_Information& a, const Server_Information& b)
+ {
+ return !(a == b);
+ }
+
+inline bool operator<(const Server_Information& a, const Server_Information& b)
+ {
+ if(a.hostname() != b.hostname())
+ return (a.hostname() < b.hostname());
+ if(a.service() != b.service())
+ return (a.service() < b.service());
+ if(a.port() != b.port())
+ return (a.port() < b.port());
+ return false; // equal
+ }
+
+}
+
+}
+
+#endif
diff --git a/src/tls/tls_session.cpp b/src/tls/tls_session.cpp
index ae57de0c2..85cb6d69e 100644
--- a/src/tls/tls_session.cpp
+++ b/src/tls/tls_session.cpp
@@ -27,7 +27,7 @@ Session::Session(const std::vector<byte>& session_identifier,
size_t fragment_size,
const std::vector<X509_Certificate>& certs,
const std::vector<byte>& ticket,
- const std::string& sni_hostname,
+ const Server_Information& server_info,
const std::string& srp_identifier) :
m_start_time(std::chrono::system_clock::now()),
m_identifier(session_identifier),
@@ -39,7 +39,7 @@ Session::Session(const std::vector<byte>& session_identifier,
m_connection_side(side),
m_fragment_size(fragment_size),
m_peer_certs(certs),
- m_sni_hostname(sni_hostname),
+ m_server_info(server_info),
m_srp_identifier(srp_identifier)
{
}
@@ -54,7 +54,11 @@ Session::Session(const std::string& pem)
Session::Session(const byte ber[], size_t ber_len)
{
byte side_code = 0;
- ASN1_String sni_hostname_str;
+
+ ASN1_String server_hostname;
+ ASN1_String server_service;
+ size_t server_port;
+
ASN1_String srp_identifier_str;
byte major_version = 0, minor_version = 0;
@@ -78,17 +82,23 @@ Session::Session(const byte ber[], size_t ber_len)
.decode_integer_type(m_fragment_size)
.decode(m_master_secret, OCTET_STRING)
.decode(peer_cert_bits, OCTET_STRING)
- .decode(sni_hostname_str)
+ .decode(server_hostname)
+ .decode(server_service)
+ .decode(server_port)
.decode(srp_identifier_str)
.end_cons()
.verify_end();
m_version = Protocol_Version(major_version, minor_version);
m_start_time = std::chrono::system_clock::from_time_t(start_time);
- m_sni_hostname = sni_hostname_str.value();
- m_srp_identifier = srp_identifier_str.value();
m_connection_side = static_cast<Connection_Side>(side_code);
+ m_server_info = Server_Information(server_hostname.value(),
+ server_service.value(),
+ server_port);
+
+ m_srp_identifier = srp_identifier_str.value();
+
if(!peer_cert_bits.empty())
{
DataSource_Memory certs(&peer_cert_bits[0], peer_cert_bits.size());
@@ -118,7 +128,9 @@ secure_vector<byte> Session::DER_encode() const
.encode(static_cast<size_t>(m_fragment_size))
.encode(m_master_secret, OCTET_STRING)
.encode(peer_cert_bits, OCTET_STRING)
- .encode(ASN1_String(m_sni_hostname, UTF8_STRING))
+ .encode(ASN1_String(m_server_info.hostname(), UTF8_STRING))
+ .encode(ASN1_String(m_server_info.service(), UTF8_STRING))
+ .encode(static_cast<size_t>(m_server_info.port()))
.encode(ASN1_String(m_srp_identifier, UTF8_STRING))
.end_cons()
.get_contents();
diff --git a/src/tls/tls_session.h b/src/tls/tls_session.h
index 206a75081..65154dfce 100644
--- a/src/tls/tls_session.h
+++ b/src/tls/tls_session.h
@@ -12,6 +12,7 @@
#include <botan/tls_version.h>
#include <botan/tls_ciphersuite.h>
#include <botan/tls_magic.h>
+#include <botan/tls_server_info.h>
#include <botan/secmem.h>
#include <botan/symkey.h>
#include <chrono>
@@ -51,8 +52,8 @@ class BOTAN_DLL Session
size_t fragment_size,
const std::vector<X509_Certificate>& peer_certs,
const std::vector<byte>& session_ticket,
- const std::string& sni_hostname = "",
- const std::string& srp_identifier = "");
+ const Server_Information& server_info,
+ const std::string& srp_identifier);
/**
* Load a session from DER representation (created by DER_encode)
@@ -133,11 +134,6 @@ class BOTAN_DLL Session
Connection_Side side() const { return m_connection_side; }
/**
- * Get the SNI hostname (if sent by the client in the initial handshake)
- */
- std::string sni_hostname() const { return m_sni_hostname; }
-
- /**
* Get the SRP identity (if sent by the client in the initial handshake)
*/
std::string srp_identifier() const { return m_srp_identifier; }
@@ -180,6 +176,8 @@ class BOTAN_DLL Session
*/
const std::vector<byte>& session_ticket() const { return m_session_ticket; }
+ Server_Information server_info() const { return m_server_info; }
+
private:
enum { TLS_SESSION_PARAM_STRUCT_VERSION = 0x2994e301 };
@@ -197,7 +195,7 @@ class BOTAN_DLL Session
size_t m_fragment_size;
std::vector<X509_Certificate> m_peer_certs;
- std::string m_sni_hostname; // optional
+ Server_Information m_server_info; // optional
std::string m_srp_identifier; // optional
};
diff --git a/src/tls/tls_session_manager.cpp b/src/tls/tls_session_manager.cpp
index 673ee90ff..ca18231a0 100644
--- a/src/tls/tls_session_manager.cpp
+++ b/src/tls/tls_session_manager.cpp
@@ -61,27 +61,24 @@ bool Session_Manager_In_Memory::load_from_session_id(
return load_from_session_str(hex_encode(session_id), session);
}
-bool Session_Manager_In_Memory::load_from_host_info(
- const std::string& hostname, u16bit port, Session& session)
+bool Session_Manager_In_Memory::load_from_server_info(
+ const Server_Information& info, Session& session)
{
std::lock_guard<std::mutex> lock(m_mutex);
- auto i = m_host_sessions.find(hostname + ":" + std::to_string(port));
+ auto i = m_info_sessions.find(info);
- if(i == m_host_sessions.end())
- {
- if(port > 0)
- i = m_host_sessions.find(hostname + ":" + std::to_string(0));
-
- if(i == m_host_sessions.end())
- return false;
- }
+ if(i == m_info_sessions.end())
+ return false;
if(load_from_session_str(i->second, session))
return true;
- // was removed from sessions map, remove m_host_sessions entry
- m_host_sessions.erase(i);
+ /*
+ * It existed at one point but was removed from the sessions map,
+ * remove m_info_sessions entry as well
+ */
+ m_info_sessions.erase(i);
return false;
}
@@ -97,7 +94,7 @@ void Session_Manager_In_Memory::remove_entry(
m_sessions.erase(i);
}
-void Session_Manager_In_Memory::save(const Session& session, u16bit port)
+void Session_Manager_In_Memory::save(const Session& session)
{
std::lock_guard<std::mutex> lock(m_mutex);
@@ -115,10 +112,8 @@ void Session_Manager_In_Memory::save(const Session& session, u16bit port)
m_sessions[session_id_str] = session.encrypt(m_session_key, m_rng);
- const std::string hostname = session.sni_hostname();
-
- if(session.side() == CLIENT && hostname != "")
- m_host_sessions[hostname + ":" + std::to_string(port)] = session_id_str;
+ if(session.side() == CLIENT && !session.server_info().empty())
+ m_info_sessions[session.server_info()] = session_id_str;
}
}
diff --git a/src/tls/tls_session_manager.h b/src/tls/tls_session_manager.h
index 4efefb6ff..d7c805195 100644
--- a/src/tls/tls_session_manager.h
+++ b/src/tls/tls_session_manager.h
@@ -30,7 +30,7 @@ class BOTAN_DLL Session_Manager
{
public:
/**
- * Try to load a saved session (server side)
+ * Try to load a saved session (using session ID)
* @param session_id the session identifier we are trying to resume
* @param session will be set to the saved session data (if found),
or not modified if not found
@@ -40,15 +40,14 @@ class BOTAN_DLL Session_Manager
Session& session) = 0;
/**
- * Try to load a saved session (client side)
- * @param hostname of the host we are connecting to
- * @param port the port number if we know it, or 0 if unknown
+ * Try to load a saved session (using info about server)
+ * @param info the information about the server
* @param session will be set to the saved session data (if found),
or not modified if not found
* @return true if session was modified
*/
- virtual bool load_from_host_info(const std::string& hostname, u16bit port,
- Session& session) = 0;
+ virtual bool load_from_server_info(const Server_Information& info,
+ Session& session) = 0;
/**
* Remove this session id from the cache, if it exists
@@ -64,7 +63,7 @@ class BOTAN_DLL Session_Manager
* @param session to save
* @param port the protocol port (if known)
*/
- virtual void save(const Session& session, u16bit port = 0) = 0;
+ virtual void save(const Session& session) = 0;
/**
* Return the allowed lifetime of a session; beyond this time,
@@ -86,12 +85,12 @@ class BOTAN_DLL Session_Manager_Noop : public Session_Manager
bool load_from_session_id(const std::vector<byte>&, Session&) override
{ return false; }
- bool load_from_host_info(const std::string&, u16bit, Session&) override
+ bool load_from_server_info(const Server_Information&, Session&) override
{ return false; }
void remove_entry(const std::vector<byte>&) override {}
- void save(const Session&, u16bit) override {}
+ void save(const Session&) override {}
std::chrono::seconds session_lifetime() const override
{ return std::chrono::seconds(0); }
@@ -116,12 +115,12 @@ class BOTAN_DLL Session_Manager_In_Memory : public Session_Manager
bool load_from_session_id(const std::vector<byte>& session_id,
Session& session) override;
- bool load_from_host_info(const std::string& hostname, u16bit port,
- Session& session) override;
+ bool load_from_server_info(const Server_Information& info,
+ Session& session) override;
void remove_entry(const std::vector<byte>& session_id) override;
- void save(const Session& session_data, u16bit port) override;
+ void save(const Session& session_data) override;
std::chrono::seconds session_lifetime() const override
{ return m_session_lifetime; }
@@ -140,7 +139,7 @@ class BOTAN_DLL Session_Manager_In_Memory : public Session_Manager
SymmetricKey m_session_key;
std::map<std::string, std::vector<byte>> m_sessions; // hex(session_id) -> session
- std::map<std::string, std::string> m_host_sessions;
+ std::map<Server_Information, std::string> m_info_sessions;
};
}
diff --git a/src/tls/tls_version.h b/src/tls/tls_version.h
index 651eebafc..39712db27 100644
--- a/src/tls/tls_version.h
+++ b/src/tls/tls_version.h
@@ -31,6 +31,16 @@ class BOTAN_DLL Protocol_Version
DTLS_V12 = 0xFEFD
};
+ static Protocol_Version latest_tls_version()
+ {
+ return Protocol_Version(TLS_V12);
+ }
+
+ static Protocol_Version latest_dtls_version()
+ {
+ return Protocol_Version(DTLS_V12);
+ }
+
Protocol_Version() : m_version(0) {}
/**
diff --git a/src/utils/assert.h b/src/utils/assert.h
index 88d514b43..d92b41111 100644
--- a/src/utils/assert.h
+++ b/src/utils/assert.h
@@ -35,10 +35,10 @@ void assertion_failure(const char* expr_str,
/**
* Assert that value1 == value2
*/
-#define BOTAN_ASSERT_EQUAL(value1, value2, assertion_made) \
+#define BOTAN_ASSERT_EQUAL(expr1, expr2, assertion_made) \
do { \
- if(value1 != value2) \
- Botan::assertion_failure(#value1 " == " #value2, \
+ if((expr1) != (expr2)) \
+ Botan::assertion_failure(#expr1 " == " #expr2, \
assertion_made, \
__func__, \
__FILE__, \