aboutsummaryrefslogtreecommitdiffstats
path: root/src/tls
diff options
context:
space:
mode:
Diffstat (limited to 'src/tls')
-rw-r--r--src/tls/c_hello.cpp86
-rw-r--r--src/tls/c_kex.cpp7
-rw-r--r--src/tls/cert_req.cpp40
-rw-r--r--src/tls/cert_ver.cpp7
-rw-r--r--src/tls/finished.cpp7
-rw-r--r--src/tls/hello_verify.cpp61
-rw-r--r--src/tls/info.txt6
-rw-r--r--src/tls/next_protocol.cpp3
-rw-r--r--src/tls/rec_wri.cpp44
-rw-r--r--src/tls/s_hello.cpp26
-rw-r--r--src/tls/s_kex.cpp3
-rw-r--r--src/tls/session_ticket.cpp57
-rw-r--r--src/tls/sessions_sqlite/info.txt11
-rw-r--r--src/tls/sessions_sqlite/tls_sqlite_sess_mgr.cpp343
-rw-r--r--src/tls/sessions_sqlite/tls_sqlite_sess_mgr.h68
-rw-r--r--src/tls/tls_channel.cpp37
-rw-r--r--src/tls/tls_channel.h4
-rw-r--r--src/tls/tls_ciphersuite.cpp2
-rw-r--r--src/tls/tls_client.cpp87
-rw-r--r--src/tls/tls_extensions.cpp9
-rw-r--r--src/tls/tls_extensions.h35
-rw-r--r--src/tls/tls_handshake_reader.cpp66
-rw-r--r--src/tls/tls_handshake_reader.h58
-rw-r--r--src/tls/tls_handshake_state.cpp54
-rw-r--r--src/tls/tls_handshake_state.h16
-rw-r--r--src/tls/tls_magic.h35
-rw-r--r--src/tls/tls_messages.h72
-rw-r--r--src/tls/tls_policy.cpp10
-rw-r--r--src/tls/tls_policy.h6
-rw-r--r--src/tls/tls_reader.h9
-rw-r--r--src/tls/tls_record.h2
-rw-r--r--src/tls/tls_server.cpp153
-rw-r--r--src/tls/tls_session.cpp137
-rw-r--r--src/tls/tls_session.h40
-rw-r--r--src/tls/tls_session_key.cpp4
35 files changed, 1358 insertions, 247 deletions
diff --git a/src/tls/c_hello.cpp b/src/tls/c_hello.cpp
index 6743254a5..b08e1abe2 100644
--- a/src/tls/c_hello.cpp
+++ b/src/tls/c_hello.cpp
@@ -30,34 +30,11 @@ MemoryVector<byte> make_hello_random(RandomNumberGenerator& rng)
}
/*
-* Encode and send a Handshake message
-*/
-void Handshake_Message::send(Record_Writer& writer, Handshake_Hash& hash) const
- {
- MemoryVector<byte> buf = serialize();
- MemoryVector<byte> send_buf(4);
-
- const size_t buf_size = buf.size();
-
- send_buf[0] = type();
-
- for(size_t i = 1; i != 4; ++i)
- send_buf[i] = get_byte<u32bit>(i, buf_size);
-
- send_buf += buf;
-
- hash.update(send_buf);
-
- writer.send(HANDSHAKE, &send_buf[0], send_buf.size());
- }
-
-/*
* Create a new Hello Request message
*/
Hello_Request::Hello_Request(Record_Writer& writer)
{
- Handshake_Hash dummy; // FIXME: *UGLY*
- send(writer, dummy);
+ writer.send(*this);
}
/*
@@ -66,7 +43,7 @@ Hello_Request::Hello_Request(Record_Writer& writer)
Hello_Request::Hello_Request(const MemoryRegion<byte>& buf)
{
if(buf.size())
- throw Decoding_Error("Hello_Request: Must be empty, and is not");
+ throw Decoding_Error("Bad Hello_Request, has non-zero size");
}
/*
@@ -97,49 +74,67 @@ Client_Hello::Client_Hello(Record_Writer& writer,
m_next_protocol(next_protocol),
m_fragment_size(0),
m_secure_renegotiation(true),
- m_renegotiation_info(reneg_info)
+ m_renegotiation_info(reneg_info),
+ m_supported_curves(policy.allowed_ecc_curves()),
+ m_supports_session_ticket(true)
{
std::vector<std::string> hashes = policy.allowed_hashes();
std::vector<std::string> sigs = policy.allowed_signature_methods();
- m_supported_curves = policy.allowed_ecc_curves();
-
for(size_t i = 0; i != hashes.size(); ++i)
for(size_t j = 0; j != sigs.size(); ++j)
m_supported_algos.push_back(std::make_pair(hashes[i], sigs[j]));
- send(writer, hash);
+ hash.update(writer.send(*this));
}
/*
-* Create a new Client Hello message
+* Create a new Client Hello message (session resumption case)
*/
Client_Hello::Client_Hello(Record_Writer& writer,
Handshake_Hash& hash,
+ const Policy& policy,
RandomNumberGenerator& rng,
const Session& session,
bool next_protocol) :
m_version(session.version()),
m_session_id(session.session_id()),
m_random(make_hello_random(rng)),
+ m_suites(policy.ciphersuite_list(session.srp_identifier() != "")),
+ m_comp_methods(policy.compression()),
m_hostname(session.sni_hostname()),
m_srp_identifier(session.srp_identifier()),
m_next_protocol(next_protocol),
m_fragment_size(session.fragment_size()),
- m_secure_renegotiation(session.secure_renegotiation())
+ m_secure_renegotiation(session.secure_renegotiation()),
+ m_supported_curves(policy.allowed_ecc_curves()),
+ m_supports_session_ticket(true),
+ m_session_ticket(session.session_ticket())
{
- m_suites.push_back(session.ciphersuite_code());
- m_comp_methods.push_back(session.compression_method());
+ if(!value_exists(m_suites, session.ciphersuite_code()))
+ m_suites.push_back(session.ciphersuite_code());
- // set m_supported_algos + m_supported_curves here?
+ if(!value_exists(m_comp_methods, session.compression_method()))
+ m_comp_methods.push_back(session.compression_method());
- send(writer, hash);
+ std::vector<std::string> hashes = policy.allowed_hashes();
+ std::vector<std::string> sigs = policy.allowed_signature_methods();
+
+ for(size_t i = 0; i != hashes.size(); ++i)
+ for(size_t j = 0; j != sigs.size(); ++j)
+ m_supported_algos.push_back(std::make_pair(hashes[i], sigs[j]));
+
+ hash.update(writer.send(*this));
}
+/*
+* Read a counterparty client hello
+*/
Client_Hello::Client_Hello(const MemoryRegion<byte>& buf, Handshake_Type type)
{
m_next_protocol = false;
m_secure_renegotiation = false;
+ m_supports_session_ticket = false;
m_fragment_size = 0;
if(type == CLIENT_HELLO)
@@ -185,11 +180,14 @@ MemoryVector<byte> Client_Hello::serialize() const
if(m_next_protocol)
extensions.add(new Next_Protocol_Notification());
+
+ extensions.add(new Session_Ticket(m_session_ticket));
}
else
{
// renegotiation
extensions.add(new Renegotation_Extension(m_renegotiation_info));
+ extensions.add(new Session_Ticket(m_session_ticket));
}
buf += extensions.serialize();
@@ -328,6 +326,24 @@ void Client_Hello::deserialize(const MemoryRegion<byte>& buf)
}
}
+ if(Maximum_Fragment_Length* frag = extensions.get<Maximum_Fragment_Length>())
+ {
+ m_fragment_size = frag->fragment_size();
+ }
+
+ if(Session_Ticket* ticket = extensions.get<Session_Ticket>())
+ {
+ m_supports_session_ticket = true;
+ m_session_ticket = ticket->contents();
+ }
+
+ if(Renegotation_Extension* reneg = extensions.get<Renegotation_Extension>())
+ {
+ // checked by TLS_Client / TLS_Server as they know the handshake state
+ m_secure_renegotiation = true;
+ m_renegotiation_info = reneg->renegotiation_info();
+ }
+
if(value_exists(m_suites, static_cast<u16bit>(TLS_EMPTY_RENEGOTIATION_INFO_SCSV)))
{
/*
diff --git a/src/tls/c_kex.cpp b/src/tls/c_kex.cpp
index 6728b4877..0e7eba23a 100644
--- a/src/tls/c_kex.cpp
+++ b/src/tls/c_kex.cpp
@@ -8,6 +8,7 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_reader.h>
#include <botan/internal/tls_extensions.h>
+#include <botan/tls_record.h>
#include <botan/internal/assert.h>
#include <botan/credentials_manager.h>
#include <botan/pubkey.h>
@@ -201,7 +202,7 @@ Client_Key_Exchange::Client_Key_Exchange(Record_Writer& writer,
MemoryVector<byte> encrypted_key = encryptor.encrypt(pre_master, rng);
- if(state->version == Protocol_Version::SSL_V3)
+ if(state->version() == Protocol_Version::SSL_V3)
key_material = encrypted_key; // no length field
else
append_tls_length_value(key_material, encrypted_key, 2);
@@ -212,7 +213,7 @@ Client_Key_Exchange::Client_Key_Exchange(Record_Writer& writer,
pub_key->algo_name());
}
- send(writer, state->hash);
+ state->hash.update(writer.send(*this));
}
/*
@@ -245,7 +246,7 @@ Client_Key_Exchange::Client_Key_Exchange(const MemoryRegion<byte>& contents,
try
{
- if(state->version == Protocol_Version::SSL_V3)
+ if(state->version() == Protocol_Version::SSL_V3)
{
pre_master = decryptor.decrypt(contents);
}
diff --git a/src/tls/cert_req.cpp b/src/tls/cert_req.cpp
index f400a36d2..df70cc43d 100644
--- a/src/tls/cert_req.cpp
+++ b/src/tls/cert_req.cpp
@@ -8,10 +8,10 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_reader.h>
#include <botan/internal/tls_extensions.h>
+#include <botan/tls_record.h>
#include <botan/der_enc.h>
#include <botan/ber_dec.h>
#include <botan/loadstor.h>
-#include <botan/secqueue.h>
namespace Botan {
@@ -74,7 +74,7 @@ Certificate_Req::Certificate_Req(Record_Writer& writer,
m_supported_algos.push_back(std::make_pair(hashes[i], sigs[j]));
}
- send(writer, hash);
+ hash.update(writer.send(*this));
}
/**
@@ -173,10 +173,10 @@ MemoryVector<byte> Certificate_Req::serialize() const
*/
Certificate::Certificate(Record_Writer& writer,
Handshake_Hash& hash,
- const std::vector<X509_Certificate>& cert_list)
+ const std::vector<X509_Certificate>& cert_list) :
+ m_certs(cert_list)
{
- certs = cert_list;
- send(writer, hash);
+ hash.update(writer.send(*this));
}
/**
@@ -189,27 +189,25 @@ Certificate::Certificate(const MemoryRegion<byte>& buf)
const size_t total_size = make_u32bit(0, buf[0], buf[1], buf[2]);
- SecureQueue queue;
- queue.write(&buf[3], buf.size() - 3);
-
- if(queue.size() != total_size)
+ if(total_size != buf.size() - 3)
throw Decoding_Error("Certificate: Message malformed");
- while(queue.size())
+ const byte* certs = &buf[3];
+
+ while(certs != buf.end())
{
- if(queue.size() < 3)
+ if(buf.end() - certs < 3)
throw Decoding_Error("Certificate: Message malformed");
- byte len[3];
- queue.read(len, 3);
-
- const size_t cert_size = make_u32bit(0, len[0], len[1], len[2]);
- const size_t original_size = queue.size();
+ const size_t cert_size = make_u32bit(0, certs[0], certs[1], certs[2]);
- X509_Certificate cert(queue);
- if(queue.size() + cert_size != original_size)
+ if(buf.end() - certs < (3 + cert_size))
throw Decoding_Error("Certificate: Message malformed");
- certs.push_back(cert);
+
+ DataSource_Memory cert_buf(&certs[3], cert_size);
+ m_certs.push_back(X509_Certificate(cert_buf));
+
+ certs += cert_size + 3;
}
}
@@ -220,9 +218,9 @@ MemoryVector<byte> Certificate::serialize() const
{
MemoryVector<byte> buf(3);
- for(size_t i = 0; i != certs.size(); ++i)
+ for(size_t i = 0; i != m_certs.size(); ++i)
{
- MemoryVector<byte> raw_cert = certs[i].BER_encode();
+ MemoryVector<byte> raw_cert = m_certs[i].BER_encode();
const size_t cert_size = raw_cert.size();
for(size_t i = 0; i != 3; ++i)
buf.push_back(get_byte<u32bit>(i+1, cert_size));
diff --git a/src/tls/cert_ver.cpp b/src/tls/cert_ver.cpp
index ffcba84d8..0a377b35f 100644
--- a/src/tls/cert_ver.cpp
+++ b/src/tls/cert_ver.cpp
@@ -8,6 +8,7 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_reader.h>
#include <botan/internal/tls_extensions.h>
+#include <botan/tls_record.h>
#include <botan/internal/assert.h>
#include <memory>
@@ -30,7 +31,7 @@ Certificate_Verify::Certificate_Verify(Record_Writer& writer,
PK_Signer signer(*priv_key, format.first, format.second);
- if(state->version == Protocol_Version::SSL_V3)
+ if(state->version() == Protocol_Version::SSL_V3)
{
SecureVector<byte> md5_sha = state->hash.final_ssl3(
state->keys.master_secret());
@@ -45,7 +46,7 @@ Certificate_Verify::Certificate_Verify(Record_Writer& writer,
signature = signer.sign_message(state->hash.get_contents(), rng);
}
- send(writer, state->hash);
+ state->hash.update(writer.send(*this));
}
/*
@@ -99,7 +100,7 @@ bool Certificate_Verify::verify(const X509_Certificate& cert,
PK_Verifier verifier(*key, format.first, format.second);
- if(state->version == Protocol_Version::SSL_V3)
+ if(state->version() == Protocol_Version::SSL_V3)
{
SecureVector<byte> md5_sha = state->hash.final_ssl3(
state->keys.master_secret());
diff --git a/src/tls/finished.cpp b/src/tls/finished.cpp
index b4c9bdc3d..bb6e8d20e 100644
--- a/src/tls/finished.cpp
+++ b/src/tls/finished.cpp
@@ -6,6 +6,7 @@
*/
#include <botan/internal/tls_messages.h>
+#include <botan/tls_record.h>
#include <memory>
namespace Botan {
@@ -20,7 +21,7 @@ namespace {
MemoryVector<byte> finished_compute_verify(Handshake_State* state,
Connection_Side side)
{
- if(state->version == Protocol_Version::SSL_V3)
+ if(state->version() == Protocol_Version::SSL_V3)
{
const byte SSL_CLIENT_LABEL[] = { 0x43, 0x4C, 0x4E, 0x54 };
const byte SSL_SERVER_LABEL[] = { 0x53, 0x52, 0x56, 0x52 };
@@ -54,7 +55,7 @@ MemoryVector<byte> finished_compute_verify(Handshake_State* state,
else
input += std::make_pair(TLS_SERVER_LABEL, sizeof(TLS_SERVER_LABEL));
- input += state->hash.final(state->version, state->suite.mac_algo());
+ input += state->hash.final(state->version(), state->suite.mac_algo());
return prf->derive_key(12, state->keys.master_secret(), input);
}
@@ -70,7 +71,7 @@ Finished::Finished(Record_Writer& writer,
Connection_Side side)
{
verification_data = finished_compute_verify(state, side);
- send(writer, state->hash);
+ state->hash.update(writer.send(*this));
}
/*
diff --git a/src/tls/hello_verify.cpp b/src/tls/hello_verify.cpp
new file mode 100644
index 000000000..c7aae94a1
--- /dev/null
+++ b/src/tls/hello_verify.cpp
@@ -0,0 +1,61 @@
+/*
+* DTLS Hello Verify Request
+* (C) 2012 Jack Lloyd
+*
+* Released under the terms of the Botan license
+*/
+
+#include <botan/internal/tls_messages.h>
+#include <botan/lookup.h>
+#include <memory>
+
+namespace Botan {
+
+namespace TLS {
+
+Hello_Verify_Request::Hello_Verify_Request(const MemoryRegion<byte>& buf)
+ {
+ if(buf.size() < 3)
+ throw Decoding_Error("Hello verify request too small");
+
+ if(buf[0] != 254 || (buf[1] != 255 && buf[1] != 253))
+ throw Decoding_Error("Unknown version from server in hello verify request");
+
+ m_cookie.resize(buf.size() - 2);
+ copy_mem(&m_cookie[0], &buf[2], buf.size() - 2);
+ }
+
+Hello_Verify_Request::Hello_Verify_Request(const MemoryVector<byte>& client_hello_bits,
+ const std::string& client_identity,
+ const SymmetricKey& secret_key)
+ {
+ std::auto_ptr<MessageAuthenticationCode> hmac(get_mac("HMAC(SHA-256)"));
+ hmac->set_key(secret_key);
+
+ hmac->update_be(client_hello_bits.size());
+ hmac->update(client_hello_bits);
+ hmac->update_be(client_identity.size());
+ hmac->update(client_identity);
+
+ m_cookie = hmac->final();
+ }
+
+MemoryVector<byte> Hello_Verify_Request::serialize() const
+ {
+ /* DTLS 1.2 server implementations SHOULD use DTLS version 1.0
+ regardless of the version of TLS that is expected to be
+ negotiated (RFC 6347, section 4.2.1)
+ */
+
+ Protocol_Version format_version(Protocol_Version::TLS_V11);
+
+ MemoryVector<byte> bits;
+ bits.push_back(format_version.major_version());
+ bits.push_back(format_version.minor_version());
+ bits += m_cookie;
+ return bits;
+ }
+
+}
+
+}
diff --git a/src/tls/info.txt b/src/tls/info.txt
index 68ca026d5..b19eedb20 100644
--- a/src/tls/info.txt
+++ b/src/tls/info.txt
@@ -23,6 +23,7 @@ tls_version.h
<header:internal>
tls_extensions.h
tls_handshake_hash.h
+tls_handshake_reader.h
tls_handshake_state.h
tls_messages.h
tls_reader.h
@@ -36,15 +37,18 @@ c_kex.cpp
cert_req.cpp
cert_ver.cpp
finished.cpp
+hello_verify.cpp
next_protocol.cpp
rec_read.cpp
rec_wri.cpp
s_hello.cpp
s_kex.cpp
+session_ticket.cpp
tls_channel.cpp
tls_client.cpp
tls_extensions.cpp
tls_handshake_hash.cpp
+tls_handshake_reader.cpp
tls_handshake_state.cpp
tls_policy.cpp
tls_server.cpp
@@ -59,6 +63,7 @@ tls_version.cpp
aes
arc4
asn1
+credentials
des
dh
dsa
@@ -68,6 +73,7 @@ eme_pkcs
emsa3
filters
hmac
+kdf2
md5
prf_ssl3
prf_tls
diff --git a/src/tls/next_protocol.cpp b/src/tls/next_protocol.cpp
index 97b072440..17b77fb6e 100644
--- a/src/tls/next_protocol.cpp
+++ b/src/tls/next_protocol.cpp
@@ -8,6 +8,7 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_extensions.h>
#include <botan/internal/tls_reader.h>
+#include <botan/tls_record.h>
namespace Botan {
@@ -18,7 +19,7 @@ Next_Protocol::Next_Protocol(Record_Writer& writer,
const std::string& protocol) :
m_protocol(protocol)
{
- send(writer, hash);
+ hash.update(writer.send(*this));
}
Next_Protocol::Next_Protocol(const MemoryRegion<byte>& buf)
diff --git a/src/tls/rec_wri.cpp b/src/tls/rec_wri.cpp
index 633b63720..3a54d7931 100644
--- a/src/tls/rec_wri.cpp
+++ b/src/tls/rec_wri.cpp
@@ -6,6 +6,7 @@
*/
#include <botan/tls_record.h>
+#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_session_key.h>
#include <botan/internal/tls_handshake_hash.h>
#include <botan/lookup.h>
@@ -22,9 +23,10 @@ namespace TLS {
* Record_Writer Constructor
*/
Record_Writer::Record_Writer(std::function<void (const byte[], size_t)> out) :
- m_output_fn(out), m_writebuf(TLS_HEADER_SIZE + MAX_CIPHERTEXT_SIZE)
+ m_output_fn(out),
+ m_writebuf(TLS_HEADER_SIZE + MAX_CIPHERTEXT_SIZE),
+ m_mac(0)
{
- m_mac = 0;
reset();
set_maximum_fragment_size(0);
}
@@ -144,6 +146,25 @@ void Record_Writer::activate(Connection_Side side,
throw Invalid_Argument("Record_Writer: Unknown hash " + mac_algo);
}
+MemoryVector<byte> Record_Writer::send(Handshake_Message& msg)
+ {
+ const MemoryVector<byte> buf = msg.serialize();
+ MemoryVector<byte> send_buf(4);
+
+ const size_t buf_size = buf.size();
+
+ send_buf[0] = msg.type();
+
+ for(size_t i = 1; i != 4; ++i)
+ send_buf[i] = get_byte<u32bit>(i, buf_size);
+
+ send_buf += buf;
+
+ send(HANDSHAKE, &send_buf[0], send_buf.size());
+
+ return send_buf;
+ }
+
/*
* Send one or more records to the other side
*/
@@ -190,16 +211,15 @@ void Record_Writer::send_record(byte type, const byte input[], size_t length)
if(m_mac_size == 0) // initial unencrypted handshake records
{
- const byte header[TLS_HEADER_SIZE] = {
- type,
- m_version.major_version(),
- m_version.minor_version(),
- get_byte<u16bit>(0, length),
- get_byte<u16bit>(1, length)
- };
-
- m_output_fn(header, TLS_HEADER_SIZE);
- m_output_fn(input, length);
+ m_writebuf[0] = type;
+ m_writebuf[1] = m_version.major_version();
+ m_writebuf[2] = m_version.minor_version();
+ m_writebuf[3] = get_byte<u16bit>(0, length);
+ m_writebuf[4] = get_byte<u16bit>(1, length);
+
+ copy_mem(&m_writebuf[TLS_HEADER_SIZE], input, length);
+
+ m_output_fn(&m_writebuf[0], TLS_HEADER_SIZE + length);
return;
}
diff --git a/src/tls/s_hello.cpp b/src/tls/s_hello.cpp
index 0ad78fc5b..7da9fdc57 100644
--- a/src/tls/s_hello.cpp
+++ b/src/tls/s_hello.cpp
@@ -25,6 +25,7 @@ Server_Hello::Server_Hello(Record_Writer& writer,
const Client_Hello& c_hello,
const std::vector<std::string>& available_cert_types,
const Policy& policy,
+ bool have_session_ticket_key,
bool client_has_secure_renegotiation,
const MemoryRegion<byte>& reneg_info,
bool client_has_npn,
@@ -37,7 +38,9 @@ Server_Hello::Server_Hello(Record_Writer& writer,
m_secure_renegotiation(client_has_secure_renegotiation),
m_renegotiation_info(reneg_info),
m_next_protocol(client_has_npn),
- m_next_protocols(next_protocols)
+ m_next_protocols(next_protocols),
+ m_supports_session_ticket(have_session_ticket_key &&
+ c_hello.supports_session_ticket())
{
suite = policy.choose_suite(
c_hello.ciphersuites(),
@@ -51,7 +54,7 @@ Server_Hello::Server_Hello(Record_Writer& writer,
comp_method = policy.choose_compression(c_hello.compression_methods());
- send(writer, hash);
+ hash.update(writer.send(*this));
}
/*
@@ -66,6 +69,7 @@ Server_Hello::Server_Hello(Record_Writer& writer,
size_t max_fragment_size,
bool client_has_secure_renegotiation,
const MemoryRegion<byte>& reneg_info,
+ bool client_supports_session_tickets,
bool client_has_npn,
const std::vector<std::string>& next_protocols,
RandomNumberGenerator& rng) :
@@ -78,9 +82,10 @@ Server_Hello::Server_Hello(Record_Writer& writer,
m_secure_renegotiation(client_has_secure_renegotiation),
m_renegotiation_info(reneg_info),
m_next_protocol(client_has_npn),
- m_next_protocols(next_protocols)
+ m_next_protocols(next_protocols),
+ m_supports_session_ticket(client_supports_session_tickets)
{
- send(writer, hash);
+ hash.update(writer.send(*this));
}
/*
@@ -89,6 +94,7 @@ Server_Hello::Server_Hello(Record_Writer& writer,
Server_Hello::Server_Hello(const MemoryRegion<byte>& buf)
{
m_secure_renegotiation = false;
+ m_supports_session_ticket = false;
m_next_protocol = false;
if(buf.size() < 38)
@@ -132,6 +138,13 @@ Server_Hello::Server_Hello(const MemoryRegion<byte>& buf)
m_next_protocols = npn->protocols();
m_next_protocol = true;
}
+
+ if(Session_Ticket* ticket = extensions.get<Session_Ticket>())
+ {
+ if(!ticket->contents().empty())
+ throw Decoding_Error("TLS server sent non-empty session ticket extension");
+ m_supports_session_ticket = true;
+ }
}
/*
@@ -163,6 +176,9 @@ MemoryVector<byte> Server_Hello::serialize() const
if(m_next_protocol)
extensions.add(new Next_Protocol_Notification(m_next_protocols));
+ if(m_supports_session_ticket)
+ extensions.add(new Session_Ticket());
+
buf += extensions.serialize();
return buf;
@@ -174,7 +190,7 @@ MemoryVector<byte> Server_Hello::serialize() const
Server_Hello_Done::Server_Hello_Done(Record_Writer& writer,
Handshake_Hash& hash)
{
- send(writer, hash);
+ hash.update(writer.send(*this));
}
/*
diff --git a/src/tls/s_kex.cpp b/src/tls/s_kex.cpp
index 945c574b9..0890cac49 100644
--- a/src/tls/s_kex.cpp
+++ b/src/tls/s_kex.cpp
@@ -8,6 +8,7 @@
#include <botan/internal/tls_messages.h>
#include <botan/internal/tls_reader.h>
#include <botan/internal/tls_extensions.h>
+#include <botan/tls_record.h>
#include <botan/internal/assert.h>
#include <botan/credentials_manager.h>
#include <botan/loadstor.h>
@@ -105,7 +106,7 @@ Server_Key_Exchange::Server_Key_Exchange(Record_Writer& writer,
m_signature = signer.signature(rng);
}
- send(writer, state->hash);
+ state->hash.update(writer.send(*this));
}
/**
diff --git a/src/tls/session_ticket.cpp b/src/tls/session_ticket.cpp
new file mode 100644
index 000000000..273996a16
--- /dev/null
+++ b/src/tls/session_ticket.cpp
@@ -0,0 +1,57 @@
+/*
+* Session Tickets
+* (C) 2012 Jack Lloyd
+*
+* Released under the terms of the Botan license
+*/
+
+#include <botan/internal/tls_messages.h>
+#include <botan/internal/tls_extensions.h>
+#include <botan/internal/tls_reader.h>
+#include <botan/tls_record.h>
+#include <botan/loadstor.h>
+
+namespace Botan {
+
+namespace TLS {
+
+New_Session_Ticket::New_Session_Ticket(Record_Writer& writer,
+ Handshake_Hash& hash,
+ const MemoryRegion<byte>& ticket,
+ u32bit lifetime) :
+ m_ticket_lifetime_hint(lifetime),
+ m_ticket(ticket)
+ {
+ hash.update(writer.send(*this));
+ }
+
+New_Session_Ticket::New_Session_Ticket(Record_Writer& writer,
+ Handshake_Hash& hash) :
+ m_ticket_lifetime_hint(0)
+ {
+ hash.update(writer.send(*this));
+ }
+
+New_Session_Ticket::New_Session_Ticket(const MemoryRegion<byte>& buf) :
+ m_ticket_lifetime_hint(0)
+ {
+ if(buf.size() < 6)
+ throw Decoding_Error("Session ticket message too short to be valid");
+
+ TLS_Data_Reader reader(buf);
+
+ m_ticket_lifetime_hint = reader.get_u32bit();
+ m_ticket = reader.get_range<byte>(2, 0, 65535);
+ }
+
+MemoryVector<byte> New_Session_Ticket::serialize() const
+ {
+ MemoryVector<byte> buf(4);
+ store_be(m_ticket_lifetime_hint, &buf[0]);
+ append_tls_length_value(buf, m_ticket, 2);
+ return buf;
+ }
+
+}
+
+}
diff --git a/src/tls/sessions_sqlite/info.txt b/src/tls/sessions_sqlite/info.txt
new file mode 100644
index 000000000..c5fc35952
--- /dev/null
+++ b/src/tls/sessions_sqlite/info.txt
@@ -0,0 +1,11 @@
+define TLS_SQLITE_SESSION_MANAGER
+
+load_on request
+
+<libs>
+all -> sqlite3
+</libs>
+
+<requires>
+pbkdf2
+</requires>
diff --git a/src/tls/sessions_sqlite/tls_sqlite_sess_mgr.cpp b/src/tls/sessions_sqlite/tls_sqlite_sess_mgr.cpp
new file mode 100644
index 000000000..4d78a5365
--- /dev/null
+++ b/src/tls/sessions_sqlite/tls_sqlite_sess_mgr.cpp
@@ -0,0 +1,343 @@
+/*
+* SQLite TLS Session Manager
+* (C) 2012 Jack Lloyd
+*
+* Released under the terms of the Botan license
+*/
+
+#include <botan/tls_sqlite_sess_mgr.h>
+#include <botan/internal/assert.h>
+#include <botan/lookup.h>
+#include <botan/hex.h>
+#include <botan/time.h>
+#include <botan/loadstor.h>
+#include <memory>
+
+#include <sqlite3.h>
+
+namespace Botan {
+
+namespace TLS {
+
+namespace {
+
+class sqlite3_statement
+ {
+ public:
+ sqlite3_statement(sqlite3* db, const std::string& base_sql)
+ {
+ int rc = sqlite3_prepare_v2(db, base_sql.c_str(), -1, &m_stmt, 0);
+
+ if(rc != SQLITE_OK)
+ throw std::runtime_error("sqlite3_prepare failed " + base_sql + ", code " + to_string(rc));
+ }
+
+ void bind(int column, const std::string& val)
+ {
+ int rc = sqlite3_bind_text(m_stmt, column, val.c_str(), -1, SQLITE_TRANSIENT);
+ if(rc != SQLITE_OK)
+ throw std::runtime_error("sqlite3_bind_text failed, code " + to_string(rc));
+ }
+
+ void bind(int column, int val)
+ {
+ int rc = sqlite3_bind_int(m_stmt, column, val);
+ if(rc != SQLITE_OK)
+ throw std::runtime_error("sqlite3_bind_int failed, code " + to_string(rc));
+ }
+
+ void bind(int column, const MemoryRegion<byte>& val)
+ {
+ int rc = sqlite3_bind_blob(m_stmt, column, &val[0], val.size(), SQLITE_TRANSIENT);
+ if(rc != SQLITE_OK)
+ throw std::runtime_error("sqlite3_bind_text failed, code " + to_string(rc));
+ }
+
+ std::pair<const byte*, size_t> get_blob(int column)
+ {
+ BOTAN_ASSERT(sqlite3_column_type(m_stmt, 0) == SQLITE_BLOB,
+ "Return value is a blob");
+
+ const void* session_blob = sqlite3_column_blob(m_stmt, column);
+ const int session_blob_size = sqlite3_column_bytes(m_stmt, column);
+
+ BOTAN_ASSERT(session_blob_size >= 0, "Blob size is non-negative");
+
+ return std::make_pair(static_cast<const byte*>(session_blob),
+ static_cast<size_t>(session_blob_size));
+ }
+
+ size_t get_size_t(int column)
+ {
+ BOTAN_ASSERT(sqlite3_column_type(m_stmt, column) == SQLITE_INTEGER,
+ "Return count is an integer");
+
+ const int sessions_int = sqlite3_column_int(m_stmt, column);
+
+ BOTAN_ASSERT(sessions_int >= 0, "Expected size_t is non-negative");
+
+ return static_cast<size_t>(sessions_int);
+ }
+
+ void spin()
+ {
+ while(sqlite3_step(m_stmt) == SQLITE_ROW)
+ {}
+ }
+
+ int step()
+ {
+ return sqlite3_step(m_stmt);
+ }
+
+ sqlite3_stmt* stmt() { return m_stmt; }
+
+ ~sqlite3_statement() { sqlite3_finalize(m_stmt); }
+ private:
+ sqlite3_stmt* m_stmt;
+ };
+
+size_t row_count(sqlite3* db, const std::string& table_name)
+ {
+ sqlite3_statement stmt(db, "select count(*) from " + table_name);
+
+ if(stmt.step() == SQLITE_ROW)
+ return stmt.get_size_t(0);
+ else
+ throw std::runtime_error("Querying size of table " + table_name + " failed");
+ }
+
+void create_table(sqlite3* db, const char* table_schema)
+ {
+ char* errmsg = 0;
+ int rc = sqlite3_exec(db, table_schema, 0, 0, &errmsg);
+
+ if(rc != SQLITE_OK)
+ {
+ const std::string err_msg = errmsg;
+ sqlite3_free(errmsg);
+ sqlite3_close(db);
+ throw std::runtime_error("sqlite3_exec for table failed - " + err_msg);
+ }
+ }
+
+
+SymmetricKey derive_key(const std::string& passphrase,
+ const byte salt[],
+ size_t salt_len,
+ size_t iterations,
+ size_t& check_val)
+ {
+ std::auto_ptr<PBKDF> pbkdf(get_pbkdf("PBKDF2(SHA-512)"));
+
+ SecureVector<byte> x = pbkdf->derive_key(32 + 3,
+ passphrase,
+ salt, salt_len,
+ iterations).bits_of();
+
+ check_val = make_u32bit(0, x[0], x[1], x[2]);
+ return SymmetricKey(&x[3], x.size() - 3);
+ }
+
+}
+
+Session_Manager_SQLite::Session_Manager_SQLite(const std::string& passphrase,
+ RandomNumberGenerator& rng,
+ const std::string& db_filename,
+ size_t max_sessions,
+ size_t session_lifetime) :
+ m_rng(rng),
+ m_max_sessions(max_sessions),
+ m_session_lifetime(session_lifetime)
+ {
+ int rc = sqlite3_open(db_filename.c_str(), &m_db);
+
+ if(rc)
+ {
+ const std::string err_msg = sqlite3_errmsg(m_db);
+ sqlite3_close(m_db);
+ throw std::runtime_error("sqlite3_open failed - " + err_msg);
+ }
+
+ create_table(m_db,
+ "create table if not exists tls_sessions "
+ "("
+ "session_id TEXT PRIMARY KEY, "
+ "session_start INTEGER, "
+ "hostname TEXT, "
+ "hostport INTEGER, "
+ "session BLOB"
+ ")");
+
+ create_table(m_db,
+ "create table if not exists tls_sessions_metadata "
+ "("
+ "passphrase_salt BLOB, "
+ "passphrase_iterations INTEGER, "
+ "passphrase_check INTEGER "
+ ")");
+
+ const size_t salts = row_count(m_db, "tls_sessions_metadata");
+
+ if(salts == 1)
+ {
+ // existing db
+ sqlite3_statement stmt(m_db, "select * from tls_sessions_metadata");
+
+ int rc = stmt.step();
+ if(rc == SQLITE_ROW)
+ {
+ std::pair<const byte*, size_t> salt = stmt.get_blob(0);
+ const size_t iterations = stmt.get_size_t(1);
+ const size_t check_val_db = stmt.get_size_t(2);
+
+ size_t check_val_created;
+ m_session_key = derive_key(passphrase,
+ salt.first,
+ salt.second,
+ iterations,
+ check_val_created);
+
+ if(check_val_created != check_val_db)
+ throw std::runtime_error("Session database password not valid");
+ }
+ }
+ else
+ {
+ // maybe just zap the salts + sessions tables in this case?
+ if(salts != 0)
+ throw std::runtime_error("Seemingly corrupted database, multiple salts found");
+
+ // new database case
+
+ MemoryVector<byte> salt = rng.random_vec(16);
+ const size_t iterations = 64 * 1024;
+ size_t check_val = 0;
+
+ m_session_key = derive_key(passphrase, &salt[0], salt.size(),
+ iterations, check_val);
+
+ sqlite3_statement stmt(m_db, "insert into tls_sessions_metadata"
+ " values(?1, ?2, ?3)");
+
+ stmt.bind(1, salt);
+ stmt.bind(2, iterations);
+ stmt.bind(3, check_val);
+
+ stmt.spin();
+ }
+ }
+
+Session_Manager_SQLite::~Session_Manager_SQLite()
+ {
+ sqlite3_close(m_db);
+ }
+
+bool Session_Manager_SQLite::load_from_session_id(const MemoryRegion<byte>& session_id,
+ Session& session)
+ {
+ sqlite3_statement stmt(m_db, "select session from tls_sessions where session_id = ?1");
+
+ stmt.bind(1, hex_encode(session_id));
+
+ int rc = stmt.step();
+
+ while(rc == SQLITE_ROW)
+ {
+ std::pair<const byte*, size_t> blob = stmt.get_blob(0);
+
+ try
+ {
+ session = Session::decrypt(blob.first, blob.second, m_session_key);
+ return true;
+ }
+ catch(...)
+ {
+ }
+
+ rc = stmt.step();
+ }
+
+ return false;
+ }
+
+bool Session_Manager_SQLite::load_from_host_info(const std::string& hostname,
+ u16bit port,
+ Session& session)
+ {
+ sqlite3_statement stmt(m_db, "select session from tls_sessions"
+ " where hostname = ?1 and hostport = ?2"
+ " order by session_start desc");
+
+ stmt.bind(1, hostname);
+ stmt.bind(2, port);
+
+ int rc = stmt.step();
+
+ while(rc == SQLITE_ROW)
+ {
+ std::pair<const byte*, size_t> blob = stmt.get_blob(0);
+
+ try
+ {
+ session = Session::decrypt(blob.first, blob.second, m_session_key);
+ return true;
+ }
+ catch(...)
+ {
+ }
+
+ rc = stmt.step();
+ }
+
+ return false;
+ }
+
+void Session_Manager_SQLite::remove_entry(const MemoryRegion<byte>& session_id)
+ {
+ sqlite3_statement stmt(m_db, "delete from tls_sessions where session_id = ?1");
+
+ stmt.bind(1, hex_encode(session_id));
+
+ stmt.spin();
+ }
+
+void Session_Manager_SQLite::save(const Session& session)
+ {
+ sqlite3_statement stmt(m_db, "insert or replace into tls_sessions"
+ " values(?1, ?2, ?3, ?4, ?5)");
+
+ stmt.bind(1, hex_encode(session.session_id()));
+ stmt.bind(2, session.start_time());
+ stmt.bind(3, session.sni_hostname());
+ stmt.bind(4, 0);
+ stmt.bind(5, session.encrypt(m_session_key, m_rng));
+
+ stmt.spin();
+
+ prune_session_cache();
+ }
+
+void Session_Manager_SQLite::prune_session_cache()
+ {
+ sqlite3_statement remove_expired(m_db, "delete from tls_sessions where session_start <= ?1");
+
+ remove_expired.bind(1, system_time() - m_session_lifetime);
+
+ remove_expired.spin();
+
+ const size_t sessions = row_count(m_db, "tls_sessions");
+
+ if(sessions > m_max_sessions)
+ {
+ sqlite3_statement remove_some(m_db, "delete from tls_sessions where session_id in "
+ "(select session_id from tls_sessions limit ?1)");
+
+ remove_some.bind(1, sessions - m_max_sessions);
+ remove_some.spin();
+ }
+ }
+
+}
+
+}
diff --git a/src/tls/sessions_sqlite/tls_sqlite_sess_mgr.h b/src/tls/sessions_sqlite/tls_sqlite_sess_mgr.h
new file mode 100644
index 000000000..424db24e5
--- /dev/null
+++ b/src/tls/sessions_sqlite/tls_sqlite_sess_mgr.h
@@ -0,0 +1,68 @@
+/*
+* SQLite TLS Session Manager
+* (C) 2012 Jack Lloyd
+*
+* Released under the terms of the Botan license
+*/
+
+#ifndef TLS_SQLITE_SESSION_MANAGER_H__
+#define TLS_SQLITE_SESSION_MANAGER_H__
+
+#include <botan/tls_session_manager.h>
+#include <botan/rng.h>
+
+class sqlite3;
+
+namespace Botan {
+
+namespace TLS {
+
+/**
+*/
+class BOTAN_DLL Session_Manager_SQLite : public Session_Manager
+ {
+ public:
+ /**
+ * @param passphrase used to encrypt the session data
+ * @param db_filename filename of the SQLite database file.
+ The table names tls_sessions and tls_sessions_metadata
+ will be used
+ * @param max_sessions a hint on the maximum number of sessions
+ * to keep in memory at any one time. (If zero, don't cap)
+ * @param session_lifetime sessions are expired after this many
+ * seconds have elapsed from initial handshake.
+ */
+ Session_Manager_SQLite(const std::string& passphrase,
+ RandomNumberGenerator& rng,
+ const std::string& db_filename,
+ size_t max_sessions = 1000,
+ size_t session_lifetime = 7200);
+
+ ~Session_Manager_SQLite();
+
+ bool load_from_session_id(const MemoryRegion<byte>& session_id,
+ Session& session);
+
+ bool load_from_host_info(const std::string& hostname, u16bit port,
+ Session& session);
+
+ void remove_entry(const MemoryRegion<byte>& session_id);
+
+ void save(const Session& session_data);
+ private:
+ Session_Manager_SQLite(const Session_Manager_SQLite&);
+ Session_Manager_SQLite& operator=(const Session_Manager_SQLite&);
+
+ void prune_session_cache();
+
+ SymmetricKey m_session_key;
+ RandomNumberGenerator& m_rng;
+ size_t m_max_sessions, m_session_lifetime;
+ class sqlite3* m_db;
+ };
+
+}
+
+}
+
+#endif
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index fa240cc23..736b37654 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -142,49 +142,38 @@ void Channel::read_handshake(byte rec_type,
if(rec_type == HANDSHAKE)
{
if(!state)
- state = new Handshake_State;
- state->queue.write(&rec_buf[0], rec_buf.size());
+ state = new Handshake_State(new Stream_Handshake_Reader);
+ state->handshake_reader()->add_input(&rec_buf[0], rec_buf.size());
}
+ BOTAN_ASSERT(state, "Handshake message recieved without state in place");
+
while(true)
{
Handshake_Type type = HANDSHAKE_NONE;
- MemoryVector<byte> contents;
if(rec_type == HANDSHAKE)
{
- if(state->queue.size() >= 4)
+ if(state->handshake_reader()->have_full_record())
{
- byte head[4] = { 0 };
- state->queue.peek(head, 4);
-
- const size_t length = make_u32bit(0, head[1], head[2], head[3]);
-
- if(state->queue.size() >= length + 4)
- {
- type = static_cast<Handshake_Type>(head[0]);
- contents.resize(length);
- state->queue.read(head, 4);
- state->queue.read(&contents[0], contents.size());
- }
+ std::pair<Handshake_Type, MemoryVector<byte> > msg =
+ state->handshake_reader()->get_next_record();
+ process_handshake_msg(msg.first, msg.second);
}
+ else
+ break;
}
else if(rec_type == CHANGE_CIPHER_SPEC)
{
- if(state->queue.size() == 0 && rec_buf.size() == 1 && rec_buf[0] == 1)
- type = HANDSHAKE_CCS;
+ if(state->handshake_reader()->empty() && rec_buf.size() == 1 && rec_buf[0] == 1)
+ process_handshake_msg(HANDSHAKE_CCS, MemoryVector<byte>());
else
throw Decoding_Error("Malformed ChangeCipherSpec message");
}
else
throw Decoding_Error("Unknown message type in handshake processing");
- if(type == HANDSHAKE_NONE)
- break;
-
- process_handshake_msg(type, contents);
-
- if(type == HANDSHAKE_CCS || !state)
+ if(type == HANDSHAKE_CCS || !state || !state->handshake_reader()->have_full_record())
break;
}
}
diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h
index c85b32ba0..9ab89e5f7 100644
--- a/src/tls/tls_channel.h
+++ b/src/tls/tls_channel.h
@@ -63,8 +63,8 @@ class BOTAN_DLL Channel
std::vector<X509_Certificate> peer_cert_chain() const { return peer_certs; }
Channel(std::function<void (const byte[], size_t)> socket_output_fn,
- std::function<void (const byte[], size_t, Alert)> proc_fn,
- std::function<bool (const Session&)> handshake_complete);
+ std::function<void (const byte[], size_t, Alert)> proc_fn,
+ std::function<bool (const Session&)> handshake_complete);
virtual ~Channel();
protected:
diff --git a/src/tls/tls_ciphersuite.cpp b/src/tls/tls_ciphersuite.cpp
index 22815a048..89daaf679 100644
--- a/src/tls/tls_ciphersuite.cpp
+++ b/src/tls/tls_ciphersuite.cpp
@@ -309,7 +309,7 @@ std::string Ciphersuite::to_string() const
{
if(cipher_algo() == "3DES")
out << "3DES_EDE";
- if(cipher_algo() == "Camellia")
+ else if(cipher_algo() == "Camellia")
out << "CAMELLIA_" << std::to_string(8*cipher_keylen());
else
out << replace_char(cipher_algo(), '-', '_');
diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp
index 2dd30819f..77e87e9bc 100644
--- a/src/tls/tls_client.cpp
+++ b/src/tls/tls_client.cpp
@@ -19,14 +19,14 @@ namespace TLS {
* TLS Client Constructor
*/
Client::Client(std::function<void (const byte[], size_t)> output_fn,
- std::function<void (const byte[], size_t, Alert)> proc_fn,
- std::function<bool (const Session&)> handshake_fn,
- Session_Manager& session_manager,
- Credentials_Manager& creds,
- const Policy& policy,
- RandomNumberGenerator& rng,
- const std::string& hostname,
- std::function<std::string (std::vector<std::string>)> next_protocol) :
+ std::function<void (const byte[], size_t, Alert)> proc_fn,
+ std::function<bool (const Session&)> handshake_fn,
+ Session_Manager& session_manager,
+ Credentials_Manager& creds,
+ const Policy& policy,
+ RandomNumberGenerator& rng,
+ const std::string& hostname,
+ std::function<std::string (std::vector<std::string>)> next_protocol) :
Channel(output_fn, proc_fn, handshake_fn),
policy(policy),
rng(rng),
@@ -35,7 +35,7 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn,
{
writer.set_version(Protocol_Version::SSL_V3);
- state = new Handshake_State;
+ state = new Handshake_State(new Stream_Handshake_Reader);
state->set_expected_next(SERVER_HELLO);
state->client_npn_cb = next_protocol;
@@ -54,6 +54,7 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn,
state->client_hello = new Client_Hello(
writer,
state->hash,
+ policy,
rng,
session_info,
send_npn_request);
@@ -87,7 +88,7 @@ void Client::renegotiate()
if(state)
return; // currently in handshake
- state = new Handshake_State;
+ state = new Handshake_State(new Stream_Handshake_Reader);
state->set_expected_next(SERVER_HELLO);
state->client_hello = new Client_Hello(writer, state->hash, policy, rng,
@@ -112,7 +113,7 @@ void Client::alert_notify(const Alert& alert)
* Process a handshake message
*/
void Client::process_handshake_msg(Handshake_Type type,
- const MemoryRegion<byte>& contents)
+ const MemoryRegion<byte>& contents)
{
if(state == 0)
throw Unexpected_Message("Unexpected handshake message from server");
@@ -135,12 +136,12 @@ void Client::process_handshake_msg(Handshake_Type type,
return;
}
- state->set_expected_next(SERVER_HELLO);
state->client_hello = new Client_Hello(writer, state->hash, policy, rng,
secure_renegotiation.for_client_hello());
-
secure_renegotiation.update(state->client_hello);
+ state->set_expected_next(SERVER_HELLO);
+
return;
}
@@ -173,17 +174,27 @@ void Client::process_handshake_msg(Handshake_Type type,
"Server sent next protocol but we didn't request it");
}
- state->version = state->server_hello->version();
+ if(state->server_hello->supports_session_ticket())
+ {
+ if(!state->client_hello->supports_session_ticket())
+ throw TLS_Exception(Alert::HANDSHAKE_FAILURE,
+ "Server sent session ticket extension but we did not");
+ }
+
+ state->set_version(state->server_hello->version());
- writer.set_version(state->version);
- reader.set_version(state->version);
+ writer.set_version(state->version());
+ reader.set_version(state->version());
secure_renegotiation.update(state->server_hello);
state->suite = Ciphersuite::by_id(state->server_hello->ciphersuite());
- if(!state->server_hello->session_id().empty() &&
- (state->server_hello->session_id() == state->client_hello->session_id()))
+ const bool server_returned_same_session_id =
+ !state->server_hello->session_id().empty() &&
+ (state->server_hello->session_id() == state->client_hello->session_id());
+
+ if(server_returned_same_session_id)
{
// successful resumption
@@ -199,19 +210,22 @@ void Client::process_handshake_msg(Handshake_Type type,
state->resume_master_secret,
true);
- state->set_expected_next(HANDSHAKE_CCS);
+ if(state->server_hello->supports_session_ticket())
+ state->set_expected_next(NEW_SESSION_TICKET);
+ else
+ state->set_expected_next(HANDSHAKE_CCS);
}
else
{
// new session
- if(state->version > state->client_hello->version())
+ if(state->version() > state->client_hello->version())
{
throw TLS_Exception(Alert::HANDSHAKE_FAILURE,
"Client: Server replied with bad version");
}
- if(state->version < policy.min_version())
+ if(state->version() < policy.min_version())
{
throw TLS_Exception(Alert::PROTOCOL_VERSION,
"Client: Server is too old for specified policy");
@@ -288,7 +302,7 @@ void Client::process_handshake_msg(Handshake_Type type,
state->server_kex = new Server_Key_Exchange(contents,
state->suite.kex_algo(),
state->suite.sig_algo(),
- state->version);
+ state->version());
if(state->suite.sig_algo() != "")
{
@@ -302,12 +316,10 @@ void Client::process_handshake_msg(Handshake_Type type,
else if(type == CERTIFICATE_REQUEST)
{
state->set_expected_next(SERVER_HELLO_DONE);
- state->cert_req = new Certificate_Req(contents, state->version);
+ state->cert_req = new Certificate_Req(contents, state->version());
}
else if(type == SERVER_HELLO_DONE)
{
- state->set_expected_next(HANDSHAKE_CCS);
-
state->server_hello_done = new Server_Hello_Done(contents);
if(state->received_handshake_msg(CERTIFICATE_REQUEST))
@@ -364,6 +376,17 @@ void Client::process_handshake_msg(Handshake_Type type,
}
state->client_finished = new Finished(writer, state, CLIENT);
+
+ if(state->server_hello->supports_session_ticket())
+ state->set_expected_next(NEW_SESSION_TICKET);
+ else
+ state->set_expected_next(HANDSHAKE_CCS);
+ }
+ else if(type == NEW_SESSION_TICKET)
+ {
+ state->new_session_ticket = new New_Session_Ticket(contents);
+
+ state->set_expected_next(HANDSHAKE_CCS);
}
else if(type == HANDSHAKE_CCS)
{
@@ -394,8 +417,17 @@ void Client::process_handshake_msg(Handshake_Type type,
state->client_finished = new Finished(writer, state, CLIENT);
}
+ secure_renegotiation.update(state->client_finished, state->server_finished);
+
+ MemoryVector<byte> session_id = state->server_hello->session_id();
+
+ const MemoryRegion<byte>& session_ticket = state->session_ticket();
+
+ if(session_id.empty() && !session_ticket.empty())
+ session_id = make_hello_random(rng);
+
Session session_info(
- state->server_hello->session_id(),
+ session_id,
state->keys.master_secret(),
state->server_hello->version(),
state->server_hello->ciphersuite(),
@@ -404,6 +436,7 @@ void Client::process_handshake_msg(Handshake_Type type,
secure_renegotiation.supported(),
state->server_hello->fragment_size(),
peer_certs,
+ session_ticket,
state->client_hello->sni_hostname(),
""
);
@@ -413,8 +446,6 @@ void Client::process_handshake_msg(Handshake_Type type,
else
session_manager.remove_entry(session_info.session_id());
- secure_renegotiation.update(state->client_finished, state->server_finished);
-
delete state;
state = 0;
handshake_completed = true;
diff --git a/src/tls/tls_extensions.cpp b/src/tls/tls_extensions.cpp
index 7162dcf40..c0de24bfe 100644
--- a/src/tls/tls_extensions.cpp
+++ b/src/tls/tls_extensions.cpp
@@ -42,6 +42,9 @@ Extension* make_extension(TLS_Data_Reader& reader,
case TLSEXT_NEXT_PROTOCOL:
return new Next_Protocol_Notification(reader, size);
+ case TLSEXT_SESSION_TICKET:
+ return new Session_Ticket(reader, size);
+
default:
return 0; // not known
}
@@ -501,6 +504,12 @@ Signature_Algorithms::Signature_Algorithms(TLS_Data_Reader& reader,
}
}
+Session_Ticket::Session_Ticket(TLS_Data_Reader& reader,
+ u16bit extension_size)
+ {
+ m_ticket = reader.get_elem<byte, MemoryVector<byte> >(extension_size);
+ }
+
}
}
diff --git a/src/tls/tls_extensions.h b/src/tls/tls_extensions.h
index 180216b8b..6a97d2560 100644
--- a/src/tls/tls_extensions.h
+++ b/src/tls/tls_extensions.h
@@ -1,6 +1,6 @@
/*
* TLS Extensions
-* (C) 2011 Jack Lloyd
+* (C) 2011-2012 Jack Lloyd
*
* Released under the terms of the Botan license
*/
@@ -210,6 +210,39 @@ class Next_Protocol_Notification : public Extension
std::vector<std::string> m_protocols;
};
+class Session_Ticket : public Extension
+ {
+ public:
+ static Handshake_Extension_Type static_type()
+ { return TLSEXT_SESSION_TICKET; }
+
+ Handshake_Extension_Type type() const { return static_type(); }
+
+ const MemoryVector<byte>& contents() const { return m_ticket; }
+
+ /**
+ * Create empty extension, used by both client and server
+ */
+ Session_Ticket() {}
+
+ /**
+ * Extension with ticket, used by client
+ */
+ Session_Ticket(const MemoryRegion<byte>& session_ticket) :
+ m_ticket(session_ticket) {}
+
+ /**
+ * Deserialize a session ticket
+ */
+ Session_Ticket(TLS_Data_Reader& reader, u16bit extension_size);
+
+ MemoryVector<byte> serialize() const { return m_ticket; }
+
+ bool empty() const { return false; }
+ private:
+ MemoryVector<byte> m_ticket;
+ };
+
/**
* Supported Elliptic Curves Extension (RFC 4492)
*/
diff --git a/src/tls/tls_handshake_reader.cpp b/src/tls/tls_handshake_reader.cpp
new file mode 100644
index 000000000..8278a2296
--- /dev/null
+++ b/src/tls/tls_handshake_reader.cpp
@@ -0,0 +1,66 @@
+/*
+* TLS Handshake Reader
+* (C) 2012 Jack Lloyd
+*
+* Released under the terms of the Botan license
+*/
+
+#include <botan/internal/tls_handshake_reader.h>
+#include <botan/exceptn.h>
+
+namespace Botan {
+
+namespace TLS {
+
+void Stream_Handshake_Reader::add_input(const byte record[],
+ size_t record_size)
+ {
+ m_queue.write(record, record_size);
+ }
+
+bool Stream_Handshake_Reader::empty() const
+ {
+ return m_queue.empty();
+ }
+
+bool Stream_Handshake_Reader::have_full_record() const
+ {
+ if(m_queue.size() >= 4)
+ {
+ byte head[4] = { 0 };
+ m_queue.peek(head, 4);
+
+ const size_t length = make_u32bit(0, head[1], head[2], head[3]);
+
+ return (m_queue.size() >= length + 4);
+ }
+
+ return false;
+ }
+
+std::pair<Handshake_Type, MemoryVector<byte> > Stream_Handshake_Reader::get_next_record()
+ {
+ if(m_queue.size() >= 4)
+ {
+ byte head[4] = { 0 };
+ m_queue.peek(head, 4);
+
+ const size_t length = make_u32bit(0, head[1], head[2], head[3]);
+
+ if(m_queue.size() >= length + 4)
+ {
+ Handshake_Type type = static_cast<Handshake_Type>(head[0]);
+ MemoryVector<byte> contents(length);
+ m_queue.read(head, 4); // discard
+ m_queue.read(&contents[0], contents.size());
+
+ return std::make_pair(type, contents);
+ }
+ }
+
+ throw Internal_Error("Stream_Handshake_Reader::get_next_record called without a full record");
+ }
+
+}
+
+}
diff --git a/src/tls/tls_handshake_reader.h b/src/tls/tls_handshake_reader.h
new file mode 100644
index 000000000..06a273ced
--- /dev/null
+++ b/src/tls/tls_handshake_reader.h
@@ -0,0 +1,58 @@
+/*
+* TLS Handshake Reader
+* (C) 2012 Jack Lloyd
+*
+* Released under the terms of the Botan license
+*/
+
+#ifndef BOTAN_TLS_HANDSHAKE_READER_H__
+#define BOTAN_TLS_HANDSHAKE_READER_H__
+
+#include <botan/tls_magic.h>
+#include <botan/secqueue.h>
+#include <botan/loadstor.h>
+#include <utility>
+
+namespace Botan {
+
+namespace TLS {
+
+/**
+* Handshake Reader Interface
+*/
+class Handshake_Reader
+ {
+ public:
+ virtual void add_input(const byte record[], size_t record_size) = 0;
+
+ virtual bool empty() const = 0;
+
+ virtual bool have_full_record() const = 0;
+
+ virtual std::pair<Handshake_Type, MemoryVector<byte> > get_next_record() = 0;
+
+ virtual ~Handshake_Reader() {}
+ };
+
+/**
+* Reader of TLS handshake messages
+*/
+class Stream_Handshake_Reader : public Handshake_Reader
+ {
+ public:
+ void add_input(const byte record[], size_t record_size);
+
+ bool empty() const;
+
+ bool have_full_record() const;
+
+ std::pair<Handshake_Type, MemoryVector<byte> > get_next_record();
+ private:
+ SecureQueue m_queue;
+ };
+
+}
+
+}
+
+#endif
diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp
index 86e6e0b55..b34d8616d 100644
--- a/src/tls/tls_handshake_state.cpp
+++ b/src/tls/tls_handshake_state.cpp
@@ -54,12 +54,15 @@ u32bit bitmask_for_handshake_type(Handshake_Type type)
case NEXT_PROTOCOL:
return (1 << 9);
- case HANDSHAKE_CCS:
+ case NEW_SESSION_TICKET:
return (1 << 10);
- case FINISHED:
+ case HANDSHAKE_CCS:
return (1 << 11);
+ case FINISHED:
+ return (1 << 12);
+
// allow explicitly disabling new handshakes
case HANDSHAKE_NONE:
return 0;
@@ -76,7 +79,7 @@ u32bit bitmask_for_handshake_type(Handshake_Type type)
/*
* Initialize the SSL/TLS Handshake State
*/
-Handshake_State::Handshake_State()
+Handshake_State::Handshake_State(Handshake_Reader* reader)
{
client_hello = 0;
server_hello = 0;
@@ -85,6 +88,7 @@ Handshake_State::Handshake_State()
cert_req = 0;
server_hello_done = 0;
next_protocol = 0;
+ new_session_ticket = 0;
client_certs = 0;
client_kex = 0;
@@ -92,14 +96,21 @@ Handshake_State::Handshake_State()
client_finished = 0;
server_finished = 0;
+ m_handshake_reader = reader;
+
server_rsa_kex_key = 0;
- version = Protocol_Version::SSL_V3;
+ m_version = Protocol_Version::SSL_V3;
hand_expecting_mask = 0;
hand_received_mask = 0;
}
+void Handshake_State::set_version(const Protocol_Version& version)
+ {
+ m_version = version;
+ }
+
void Handshake_State::confirm_transition_to(Handshake_Type handshake_msg)
{
const u32bit mask = bitmask_for_handshake_type(handshake_msg);
@@ -132,17 +143,25 @@ bool Handshake_State::received_handshake_msg(Handshake_Type handshake_msg) const
return (hand_received_mask & mask);
}
+const MemoryRegion<byte>& Handshake_State::session_ticket() const
+ {
+ if(new_session_ticket && !new_session_ticket->ticket().empty())
+ return new_session_ticket->ticket();
+
+ return client_hello->session_ticket();
+ }
+
KDF* Handshake_State::protocol_specific_prf()
{
- if(version == Protocol_Version::SSL_V3)
+ if(version() == Protocol_Version::SSL_V3)
{
return get_kdf("SSL3-PRF");
}
- else if(version == Protocol_Version::TLS_V10 || version == Protocol_Version::TLS_V11)
+ else if(version() == Protocol_Version::TLS_V10 || version() == Protocol_Version::TLS_V11)
{
return get_kdf("TLS-PRF");
}
- else if(version == Protocol_Version::TLS_V12)
+ else if(version() == Protocol_Version::TLS_V12)
{
if(suite.mac_algo() == "SHA-1" || suite.mac_algo() == "SHA-256")
return get_kdf("TLS-12-PRF(SHA-256)");
@@ -150,7 +169,7 @@ KDF* Handshake_State::protocol_specific_prf()
return get_kdf("TLS-12-PRF(" + suite.mac_algo() + ")");
}
- throw Internal_Error("Unknown version code " + version.to_string());
+ throw Internal_Error("Unknown version code " + version().to_string());
}
std::pair<std::string, Signature_Format>
@@ -175,15 +194,15 @@ Handshake_State::choose_sig_format(const Private_Key* key,
}
}
- if(for_client_auth && this->version == Protocol_Version::SSL_V3)
+ if(for_client_auth && this->version() == Protocol_Version::SSL_V3)
hash_algo = "Raw";
- if(hash_algo == "" && this->version == Protocol_Version::TLS_V12)
+ if(hash_algo == "" && this->version() == Protocol_Version::TLS_V12)
hash_algo = "SHA-1"; // TLS 1.2 but no compatible hashes set (?)
BOTAN_ASSERT(hash_algo != "", "Couldn't figure out hash to use");
- if(this->version >= Protocol_Version::TLS_V12)
+ if(this->version() >= Protocol_Version::TLS_V12)
{
hash_algo_out = hash_algo;
sig_algo_out = sig_algo;
@@ -221,7 +240,7 @@ Handshake_State::understand_sig_format(const Public_Key* key,
Or not?
*/
- if(this->version < Protocol_Version::TLS_V12)
+ if(this->version() < Protocol_Version::TLS_V12)
{
if(hash_algo != "" || sig_algo != "")
throw Decoding_Error("Counterparty sent hash/sig IDs with old version");
@@ -237,11 +256,11 @@ Handshake_State::understand_sig_format(const Public_Key* key,
if(algo_name == "RSA")
{
- if(for_client_auth && this->version == Protocol_Version::SSL_V3)
+ if(for_client_auth && this->version() == Protocol_Version::SSL_V3)
{
hash_algo = "Raw";
}
- else if(this->version < Protocol_Version::TLS_V12)
+ else if(this->version() < Protocol_Version::TLS_V12)
{
hash_algo = "TLS.Digest.0";
}
@@ -251,11 +270,11 @@ Handshake_State::understand_sig_format(const Public_Key* key,
}
else if(algo_name == "DSA" || algo_name == "ECDSA")
{
- if(algo_name == "DSA" && for_client_auth && this->version == Protocol_Version::SSL_V3)
+ if(algo_name == "DSA" && for_client_auth && this->version() == Protocol_Version::SSL_V3)
{
hash_algo = "Raw";
}
- else if(this->version < Protocol_Version::TLS_V12)
+ else if(this->version() < Protocol_Version::TLS_V12)
{
hash_algo = "SHA-1";
}
@@ -280,12 +299,15 @@ Handshake_State::~Handshake_State()
delete cert_req;
delete server_hello_done;
delete next_protocol;
+ delete new_session_ticket;
delete client_certs;
delete client_kex;
delete client_verify;
delete client_finished;
delete server_finished;
+
+ delete m_handshake_reader;
}
}
diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h
index 4e1cb2f25..d2f64c43b 100644
--- a/src/tls/tls_handshake_state.h
+++ b/src/tls/tls_handshake_state.h
@@ -9,8 +9,8 @@
#define BOTAN_TLS_HANDSHAKE_STATE_H__
#include <botan/internal/tls_handshake_hash.h>
+#include <botan/internal/tls_handshake_reader.h>
#include <botan/internal/tls_session_key.h>
-#include <botan/secqueue.h>
#include <botan/pk_keys.h>
#include <botan/pubkey.h>
@@ -29,7 +29,7 @@ namespace TLS {
class Handshake_State
{
public:
- Handshake_State();
+ Handshake_State(Handshake_Reader* reader);
~Handshake_State();
bool received_handshake_msg(Handshake_Type handshake_msg) const;
@@ -37,6 +37,8 @@ class Handshake_State
void confirm_transition_to(Handshake_Type handshake_msg);
void set_expected_next(Handshake_Type handshake_msg);
+ const MemoryRegion<byte>& session_ticket() const;
+
std::pair<std::string, Signature_Format>
understand_sig_format(const Public_Key* key,
std::string hash_algo,
@@ -51,7 +53,9 @@ class Handshake_State
KDF* protocol_specific_prf();
- Protocol_Version version;
+ Protocol_Version version() const { return m_version; }
+
+ void set_version(const Protocol_Version& version);
class Client_Hello* client_hello;
class Server_Hello* server_hello;
@@ -65,6 +69,7 @@ class Handshake_State
class Certificate_Verify* client_verify;
class Next_Protocol* next_protocol;
+ class New_Session_Ticket* new_session_ticket;
class Finished* client_finished;
class Finished* server_finished;
@@ -76,8 +81,6 @@ class Handshake_State
Session_Keys keys;
Handshake_Hash hash;
- SecureQueue queue;
-
/*
* Only used by clients for session resumption
*/
@@ -88,8 +91,11 @@ class Handshake_State
*/
std::function<std::string (std::vector<std::string>)> client_npn_cb;
+ Handshake_Reader* handshake_reader() { return m_handshake_reader; }
private:
+ Handshake_Reader* m_handshake_reader;
u32bit hand_expecting_mask, hand_received_mask;
+ Protocol_Version m_version;
};
}
diff --git a/src/tls/tls_magic.h b/src/tls/tls_magic.h
index 72a430bf2..0e45407d3 100644
--- a/src/tls/tls_magic.h
+++ b/src/tls/tls_magic.h
@@ -36,23 +36,24 @@ enum Record_Type {
};
enum Handshake_Type {
- HELLO_REQUEST = 0,
- CLIENT_HELLO = 1,
- CLIENT_HELLO_SSLV2 = 200, // Not a wire value
- SERVER_HELLO = 2,
- NEW_SESSION_TICKET = 4, // RFC 5077
- CERTIFICATE = 11,
- SERVER_KEX = 12,
- CERTIFICATE_REQUEST = 13,
- SERVER_HELLO_DONE = 14,
- CERTIFICATE_VERIFY = 15,
- CLIENT_KEX = 16,
- FINISHED = 20,
-
- NEXT_PROTOCOL = 67,
-
- HANDSHAKE_CCS = 100, // Not a wire value
- HANDSHAKE_NONE = 255 // Null value
+ HELLO_REQUEST = 0,
+ CLIENT_HELLO = 1,
+ CLIENT_HELLO_SSLV2 = 253, // Not a wire value
+ SERVER_HELLO = 2,
+ HELLO_VERIFY_REQUEST = 3,
+ NEW_SESSION_TICKET = 4, // RFC 5077
+ CERTIFICATE = 11,
+ SERVER_KEX = 12,
+ CERTIFICATE_REQUEST = 13,
+ SERVER_HELLO_DONE = 14,
+ CERTIFICATE_VERIFY = 15,
+ CLIENT_KEX = 16,
+ FINISHED = 20,
+
+ NEXT_PROTOCOL = 67,
+
+ HANDSHAKE_CCS = 254, // Not a wire value
+ HANDSHAKE_NONE = 255 // Null value
};
enum Ciphersuite_Code {
diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h
index 7162ece1a..2f8af5fd2 100644
--- a/src/tls/tls_messages.h
+++ b/src/tls/tls_messages.h
@@ -33,19 +33,39 @@ class Record_Reader;
class Handshake_Message
{
public:
- void send(Record_Writer& writer, Handshake_Hash& hash) const;
-
+ virtual MemoryVector<byte> serialize() const = 0;
virtual Handshake_Type type() const = 0;
+ Handshake_Message() {}
virtual ~Handshake_Message() {}
private:
+ Handshake_Message(const Handshake_Message&) {}
Handshake_Message& operator=(const Handshake_Message&) { return (*this); }
- virtual MemoryVector<byte> serialize() const = 0;
};
MemoryVector<byte> make_hello_random(RandomNumberGenerator& rng);
/**
+* DTLS Hello Verify Request
+*/
+class Hello_Verify_Request : public Handshake_Message
+ {
+ public:
+ MemoryVector<byte> serialize() const;
+ Handshake_Type type() const { return HELLO_VERIFY_REQUEST; }
+
+ MemoryVector<byte> cookie() const { return m_cookie; }
+
+ Hello_Verify_Request(const MemoryRegion<byte>& buf);
+
+ Hello_Verify_Request(const MemoryVector<byte>& client_hello_bits,
+ const std::string& client_identity,
+ const SymmetricKey& secret_key);
+ private:
+ MemoryVector<byte> m_cookie;
+ };
+
+/**
* Client Hello Message
*/
class Client_Hello : public Handshake_Message
@@ -88,6 +108,11 @@ class Client_Hello : public Handshake_Message
size_t fragment_size() const { return m_fragment_size; }
+ bool supports_session_ticket() const { return m_supports_session_ticket; }
+
+ const MemoryRegion<byte>& session_ticket() const
+ { return m_session_ticket; }
+
Client_Hello(Record_Writer& writer,
Handshake_Hash& hash,
const Policy& policy,
@@ -99,6 +124,7 @@ class Client_Hello : public Handshake_Message
Client_Hello(Record_Writer& writer,
Handshake_Hash& hash,
+ const Policy& policy,
RandomNumberGenerator& rng,
const Session& resumed_session,
bool next_protocol = false);
@@ -125,6 +151,9 @@ class Client_Hello : public Handshake_Message
std::vector<std::pair<std::string, std::string> > m_supported_algos;
std::vector<std::string> m_supported_curves;
+
+ bool m_supports_session_ticket;
+ MemoryVector<byte> m_session_ticket;
};
/**
@@ -150,6 +179,8 @@ class Server_Hello : public Handshake_Message
bool next_protocol_notification() const { return m_next_protocol; }
+ bool supports_session_ticket() const { return m_supports_session_ticket; }
+
const std::vector<std::string>& next_protocols() const
{ return m_next_protocols; }
@@ -166,6 +197,7 @@ class Server_Hello : public Handshake_Message
const Client_Hello& other,
const std::vector<std::string>& available_cert_types,
const Policy& policies,
+ bool have_session_ticket_key,
bool client_has_secure_renegotiation,
const MemoryRegion<byte>& reneg_info,
bool client_has_npn,
@@ -181,6 +213,7 @@ class Server_Hello : public Handshake_Message
size_t max_fragment_size,
bool client_has_secure_renegotiation,
const MemoryRegion<byte>& reneg_info,
+ bool client_supports_session_tickets,
bool client_has_npn,
const std::vector<std::string>& next_protocols,
RandomNumberGenerator& rng);
@@ -200,6 +233,7 @@ class Server_Hello : public Handshake_Message
bool m_next_protocol;
std::vector<std::string> m_next_protocols;
+ bool m_supports_session_ticket;
};
/**
@@ -238,10 +272,10 @@ class Certificate : public Handshake_Message
{
public:
Handshake_Type type() const { return CERTIFICATE; }
- const std::vector<X509_Certificate>& cert_chain() const { return certs; }
+ const std::vector<X509_Certificate>& cert_chain() const { return m_certs; }
- size_t count() const { return certs.size(); }
- bool empty() const { return certs.empty(); }
+ size_t count() const { return m_certs.size(); }
+ bool empty() const { return m_certs.empty(); }
Certificate(Record_Writer& writer,
Handshake_Hash& hash,
@@ -251,7 +285,7 @@ class Certificate : public Handshake_Message
private:
MemoryVector<byte> serialize() const;
- std::vector<X509_Certificate> certs;
+ std::vector<X509_Certificate> m_certs;
};
/**
@@ -434,6 +468,30 @@ class Next_Protocol : public Handshake_Message
std::string m_protocol;
};
+class New_Session_Ticket : public Handshake_Message
+ {
+ public:
+ Handshake_Type type() const { return NEW_SESSION_TICKET; }
+
+ u32bit ticket_lifetime_hint() const { return m_ticket_lifetime_hint; }
+ const MemoryVector<byte>& ticket() const { return m_ticket; }
+
+ New_Session_Ticket(Record_Writer& writer,
+ Handshake_Hash& hash,
+ const MemoryRegion<byte>& ticket,
+ u32bit lifetime = 0);
+
+ New_Session_Ticket(Record_Writer& writer,
+ Handshake_Hash& hash);
+
+ New_Session_Ticket(const MemoryRegion<byte>& buf);
+ private:
+ MemoryVector<byte> serialize() const;
+
+ u32bit m_ticket_lifetime_hint;
+ MemoryVector<byte> m_ticket;
+ };
+
}
}
diff --git a/src/tls/tls_policy.cpp b/src/tls/tls_policy.cpp
index 49f74975b..1ab55f7c6 100644
--- a/src/tls/tls_policy.cpp
+++ b/src/tls/tls_policy.cpp
@@ -89,6 +89,16 @@ std::vector<std::string> Policy::allowed_ecc_curves() const
return curves;
}
+Protocol_Version Policy::min_version() const
+ {
+ return Protocol_Version::SSL_V3;
+ }
+
+Protocol_Version Policy::pref_version() const
+ {
+ return Protocol_Version::TLS_V12;
+ }
+
namespace {
class Ciphersuite_Preference_Ordering
diff --git a/src/tls/tls_policy.h b/src/tls/tls_policy.h
index cd00331a5..f53b9bab6 100644
--- a/src/tls/tls_policy.h
+++ b/src/tls/tls_policy.h
@@ -97,14 +97,12 @@ class BOTAN_DLL Policy
/**
* @return the minimum version that we are willing to negotiate
*/
- virtual Protocol_Version min_version() const
- { return Protocol_Version::SSL_V3; }
+ virtual Protocol_Version min_version() const;
/**
* @return the version we would prefer to negotiate
*/
- virtual Protocol_Version pref_version() const
- { return Protocol_Version::TLS_V12; }
+ virtual Protocol_Version pref_version() const;
/**
* Return allowed ciphersuites, in order of preference
diff --git a/src/tls/tls_reader.h b/src/tls/tls_reader.h
index 8c2e9efe2..bf8098bed 100644
--- a/src/tls/tls_reader.h
+++ b/src/tls/tls_reader.h
@@ -50,6 +50,15 @@ class TLS_Data_Reader
offset += bytes;
}
+ u16bit get_u32bit()
+ {
+ assert_at_least(4);
+ u16bit result = make_u32bit(buf[offset ], buf[offset+1],
+ buf[offset+2], buf[offset+3]);
+ offset += 4;
+ return result;
+ }
+
u16bit get_u16bit()
{
assert_at_least(2);
diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h
index 6634810df..b966e3c72 100644
--- a/src/tls/tls_record.h
+++ b/src/tls/tls_record.h
@@ -33,6 +33,8 @@ class BOTAN_DLL Record_Writer
void send(byte type, const byte input[], size_t length);
void send(byte type, byte val) { send(type, &val, 1); }
+ MemoryVector<byte> send(class Handshake_Message& msg);
+
void send_alert(const Alert& alert);
void activate(Connection_Side side,
diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp
index 129d5346d..3b4a526cc 100644
--- a/src/tls/tls_server.cpp
+++ b/src/tls/tls_server.cpp
@@ -20,16 +20,35 @@ namespace {
bool check_for_resume(Session& session_info,
Session_Manager& session_manager,
+ Credentials_Manager& credentials,
Client_Hello* client_hello)
{
- MemoryVector<byte> client_session_id = client_hello->session_id();
+ const MemoryVector<byte>& client_session_id = client_hello->session_id();
+ const MemoryVector<byte>& session_ticket = client_hello->session_ticket();
- if(client_session_id.empty()) // not resuming
- return false;
+ if(session_ticket.empty())
+ {
+ if(client_session_id.empty()) // not resuming
+ return false;
- // not found
- if(!session_manager.load_from_session_id(client_session_id, session_info))
- return false;
+ // not found
+ if(!session_manager.load_from_session_id(client_session_id, session_info))
+ return false;
+ }
+ else
+ {
+ // If a session ticket was sent, ignore client session ID
+ try
+ {
+ session_info = Session::decrypt(
+ session_ticket,
+ credentials.psk("tls-server", "session-ticket", ""));
+ }
+ catch(...)
+ {
+ return false;
+ }
+ }
// wrong version
if(client_hello->version() != session_info.version())
@@ -45,14 +64,14 @@ bool check_for_resume(Session& session_info,
session_info.compression_method()))
return false;
- // client sent a different SRP identity (!!!)
+ // client sent a different SRP identity
if(client_hello->srp_identifier() != "")
{
if(client_hello->srp_identifier() != session_info.srp_identifier())
return false;
}
- // client sent a different SNI hostname (!!!)
+ // client sent a different SNI hostname
if(client_hello->sni_hostname() != "")
{
if(client_hello->sni_hostname() != session_info.sni_hostname())
@@ -112,7 +131,7 @@ void Server::renegotiate()
if(state)
return; // currently in handshake
- state = new Handshake_State;
+ state = new Handshake_State(new Stream_Handshake_Reader);
state->set_expected_next(CLIENT_HELLO);
Hello_Request hello_req(writer);
}
@@ -137,7 +156,7 @@ void Server::read_handshake(byte rec_type,
{
if(rec_type == HANDSHAKE && !state)
{
- state = new Handshake_State;
+ state = new Handshake_State(new Stream_Handshake_Reader);
state->set_expected_next(CLIENT_HELLO);
}
@@ -183,20 +202,30 @@ void Server::process_handshake_msg(Handshake_Type type,
"Client version is unacceptable by policy");
if(client_version <= policy.pref_version())
- state->version = client_version;
+ state->set_version(client_version);
else
- state->version = policy.pref_version();
+ state->set_version(policy.pref_version());
secure_renegotiation.update(state->client_hello);
- writer.set_version(state->version);
- reader.set_version(state->version);
+ writer.set_version(state->version());
+ reader.set_version(state->version());
Session session_info;
const bool resuming = check_for_resume(session_info,
session_manager,
+ creds,
state->client_hello);
+ bool have_session_ticket_key = false;
+
+ try
+ {
+ have_session_ticket_key =
+ creds.psk("tls-server", "session-ticket", "").length() > 0;
+ }
+ catch(...) {}
+
if(resuming)
{
// resume session
@@ -204,13 +233,14 @@ void Server::process_handshake_msg(Handshake_Type type,
state->server_hello = new Server_Hello(
writer,
state->hash,
- session_info.session_id(),
+ state->client_hello->session_id(),
Protocol_Version(session_info.version()),
session_info.ciphersuite_code(),
session_info.compression_method(),
session_info.fragment_size(),
secure_renegotiation.supported(),
secure_renegotiation.for_server_hello(),
+ state->client_hello->supports_session_ticket() && have_session_ticket_key,
state->client_hello->next_protocol_notification(),
m_possible_protocols,
rng);
@@ -225,6 +255,31 @@ void Server::process_handshake_msg(Handshake_Type type,
state->keys = Session_Keys(state, session_info.master_secret(), true);
+ if(!handshake_fn(session_info))
+ {
+ if(state->server_hello->supports_session_ticket())
+ state->new_session_ticket = new New_Session_Ticket(writer, state->hash);
+ else
+ session_manager.remove_entry(session_info.session_id());
+ }
+
+ // FIXME: should only send a new ticket if we need too (eg old session)
+ if(state->server_hello->supports_session_ticket() && !state->new_session_ticket)
+ {
+ try
+ {
+ const SymmetricKey ticket_key = creds.psk("tls-server", "session-ticket", "");
+
+ state->new_session_ticket =
+ new New_Session_Ticket(writer, state->hash,
+ session_info.encrypt(ticket_key, rng));
+ }
+ catch(...) {}
+
+ if(!state->new_session_ticket)
+ state->new_session_ticket = new New_Session_Ticket(writer, state->hash);
+ }
+
writer.send(CHANGE_CIPHER_SPEC, 1);
writer.activate(SERVER, state->suite, state->keys,
@@ -232,9 +287,6 @@ void Server::process_handshake_msg(Handshake_Type type,
state->server_finished = new Finished(writer, state, SERVER);
- if(!handshake_fn(session_info))
- session_manager.remove_entry(session_info.session_id());
-
state->set_expected_next(HANDSHAKE_CCS);
}
else // new session
@@ -261,10 +313,11 @@ void Server::process_handshake_msg(Handshake_Type type,
state->server_hello = new Server_Hello(
writer,
state->hash,
- state->version,
+ state->version(),
*(state->client_hello),
available_cert_types,
policy,
+ have_session_ticket_key,
secure_renegotiation.supported(),
secure_renegotiation.for_server_hello(),
state->client_hello->next_protocol_notification(),
@@ -323,7 +376,7 @@ void Server::process_handshake_msg(Handshake_Type type,
state->hash,
policy,
client_auth_CAs,
- state->version);
+ state->version());
state->set_expected_next(CERTIFICATE);
}
@@ -364,7 +417,7 @@ void Server::process_handshake_msg(Handshake_Type type,
}
else if(type == CERTIFICATE_VERIFY)
{
- state->client_verify = new Certificate_Verify(contents, state->version);
+ state->client_verify = new Certificate_Verify(contents, state->version());
const std::vector<X509_Certificate>& client_certs =
state->client_certs->cert_chain();
@@ -421,11 +474,48 @@ void Server::process_handshake_msg(Handshake_Type type,
throw TLS_Exception(Alert::DECRYPT_ERROR,
"Finished message didn't verify");
- // already sent it if resuming
if(!state->server_finished)
{
+ // already sent finished if resuming, so this is a new session
+
state->hash.update(type, contents);
+ Session session_info(
+ state->server_hello->session_id(),
+ state->keys.master_secret(),
+ state->server_hello->version(),
+ state->server_hello->ciphersuite(),
+ state->server_hello->compression_method(),
+ SERVER,
+ secure_renegotiation.supported(),
+ state->server_hello->fragment_size(),
+ peer_certs,
+ MemoryVector<byte>(),
+ m_hostname,
+ ""
+ );
+
+ if(handshake_fn(session_info))
+ {
+ if(state->server_hello->supports_session_ticket())
+ {
+ try
+ {
+ const SymmetricKey ticket_key = creds.psk("tls-server", "session-ticket", "");
+
+ state->new_session_ticket =
+ new New_Session_Ticket(writer, state->hash,
+ session_info.encrypt(ticket_key, rng));
+ }
+ catch(...) {}
+ }
+ else
+ session_manager.save(session_info);
+ }
+
+ if(state->server_hello->supports_session_ticket() && !state->new_session_ticket)
+ state->new_session_ticket = new New_Session_Ticket(writer, state->hash);
+
writer.send(CHANGE_CIPHER_SPEC, 1);
writer.activate(SERVER, state->suite, state->keys,
@@ -437,25 +527,6 @@ void Server::process_handshake_msg(Handshake_Type type,
peer_certs = state->client_certs->cert_chain();
}
- Session session_info(
- state->server_hello->session_id(),
- state->keys.master_secret(),
- state->server_hello->version(),
- state->server_hello->ciphersuite(),
- state->server_hello->compression_method(),
- SERVER,
- secure_renegotiation.supported(),
- state->server_hello->fragment_size(),
- peer_certs,
- m_hostname,
- ""
- );
-
- if(handshake_fn(session_info))
- session_manager.save(session_info);
- else
- session_manager.remove_entry(session_info.session_id());
-
secure_renegotiation.update(state->client_finished,
state->server_finished);
diff --git a/src/tls/tls_session.cpp b/src/tls/tls_session.cpp
index b27409dfa..44689b510 100644
--- a/src/tls/tls_session.cpp
+++ b/src/tls/tls_session.cpp
@@ -1,6 +1,6 @@
/*
* TLS Session State
-* (C) 2011 Jack Lloyd
+* (C) 2011-2012 Jack Lloyd
*
* Released under the terms of the Botan license
*/
@@ -10,6 +10,9 @@
#include <botan/ber_dec.h>
#include <botan/asn1_str.h>
#include <botan/pem.h>
+#include <botan/lookup.h>
+#include <botan/loadstor.h>
+#include <memory>
namespace Botan {
@@ -28,6 +31,7 @@ Session::Session(const MemoryRegion<byte>& session_identifier,
const std::string& srp_identifier) :
m_start_time(std::chrono::system_clock::now()),
m_identifier(session_identifier),
+ m_session_ticket(ticket),
m_master_secret(master_secret),
m_version(version),
m_ciphersuite(ciphersuite),
@@ -41,10 +45,15 @@ Session::Session(const MemoryRegion<byte>& session_identifier,
{
}
-Session::Session(const byte ber[], size_t ber_len)
+Session::Session(const std::string& pem)
{
- BER_Decoder decoder(ber, ber_len);
+ SecureVector<byte> der = PEM_Code::decode_check_label(pem, "SSL SESSION");
+ *this = Session(&der[0], der.size());
+ }
+
+Session::Session(const byte ber[], size_t ber_len)
+ {
byte side_code = 0;
ASN1_String sni_hostname_str;
ASN1_String srp_identifier_str;
@@ -59,10 +68,11 @@ Session::Session(const byte ber[], size_t ber_len)
.start_cons(SEQUENCE)
.decode_and_check(static_cast<size_t>(TLS_SESSION_PARAM_STRUCT_VERSION),
"Unknown version in session structure")
- .decode(m_identifier, OCTET_STRING)
- .decode_integer_type(start_time)
+ .decode_integer_type(m_start_time)
.decode_integer_type(major_version)
.decode_integer_type(minor_version)
+ .decode(m_identifier, OCTET_STRING)
+ .decode(m_session_ticket, OCTET_STRING)
.decode_integer_type(m_ciphersuite)
.decode_integer_type(m_compression_method)
.decode_integer_type(side_code)
@@ -90,13 +100,6 @@ Session::Session(const byte ber[], size_t ber_len)
}
}
-Session::Session(const std::string& pem)
- {
- SecureVector<byte> der = PEM_Code::decode_check_label(pem, "SSL SESSION");
-
- *this = Session(&der[0], der.size());
- }
-
SecureVector<byte> Session::DER_encode() const
{
MemoryVector<byte> peer_cert_bits;
@@ -106,10 +109,11 @@ SecureVector<byte> Session::DER_encode() const
return DER_Encoder()
.start_cons(SEQUENCE)
.encode(static_cast<size_t>(TLS_SESSION_PARAM_STRUCT_VERSION))
- .encode(m_identifier, OCTET_STRING)
.encode(static_cast<size_t>(std::chrono::system_clock::to_time_t(m_start_time)))
.encode(static_cast<size_t>(m_version.major_version()))
.encode(static_cast<size_t>(m_version.minor_version()))
+ .encode(m_identifier, OCTET_STRING)
+ .encode(m_session_ticket, OCTET_STRING)
.encode(static_cast<size_t>(m_ciphersuite))
.encode(static_cast<size_t>(m_compression_method))
.encode(static_cast<size_t>(m_connection_side))
@@ -128,6 +132,113 @@ std::string Session::PEM_encode() const
return PEM_Code::encode(this->DER_encode(), "SSL SESSION");
}
+namespace {
+
+const u32bit SESSION_CRYPTO_MAGIC = 0x571B0E4E;
+const std::string SESSION_CRYPTO_CIPHER = "AES-256/CBC";
+const std::string SESSION_CRYPTO_MAC = "HMAC(SHA-256)";
+const std::string SESSION_CRYPTO_KDF = "KDF2(SHA-256)";
+
+const size_t MAGIC_LENGTH = 4;
+const size_t MAC_KEY_LENGTH = 32;
+const size_t CIPHER_KEY_LENGTH = 32;
+const size_t CIPHER_IV_LENGTH = 16;
+const size_t MAC_OUTPUT_LENGTH = 32;
+
}
+MemoryVector<byte>
+Session::encrypt(const SymmetricKey& master_key,
+ RandomNumberGenerator& rng) const
+ {
+ std::auto_ptr<KDF> kdf(get_kdf(SESSION_CRYPTO_KDF));
+
+ SymmetricKey cipher_key =
+ kdf->derive_key(CIPHER_KEY_LENGTH,
+ master_key.bits_of(),
+ "tls.session.cipher-key");
+
+ SymmetricKey mac_key =
+ kdf->derive_key(MAC_KEY_LENGTH,
+ master_key.bits_of(),
+ "tls.session.mac-key");
+
+ InitializationVector cipher_iv(rng, 16);
+
+ std::auto_ptr<MessageAuthenticationCode> mac(get_mac(SESSION_CRYPTO_MAC));
+ mac->set_key(mac_key);
+
+ Pipe pipe(get_cipher(SESSION_CRYPTO_CIPHER, cipher_key, cipher_iv, ENCRYPTION));
+ pipe.process_msg(this->DER_encode());
+ MemoryVector<byte> ctext = pipe.read_all(0);
+
+ MemoryVector<byte> out(MAGIC_LENGTH);
+ store_be(SESSION_CRYPTO_MAGIC, &out[0]);
+ out += cipher_iv.bits_of();
+ out += ctext;
+
+ mac->update(out);
+
+ out += mac->final();
+ return out;
+ }
+
+Session Session::decrypt(const byte buf[], size_t buf_len,
+ const SymmetricKey& master_key)
+ {
+ try
+ {
+ const size_t MIN_CTEXT_SIZE = 4 * 16; // due to 48 byte master secret
+
+ if(buf_len < (MAGIC_LENGTH +
+ CIPHER_IV_LENGTH +
+ MIN_CTEXT_SIZE +
+ MAC_OUTPUT_LENGTH))
+ throw Decoding_Error("Encrypted TLS session too short to be valid");
+
+ if(load_be<u32bit>(buf, 0) != SESSION_CRYPTO_MAGIC)
+ throw Decoding_Error("Unknown header value in encrypted session");
+
+ std::auto_ptr<KDF> kdf(get_kdf(SESSION_CRYPTO_KDF));
+
+ SymmetricKey mac_key =
+ kdf->derive_key(MAC_KEY_LENGTH,
+ master_key.bits_of(),
+ "tls.session.mac-key");
+
+ std::auto_ptr<MessageAuthenticationCode> mac(get_mac(SESSION_CRYPTO_MAC));
+ mac->set_key(mac_key);
+
+ mac->update(&buf[0], buf_len - MAC_OUTPUT_LENGTH);
+ MemoryVector<byte> computed_mac = mac->final();
+
+ if(!same_mem(&buf[buf_len - MAC_OUTPUT_LENGTH], &computed_mac[0], computed_mac.size()))
+ throw Decoding_Error("MAC verification failed for encrypted session");
+
+ SymmetricKey cipher_key =
+ kdf->derive_key(CIPHER_KEY_LENGTH,
+ master_key.bits_of(),
+ "tls.session.cipher-key");
+
+ InitializationVector cipher_iv(&buf[MAGIC_LENGTH], CIPHER_IV_LENGTH);
+
+ const size_t CTEXT_OFFSET = MAGIC_LENGTH + CIPHER_IV_LENGTH;
+
+ Pipe pipe(get_cipher(SESSION_CRYPTO_CIPHER, cipher_key, cipher_iv, DECRYPTION));
+ pipe.process_msg(&buf[CTEXT_OFFSET],
+ buf_len - (MAC_OUTPUT_LENGTH + CTEXT_OFFSET));
+ SecureVector<byte> ber = pipe.read_all();
+
+ return Session(&ber[0], ber.size());
+ }
+ catch(std::exception& e)
+ {
+ throw Decoding_Error("Failed to decrypt encrypted session -" +
+ std::string(e.what()));
+ }
+ }
+
}
+
+}
+
diff --git a/src/tls/tls_session.h b/src/tls/tls_session.h
index 82c202ebe..8fc048c75 100644
--- a/src/tls/tls_session.h
+++ b/src/tls/tls_session.h
@@ -1,6 +1,6 @@
/*
* TLS Session
-* (C) 2011 Jack Lloyd
+* (C) 2011-2012 Jack Lloyd
*
* Released under the terms of the Botan license
*/
@@ -13,6 +13,7 @@
#include <botan/tls_ciphersuite.h>
#include <botan/tls_magic.h>
#include <botan/secmem.h>
+#include <botan/symkey.h>
#include <chrono>
namespace Botan {
@@ -51,6 +52,7 @@ class BOTAN_DLL Session
bool secure_renegotiation_supported,
size_t fragment_size,
const std::vector<X509_Certificate>& peer_certs,
+ const MemoryRegion<byte>& session_ticket,
const std::string& sni_hostname = "",
const std::string& srp_identifier = "");
@@ -72,6 +74,34 @@ class BOTAN_DLL Session
SecureVector<byte> DER_encode() const;
/**
+ * Encrypt a session (useful for serialization or session tickets)
+ */
+ MemoryVector<byte> encrypt(const SymmetricKey& key,
+ RandomNumberGenerator& rng) const;
+
+
+ /**
+ * Decrypt a session created by encrypt
+ * @param ctext the ciphertext returned by encrypt
+ * @param ctext_size the size of ctext in bytes
+ * @param key the same key used by the encrypting side
+ */
+ static Session decrypt(const byte ctext[],
+ size_t ctext_size,
+ const SymmetricKey& key);
+
+ /**
+ * Decrypt a session created by encrypt
+ * @param ctext the ciphertext returned by encrypt
+ * @param key the same key used by the encrypting side
+ */
+ static inline Session decrypt(const MemoryRegion<byte>& ctext,
+ const SymmetricKey& key)
+ {
+ return Session::decrypt(&ctext[0], ctext.size(), key);
+ }
+
+ /**
* Encode this session data for storage
* @warning if the master secret is compromised so is the
* session traffic
@@ -148,12 +178,18 @@ class BOTAN_DLL Session
std::chrono::system_clock::time_point start_time() const
{ return m_start_time; }
+ /**
+ * Return the session ticket the server gave us
+ */
+ const MemoryVector<byte>& session_ticket() const { return m_session_ticket; }
+
private:
- enum { TLS_SESSION_PARAM_STRUCT_VERSION = 1 };
+ enum { TLS_SESSION_PARAM_STRUCT_VERSION = 0x2994e300 };
std::chrono::system_clock::time_point m_start_time;
MemoryVector<byte> m_identifier;
+ MemoryVector<byte> m_session_ticket; // only used by client side
SecureVector<byte> m_master_secret;
Protocol_Version m_version;
diff --git a/src/tls/tls_session_key.cpp b/src/tls/tls_session_key.cpp
index 0f520d140..4d7603ce1 100644
--- a/src/tls/tls_session_key.cpp
+++ b/src/tls/tls_session_key.cpp
@@ -47,7 +47,7 @@ Session_Keys::Session_Keys(Handshake_State* state,
{
SecureVector<byte> salt;
- if(state->version != Protocol_Version::SSL_V3)
+ if(state->version() != Protocol_Version::SSL_V3)
salt += std::make_pair(MASTER_SECRET_MAGIC, sizeof(MASTER_SECRET_MAGIC));
salt += state->client_hello->random();
@@ -57,7 +57,7 @@ Session_Keys::Session_Keys(Handshake_State* state,
}
SecureVector<byte> salt;
- if(state->version != Protocol_Version::SSL_V3)
+ if(state->version() != Protocol_Version::SSL_V3)
salt += std::make_pair(KEY_GEN_MAGIC, sizeof(KEY_GEN_MAGIC));
salt += state->server_hello->random();
salt += state->client_hello->random();