diff options
author | lloyd <[email protected]> | 2012-03-30 01:41:04 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2012-03-30 01:41:04 +0000 |
commit | 4ff6063a9605c71cc734a594ddecbdb0d17541bf (patch) | |
tree | 845fa57d5f9e24fc39f5fa847296e4b778e6a894 /src/tls | |
parent | 8a31da4d60490753031267b18957c0c599bbee3b (diff) | |
parent | 4c12fa5de1b59f2c58f974412231a19c4dc7c10f (diff) |
propagate from branch 'net.randombit.botan.tls-state-machine' (head 63b88a65b699c95ef839bc18336bceccfbfabd2e)
to branch 'net.randombit.botan.cxx11' (head 1adcc46808b403b8f6bf1669f022e65f9c30e8ea)
Diffstat (limited to 'src/tls')
35 files changed, 1358 insertions, 247 deletions
diff --git a/src/tls/c_hello.cpp b/src/tls/c_hello.cpp index 6743254a5..b08e1abe2 100644 --- a/src/tls/c_hello.cpp +++ b/src/tls/c_hello.cpp @@ -30,34 +30,11 @@ MemoryVector<byte> make_hello_random(RandomNumberGenerator& rng) } /* -* Encode and send a Handshake message -*/ -void Handshake_Message::send(Record_Writer& writer, Handshake_Hash& hash) const - { - MemoryVector<byte> buf = serialize(); - MemoryVector<byte> send_buf(4); - - const size_t buf_size = buf.size(); - - send_buf[0] = type(); - - for(size_t i = 1; i != 4; ++i) - send_buf[i] = get_byte<u32bit>(i, buf_size); - - send_buf += buf; - - hash.update(send_buf); - - writer.send(HANDSHAKE, &send_buf[0], send_buf.size()); - } - -/* * Create a new Hello Request message */ Hello_Request::Hello_Request(Record_Writer& writer) { - Handshake_Hash dummy; // FIXME: *UGLY* - send(writer, dummy); + writer.send(*this); } /* @@ -66,7 +43,7 @@ Hello_Request::Hello_Request(Record_Writer& writer) Hello_Request::Hello_Request(const MemoryRegion<byte>& buf) { if(buf.size()) - throw Decoding_Error("Hello_Request: Must be empty, and is not"); + throw Decoding_Error("Bad Hello_Request, has non-zero size"); } /* @@ -97,49 +74,67 @@ Client_Hello::Client_Hello(Record_Writer& writer, m_next_protocol(next_protocol), m_fragment_size(0), m_secure_renegotiation(true), - m_renegotiation_info(reneg_info) + m_renegotiation_info(reneg_info), + m_supported_curves(policy.allowed_ecc_curves()), + m_supports_session_ticket(true) { std::vector<std::string> hashes = policy.allowed_hashes(); std::vector<std::string> sigs = policy.allowed_signature_methods(); - m_supported_curves = policy.allowed_ecc_curves(); - for(size_t i = 0; i != hashes.size(); ++i) for(size_t j = 0; j != sigs.size(); ++j) m_supported_algos.push_back(std::make_pair(hashes[i], sigs[j])); - send(writer, hash); + hash.update(writer.send(*this)); } /* -* Create a new Client Hello message +* Create a new Client Hello message (session resumption case) */ Client_Hello::Client_Hello(Record_Writer& writer, Handshake_Hash& hash, + const Policy& policy, RandomNumberGenerator& rng, const Session& session, bool next_protocol) : m_version(session.version()), m_session_id(session.session_id()), m_random(make_hello_random(rng)), + m_suites(policy.ciphersuite_list(session.srp_identifier() != "")), + m_comp_methods(policy.compression()), m_hostname(session.sni_hostname()), m_srp_identifier(session.srp_identifier()), m_next_protocol(next_protocol), m_fragment_size(session.fragment_size()), - m_secure_renegotiation(session.secure_renegotiation()) + m_secure_renegotiation(session.secure_renegotiation()), + m_supported_curves(policy.allowed_ecc_curves()), + m_supports_session_ticket(true), + m_session_ticket(session.session_ticket()) { - m_suites.push_back(session.ciphersuite_code()); - m_comp_methods.push_back(session.compression_method()); + if(!value_exists(m_suites, session.ciphersuite_code())) + m_suites.push_back(session.ciphersuite_code()); - // set m_supported_algos + m_supported_curves here? + if(!value_exists(m_comp_methods, session.compression_method())) + m_comp_methods.push_back(session.compression_method()); - send(writer, hash); + std::vector<std::string> hashes = policy.allowed_hashes(); + std::vector<std::string> sigs = policy.allowed_signature_methods(); + + for(size_t i = 0; i != hashes.size(); ++i) + for(size_t j = 0; j != sigs.size(); ++j) + m_supported_algos.push_back(std::make_pair(hashes[i], sigs[j])); + + hash.update(writer.send(*this)); } +/* +* Read a counterparty client hello +*/ Client_Hello::Client_Hello(const MemoryRegion<byte>& buf, Handshake_Type type) { m_next_protocol = false; m_secure_renegotiation = false; + m_supports_session_ticket = false; m_fragment_size = 0; if(type == CLIENT_HELLO) @@ -185,11 +180,14 @@ MemoryVector<byte> Client_Hello::serialize() const if(m_next_protocol) extensions.add(new Next_Protocol_Notification()); + + extensions.add(new Session_Ticket(m_session_ticket)); } else { // renegotiation extensions.add(new Renegotation_Extension(m_renegotiation_info)); + extensions.add(new Session_Ticket(m_session_ticket)); } buf += extensions.serialize(); @@ -328,6 +326,24 @@ void Client_Hello::deserialize(const MemoryRegion<byte>& buf) } } + if(Maximum_Fragment_Length* frag = extensions.get<Maximum_Fragment_Length>()) + { + m_fragment_size = frag->fragment_size(); + } + + if(Session_Ticket* ticket = extensions.get<Session_Ticket>()) + { + m_supports_session_ticket = true; + m_session_ticket = ticket->contents(); + } + + if(Renegotation_Extension* reneg = extensions.get<Renegotation_Extension>()) + { + // checked by TLS_Client / TLS_Server as they know the handshake state + m_secure_renegotiation = true; + m_renegotiation_info = reneg->renegotiation_info(); + } + if(value_exists(m_suites, static_cast<u16bit>(TLS_EMPTY_RENEGOTIATION_INFO_SCSV))) { /* diff --git a/src/tls/c_kex.cpp b/src/tls/c_kex.cpp index 6728b4877..0e7eba23a 100644 --- a/src/tls/c_kex.cpp +++ b/src/tls/c_kex.cpp @@ -8,6 +8,7 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_reader.h> #include <botan/internal/tls_extensions.h> +#include <botan/tls_record.h> #include <botan/internal/assert.h> #include <botan/credentials_manager.h> #include <botan/pubkey.h> @@ -201,7 +202,7 @@ Client_Key_Exchange::Client_Key_Exchange(Record_Writer& writer, MemoryVector<byte> encrypted_key = encryptor.encrypt(pre_master, rng); - if(state->version == Protocol_Version::SSL_V3) + if(state->version() == Protocol_Version::SSL_V3) key_material = encrypted_key; // no length field else append_tls_length_value(key_material, encrypted_key, 2); @@ -212,7 +213,7 @@ Client_Key_Exchange::Client_Key_Exchange(Record_Writer& writer, pub_key->algo_name()); } - send(writer, state->hash); + state->hash.update(writer.send(*this)); } /* @@ -245,7 +246,7 @@ Client_Key_Exchange::Client_Key_Exchange(const MemoryRegion<byte>& contents, try { - if(state->version == Protocol_Version::SSL_V3) + if(state->version() == Protocol_Version::SSL_V3) { pre_master = decryptor.decrypt(contents); } diff --git a/src/tls/cert_req.cpp b/src/tls/cert_req.cpp index f400a36d2..df70cc43d 100644 --- a/src/tls/cert_req.cpp +++ b/src/tls/cert_req.cpp @@ -8,10 +8,10 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_reader.h> #include <botan/internal/tls_extensions.h> +#include <botan/tls_record.h> #include <botan/der_enc.h> #include <botan/ber_dec.h> #include <botan/loadstor.h> -#include <botan/secqueue.h> namespace Botan { @@ -74,7 +74,7 @@ Certificate_Req::Certificate_Req(Record_Writer& writer, m_supported_algos.push_back(std::make_pair(hashes[i], sigs[j])); } - send(writer, hash); + hash.update(writer.send(*this)); } /** @@ -173,10 +173,10 @@ MemoryVector<byte> Certificate_Req::serialize() const */ Certificate::Certificate(Record_Writer& writer, Handshake_Hash& hash, - const std::vector<X509_Certificate>& cert_list) + const std::vector<X509_Certificate>& cert_list) : + m_certs(cert_list) { - certs = cert_list; - send(writer, hash); + hash.update(writer.send(*this)); } /** @@ -189,27 +189,25 @@ Certificate::Certificate(const MemoryRegion<byte>& buf) const size_t total_size = make_u32bit(0, buf[0], buf[1], buf[2]); - SecureQueue queue; - queue.write(&buf[3], buf.size() - 3); - - if(queue.size() != total_size) + if(total_size != buf.size() - 3) throw Decoding_Error("Certificate: Message malformed"); - while(queue.size()) + const byte* certs = &buf[3]; + + while(certs != buf.end()) { - if(queue.size() < 3) + if(buf.end() - certs < 3) throw Decoding_Error("Certificate: Message malformed"); - byte len[3]; - queue.read(len, 3); - - const size_t cert_size = make_u32bit(0, len[0], len[1], len[2]); - const size_t original_size = queue.size(); + const size_t cert_size = make_u32bit(0, certs[0], certs[1], certs[2]); - X509_Certificate cert(queue); - if(queue.size() + cert_size != original_size) + if(buf.end() - certs < (3 + cert_size)) throw Decoding_Error("Certificate: Message malformed"); - certs.push_back(cert); + + DataSource_Memory cert_buf(&certs[3], cert_size); + m_certs.push_back(X509_Certificate(cert_buf)); + + certs += cert_size + 3; } } @@ -220,9 +218,9 @@ MemoryVector<byte> Certificate::serialize() const { MemoryVector<byte> buf(3); - for(size_t i = 0; i != certs.size(); ++i) + for(size_t i = 0; i != m_certs.size(); ++i) { - MemoryVector<byte> raw_cert = certs[i].BER_encode(); + MemoryVector<byte> raw_cert = m_certs[i].BER_encode(); const size_t cert_size = raw_cert.size(); for(size_t i = 0; i != 3; ++i) buf.push_back(get_byte<u32bit>(i+1, cert_size)); diff --git a/src/tls/cert_ver.cpp b/src/tls/cert_ver.cpp index ffcba84d8..0a377b35f 100644 --- a/src/tls/cert_ver.cpp +++ b/src/tls/cert_ver.cpp @@ -8,6 +8,7 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_reader.h> #include <botan/internal/tls_extensions.h> +#include <botan/tls_record.h> #include <botan/internal/assert.h> #include <memory> @@ -30,7 +31,7 @@ Certificate_Verify::Certificate_Verify(Record_Writer& writer, PK_Signer signer(*priv_key, format.first, format.second); - if(state->version == Protocol_Version::SSL_V3) + if(state->version() == Protocol_Version::SSL_V3) { SecureVector<byte> md5_sha = state->hash.final_ssl3( state->keys.master_secret()); @@ -45,7 +46,7 @@ Certificate_Verify::Certificate_Verify(Record_Writer& writer, signature = signer.sign_message(state->hash.get_contents(), rng); } - send(writer, state->hash); + state->hash.update(writer.send(*this)); } /* @@ -99,7 +100,7 @@ bool Certificate_Verify::verify(const X509_Certificate& cert, PK_Verifier verifier(*key, format.first, format.second); - if(state->version == Protocol_Version::SSL_V3) + if(state->version() == Protocol_Version::SSL_V3) { SecureVector<byte> md5_sha = state->hash.final_ssl3( state->keys.master_secret()); diff --git a/src/tls/finished.cpp b/src/tls/finished.cpp index b4c9bdc3d..bb6e8d20e 100644 --- a/src/tls/finished.cpp +++ b/src/tls/finished.cpp @@ -6,6 +6,7 @@ */ #include <botan/internal/tls_messages.h> +#include <botan/tls_record.h> #include <memory> namespace Botan { @@ -20,7 +21,7 @@ namespace { MemoryVector<byte> finished_compute_verify(Handshake_State* state, Connection_Side side) { - if(state->version == Protocol_Version::SSL_V3) + if(state->version() == Protocol_Version::SSL_V3) { const byte SSL_CLIENT_LABEL[] = { 0x43, 0x4C, 0x4E, 0x54 }; const byte SSL_SERVER_LABEL[] = { 0x53, 0x52, 0x56, 0x52 }; @@ -54,7 +55,7 @@ MemoryVector<byte> finished_compute_verify(Handshake_State* state, else input += std::make_pair(TLS_SERVER_LABEL, sizeof(TLS_SERVER_LABEL)); - input += state->hash.final(state->version, state->suite.mac_algo()); + input += state->hash.final(state->version(), state->suite.mac_algo()); return prf->derive_key(12, state->keys.master_secret(), input); } @@ -70,7 +71,7 @@ Finished::Finished(Record_Writer& writer, Connection_Side side) { verification_data = finished_compute_verify(state, side); - send(writer, state->hash); + state->hash.update(writer.send(*this)); } /* diff --git a/src/tls/hello_verify.cpp b/src/tls/hello_verify.cpp new file mode 100644 index 000000000..c7aae94a1 --- /dev/null +++ b/src/tls/hello_verify.cpp @@ -0,0 +1,61 @@ +/* +* DTLS Hello Verify Request +* (C) 2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#include <botan/internal/tls_messages.h> +#include <botan/lookup.h> +#include <memory> + +namespace Botan { + +namespace TLS { + +Hello_Verify_Request::Hello_Verify_Request(const MemoryRegion<byte>& buf) + { + if(buf.size() < 3) + throw Decoding_Error("Hello verify request too small"); + + if(buf[0] != 254 || (buf[1] != 255 && buf[1] != 253)) + throw Decoding_Error("Unknown version from server in hello verify request"); + + m_cookie.resize(buf.size() - 2); + copy_mem(&m_cookie[0], &buf[2], buf.size() - 2); + } + +Hello_Verify_Request::Hello_Verify_Request(const MemoryVector<byte>& client_hello_bits, + const std::string& client_identity, + const SymmetricKey& secret_key) + { + std::auto_ptr<MessageAuthenticationCode> hmac(get_mac("HMAC(SHA-256)")); + hmac->set_key(secret_key); + + hmac->update_be(client_hello_bits.size()); + hmac->update(client_hello_bits); + hmac->update_be(client_identity.size()); + hmac->update(client_identity); + + m_cookie = hmac->final(); + } + +MemoryVector<byte> Hello_Verify_Request::serialize() const + { + /* DTLS 1.2 server implementations SHOULD use DTLS version 1.0 + regardless of the version of TLS that is expected to be + negotiated (RFC 6347, section 4.2.1) + */ + + Protocol_Version format_version(Protocol_Version::TLS_V11); + + MemoryVector<byte> bits; + bits.push_back(format_version.major_version()); + bits.push_back(format_version.minor_version()); + bits += m_cookie; + return bits; + } + +} + +} diff --git a/src/tls/info.txt b/src/tls/info.txt index 68ca026d5..b19eedb20 100644 --- a/src/tls/info.txt +++ b/src/tls/info.txt @@ -23,6 +23,7 @@ tls_version.h <header:internal> tls_extensions.h tls_handshake_hash.h +tls_handshake_reader.h tls_handshake_state.h tls_messages.h tls_reader.h @@ -36,15 +37,18 @@ c_kex.cpp cert_req.cpp cert_ver.cpp finished.cpp +hello_verify.cpp next_protocol.cpp rec_read.cpp rec_wri.cpp s_hello.cpp s_kex.cpp +session_ticket.cpp tls_channel.cpp tls_client.cpp tls_extensions.cpp tls_handshake_hash.cpp +tls_handshake_reader.cpp tls_handshake_state.cpp tls_policy.cpp tls_server.cpp @@ -59,6 +63,7 @@ tls_version.cpp aes arc4 asn1 +credentials des dh dsa @@ -68,6 +73,7 @@ eme_pkcs emsa3 filters hmac +kdf2 md5 prf_ssl3 prf_tls diff --git a/src/tls/next_protocol.cpp b/src/tls/next_protocol.cpp index 97b072440..17b77fb6e 100644 --- a/src/tls/next_protocol.cpp +++ b/src/tls/next_protocol.cpp @@ -8,6 +8,7 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_extensions.h> #include <botan/internal/tls_reader.h> +#include <botan/tls_record.h> namespace Botan { @@ -18,7 +19,7 @@ Next_Protocol::Next_Protocol(Record_Writer& writer, const std::string& protocol) : m_protocol(protocol) { - send(writer, hash); + hash.update(writer.send(*this)); } Next_Protocol::Next_Protocol(const MemoryRegion<byte>& buf) diff --git a/src/tls/rec_wri.cpp b/src/tls/rec_wri.cpp index 633b63720..3a54d7931 100644 --- a/src/tls/rec_wri.cpp +++ b/src/tls/rec_wri.cpp @@ -6,6 +6,7 @@ */ #include <botan/tls_record.h> +#include <botan/internal/tls_messages.h> #include <botan/internal/tls_session_key.h> #include <botan/internal/tls_handshake_hash.h> #include <botan/lookup.h> @@ -22,9 +23,10 @@ namespace TLS { * Record_Writer Constructor */ Record_Writer::Record_Writer(std::function<void (const byte[], size_t)> out) : - m_output_fn(out), m_writebuf(TLS_HEADER_SIZE + MAX_CIPHERTEXT_SIZE) + m_output_fn(out), + m_writebuf(TLS_HEADER_SIZE + MAX_CIPHERTEXT_SIZE), + m_mac(0) { - m_mac = 0; reset(); set_maximum_fragment_size(0); } @@ -144,6 +146,25 @@ void Record_Writer::activate(Connection_Side side, throw Invalid_Argument("Record_Writer: Unknown hash " + mac_algo); } +MemoryVector<byte> Record_Writer::send(Handshake_Message& msg) + { + const MemoryVector<byte> buf = msg.serialize(); + MemoryVector<byte> send_buf(4); + + const size_t buf_size = buf.size(); + + send_buf[0] = msg.type(); + + for(size_t i = 1; i != 4; ++i) + send_buf[i] = get_byte<u32bit>(i, buf_size); + + send_buf += buf; + + send(HANDSHAKE, &send_buf[0], send_buf.size()); + + return send_buf; + } + /* * Send one or more records to the other side */ @@ -190,16 +211,15 @@ void Record_Writer::send_record(byte type, const byte input[], size_t length) if(m_mac_size == 0) // initial unencrypted handshake records { - const byte header[TLS_HEADER_SIZE] = { - type, - m_version.major_version(), - m_version.minor_version(), - get_byte<u16bit>(0, length), - get_byte<u16bit>(1, length) - }; - - m_output_fn(header, TLS_HEADER_SIZE); - m_output_fn(input, length); + m_writebuf[0] = type; + m_writebuf[1] = m_version.major_version(); + m_writebuf[2] = m_version.minor_version(); + m_writebuf[3] = get_byte<u16bit>(0, length); + m_writebuf[4] = get_byte<u16bit>(1, length); + + copy_mem(&m_writebuf[TLS_HEADER_SIZE], input, length); + + m_output_fn(&m_writebuf[0], TLS_HEADER_SIZE + length); return; } diff --git a/src/tls/s_hello.cpp b/src/tls/s_hello.cpp index 0ad78fc5b..7da9fdc57 100644 --- a/src/tls/s_hello.cpp +++ b/src/tls/s_hello.cpp @@ -25,6 +25,7 @@ Server_Hello::Server_Hello(Record_Writer& writer, const Client_Hello& c_hello, const std::vector<std::string>& available_cert_types, const Policy& policy, + bool have_session_ticket_key, bool client_has_secure_renegotiation, const MemoryRegion<byte>& reneg_info, bool client_has_npn, @@ -37,7 +38,9 @@ Server_Hello::Server_Hello(Record_Writer& writer, m_secure_renegotiation(client_has_secure_renegotiation), m_renegotiation_info(reneg_info), m_next_protocol(client_has_npn), - m_next_protocols(next_protocols) + m_next_protocols(next_protocols), + m_supports_session_ticket(have_session_ticket_key && + c_hello.supports_session_ticket()) { suite = policy.choose_suite( c_hello.ciphersuites(), @@ -51,7 +54,7 @@ Server_Hello::Server_Hello(Record_Writer& writer, comp_method = policy.choose_compression(c_hello.compression_methods()); - send(writer, hash); + hash.update(writer.send(*this)); } /* @@ -66,6 +69,7 @@ Server_Hello::Server_Hello(Record_Writer& writer, size_t max_fragment_size, bool client_has_secure_renegotiation, const MemoryRegion<byte>& reneg_info, + bool client_supports_session_tickets, bool client_has_npn, const std::vector<std::string>& next_protocols, RandomNumberGenerator& rng) : @@ -78,9 +82,10 @@ Server_Hello::Server_Hello(Record_Writer& writer, m_secure_renegotiation(client_has_secure_renegotiation), m_renegotiation_info(reneg_info), m_next_protocol(client_has_npn), - m_next_protocols(next_protocols) + m_next_protocols(next_protocols), + m_supports_session_ticket(client_supports_session_tickets) { - send(writer, hash); + hash.update(writer.send(*this)); } /* @@ -89,6 +94,7 @@ Server_Hello::Server_Hello(Record_Writer& writer, Server_Hello::Server_Hello(const MemoryRegion<byte>& buf) { m_secure_renegotiation = false; + m_supports_session_ticket = false; m_next_protocol = false; if(buf.size() < 38) @@ -132,6 +138,13 @@ Server_Hello::Server_Hello(const MemoryRegion<byte>& buf) m_next_protocols = npn->protocols(); m_next_protocol = true; } + + if(Session_Ticket* ticket = extensions.get<Session_Ticket>()) + { + if(!ticket->contents().empty()) + throw Decoding_Error("TLS server sent non-empty session ticket extension"); + m_supports_session_ticket = true; + } } /* @@ -163,6 +176,9 @@ MemoryVector<byte> Server_Hello::serialize() const if(m_next_protocol) extensions.add(new Next_Protocol_Notification(m_next_protocols)); + if(m_supports_session_ticket) + extensions.add(new Session_Ticket()); + buf += extensions.serialize(); return buf; @@ -174,7 +190,7 @@ MemoryVector<byte> Server_Hello::serialize() const Server_Hello_Done::Server_Hello_Done(Record_Writer& writer, Handshake_Hash& hash) { - send(writer, hash); + hash.update(writer.send(*this)); } /* diff --git a/src/tls/s_kex.cpp b/src/tls/s_kex.cpp index 945c574b9..0890cac49 100644 --- a/src/tls/s_kex.cpp +++ b/src/tls/s_kex.cpp @@ -8,6 +8,7 @@ #include <botan/internal/tls_messages.h> #include <botan/internal/tls_reader.h> #include <botan/internal/tls_extensions.h> +#include <botan/tls_record.h> #include <botan/internal/assert.h> #include <botan/credentials_manager.h> #include <botan/loadstor.h> @@ -105,7 +106,7 @@ Server_Key_Exchange::Server_Key_Exchange(Record_Writer& writer, m_signature = signer.signature(rng); } - send(writer, state->hash); + state->hash.update(writer.send(*this)); } /** diff --git a/src/tls/session_ticket.cpp b/src/tls/session_ticket.cpp new file mode 100644 index 000000000..273996a16 --- /dev/null +++ b/src/tls/session_ticket.cpp @@ -0,0 +1,57 @@ +/* +* Session Tickets +* (C) 2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#include <botan/internal/tls_messages.h> +#include <botan/internal/tls_extensions.h> +#include <botan/internal/tls_reader.h> +#include <botan/tls_record.h> +#include <botan/loadstor.h> + +namespace Botan { + +namespace TLS { + +New_Session_Ticket::New_Session_Ticket(Record_Writer& writer, + Handshake_Hash& hash, + const MemoryRegion<byte>& ticket, + u32bit lifetime) : + m_ticket_lifetime_hint(lifetime), + m_ticket(ticket) + { + hash.update(writer.send(*this)); + } + +New_Session_Ticket::New_Session_Ticket(Record_Writer& writer, + Handshake_Hash& hash) : + m_ticket_lifetime_hint(0) + { + hash.update(writer.send(*this)); + } + +New_Session_Ticket::New_Session_Ticket(const MemoryRegion<byte>& buf) : + m_ticket_lifetime_hint(0) + { + if(buf.size() < 6) + throw Decoding_Error("Session ticket message too short to be valid"); + + TLS_Data_Reader reader(buf); + + m_ticket_lifetime_hint = reader.get_u32bit(); + m_ticket = reader.get_range<byte>(2, 0, 65535); + } + +MemoryVector<byte> New_Session_Ticket::serialize() const + { + MemoryVector<byte> buf(4); + store_be(m_ticket_lifetime_hint, &buf[0]); + append_tls_length_value(buf, m_ticket, 2); + return buf; + } + +} + +} diff --git a/src/tls/sessions_sqlite/info.txt b/src/tls/sessions_sqlite/info.txt new file mode 100644 index 000000000..c5fc35952 --- /dev/null +++ b/src/tls/sessions_sqlite/info.txt @@ -0,0 +1,11 @@ +define TLS_SQLITE_SESSION_MANAGER + +load_on request + +<libs> +all -> sqlite3 +</libs> + +<requires> +pbkdf2 +</requires> diff --git a/src/tls/sessions_sqlite/tls_sqlite_sess_mgr.cpp b/src/tls/sessions_sqlite/tls_sqlite_sess_mgr.cpp new file mode 100644 index 000000000..4d78a5365 --- /dev/null +++ b/src/tls/sessions_sqlite/tls_sqlite_sess_mgr.cpp @@ -0,0 +1,343 @@ +/* +* SQLite TLS Session Manager +* (C) 2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#include <botan/tls_sqlite_sess_mgr.h> +#include <botan/internal/assert.h> +#include <botan/lookup.h> +#include <botan/hex.h> +#include <botan/time.h> +#include <botan/loadstor.h> +#include <memory> + +#include <sqlite3.h> + +namespace Botan { + +namespace TLS { + +namespace { + +class sqlite3_statement + { + public: + sqlite3_statement(sqlite3* db, const std::string& base_sql) + { + int rc = sqlite3_prepare_v2(db, base_sql.c_str(), -1, &m_stmt, 0); + + if(rc != SQLITE_OK) + throw std::runtime_error("sqlite3_prepare failed " + base_sql + ", code " + to_string(rc)); + } + + void bind(int column, const std::string& val) + { + int rc = sqlite3_bind_text(m_stmt, column, val.c_str(), -1, SQLITE_TRANSIENT); + if(rc != SQLITE_OK) + throw std::runtime_error("sqlite3_bind_text failed, code " + to_string(rc)); + } + + void bind(int column, int val) + { + int rc = sqlite3_bind_int(m_stmt, column, val); + if(rc != SQLITE_OK) + throw std::runtime_error("sqlite3_bind_int failed, code " + to_string(rc)); + } + + void bind(int column, const MemoryRegion<byte>& val) + { + int rc = sqlite3_bind_blob(m_stmt, column, &val[0], val.size(), SQLITE_TRANSIENT); + if(rc != SQLITE_OK) + throw std::runtime_error("sqlite3_bind_text failed, code " + to_string(rc)); + } + + std::pair<const byte*, size_t> get_blob(int column) + { + BOTAN_ASSERT(sqlite3_column_type(m_stmt, 0) == SQLITE_BLOB, + "Return value is a blob"); + + const void* session_blob = sqlite3_column_blob(m_stmt, column); + const int session_blob_size = sqlite3_column_bytes(m_stmt, column); + + BOTAN_ASSERT(session_blob_size >= 0, "Blob size is non-negative"); + + return std::make_pair(static_cast<const byte*>(session_blob), + static_cast<size_t>(session_blob_size)); + } + + size_t get_size_t(int column) + { + BOTAN_ASSERT(sqlite3_column_type(m_stmt, column) == SQLITE_INTEGER, + "Return count is an integer"); + + const int sessions_int = sqlite3_column_int(m_stmt, column); + + BOTAN_ASSERT(sessions_int >= 0, "Expected size_t is non-negative"); + + return static_cast<size_t>(sessions_int); + } + + void spin() + { + while(sqlite3_step(m_stmt) == SQLITE_ROW) + {} + } + + int step() + { + return sqlite3_step(m_stmt); + } + + sqlite3_stmt* stmt() { return m_stmt; } + + ~sqlite3_statement() { sqlite3_finalize(m_stmt); } + private: + sqlite3_stmt* m_stmt; + }; + +size_t row_count(sqlite3* db, const std::string& table_name) + { + sqlite3_statement stmt(db, "select count(*) from " + table_name); + + if(stmt.step() == SQLITE_ROW) + return stmt.get_size_t(0); + else + throw std::runtime_error("Querying size of table " + table_name + " failed"); + } + +void create_table(sqlite3* db, const char* table_schema) + { + char* errmsg = 0; + int rc = sqlite3_exec(db, table_schema, 0, 0, &errmsg); + + if(rc != SQLITE_OK) + { + const std::string err_msg = errmsg; + sqlite3_free(errmsg); + sqlite3_close(db); + throw std::runtime_error("sqlite3_exec for table failed - " + err_msg); + } + } + + +SymmetricKey derive_key(const std::string& passphrase, + const byte salt[], + size_t salt_len, + size_t iterations, + size_t& check_val) + { + std::auto_ptr<PBKDF> pbkdf(get_pbkdf("PBKDF2(SHA-512)")); + + SecureVector<byte> x = pbkdf->derive_key(32 + 3, + passphrase, + salt, salt_len, + iterations).bits_of(); + + check_val = make_u32bit(0, x[0], x[1], x[2]); + return SymmetricKey(&x[3], x.size() - 3); + } + +} + +Session_Manager_SQLite::Session_Manager_SQLite(const std::string& passphrase, + RandomNumberGenerator& rng, + const std::string& db_filename, + size_t max_sessions, + size_t session_lifetime) : + m_rng(rng), + m_max_sessions(max_sessions), + m_session_lifetime(session_lifetime) + { + int rc = sqlite3_open(db_filename.c_str(), &m_db); + + if(rc) + { + const std::string err_msg = sqlite3_errmsg(m_db); + sqlite3_close(m_db); + throw std::runtime_error("sqlite3_open failed - " + err_msg); + } + + create_table(m_db, + "create table if not exists tls_sessions " + "(" + "session_id TEXT PRIMARY KEY, " + "session_start INTEGER, " + "hostname TEXT, " + "hostport INTEGER, " + "session BLOB" + ")"); + + create_table(m_db, + "create table if not exists tls_sessions_metadata " + "(" + "passphrase_salt BLOB, " + "passphrase_iterations INTEGER, " + "passphrase_check INTEGER " + ")"); + + const size_t salts = row_count(m_db, "tls_sessions_metadata"); + + if(salts == 1) + { + // existing db + sqlite3_statement stmt(m_db, "select * from tls_sessions_metadata"); + + int rc = stmt.step(); + if(rc == SQLITE_ROW) + { + std::pair<const byte*, size_t> salt = stmt.get_blob(0); + const size_t iterations = stmt.get_size_t(1); + const size_t check_val_db = stmt.get_size_t(2); + + size_t check_val_created; + m_session_key = derive_key(passphrase, + salt.first, + salt.second, + iterations, + check_val_created); + + if(check_val_created != check_val_db) + throw std::runtime_error("Session database password not valid"); + } + } + else + { + // maybe just zap the salts + sessions tables in this case? + if(salts != 0) + throw std::runtime_error("Seemingly corrupted database, multiple salts found"); + + // new database case + + MemoryVector<byte> salt = rng.random_vec(16); + const size_t iterations = 64 * 1024; + size_t check_val = 0; + + m_session_key = derive_key(passphrase, &salt[0], salt.size(), + iterations, check_val); + + sqlite3_statement stmt(m_db, "insert into tls_sessions_metadata" + " values(?1, ?2, ?3)"); + + stmt.bind(1, salt); + stmt.bind(2, iterations); + stmt.bind(3, check_val); + + stmt.spin(); + } + } + +Session_Manager_SQLite::~Session_Manager_SQLite() + { + sqlite3_close(m_db); + } + +bool Session_Manager_SQLite::load_from_session_id(const MemoryRegion<byte>& session_id, + Session& session) + { + sqlite3_statement stmt(m_db, "select session from tls_sessions where session_id = ?1"); + + stmt.bind(1, hex_encode(session_id)); + + int rc = stmt.step(); + + while(rc == SQLITE_ROW) + { + std::pair<const byte*, size_t> blob = stmt.get_blob(0); + + try + { + session = Session::decrypt(blob.first, blob.second, m_session_key); + return true; + } + catch(...) + { + } + + rc = stmt.step(); + } + + return false; + } + +bool Session_Manager_SQLite::load_from_host_info(const std::string& hostname, + u16bit port, + Session& session) + { + sqlite3_statement stmt(m_db, "select session from tls_sessions" + " where hostname = ?1 and hostport = ?2" + " order by session_start desc"); + + stmt.bind(1, hostname); + stmt.bind(2, port); + + int rc = stmt.step(); + + while(rc == SQLITE_ROW) + { + std::pair<const byte*, size_t> blob = stmt.get_blob(0); + + try + { + session = Session::decrypt(blob.first, blob.second, m_session_key); + return true; + } + catch(...) + { + } + + rc = stmt.step(); + } + + return false; + } + +void Session_Manager_SQLite::remove_entry(const MemoryRegion<byte>& session_id) + { + sqlite3_statement stmt(m_db, "delete from tls_sessions where session_id = ?1"); + + stmt.bind(1, hex_encode(session_id)); + + stmt.spin(); + } + +void Session_Manager_SQLite::save(const Session& session) + { + sqlite3_statement stmt(m_db, "insert or replace into tls_sessions" + " values(?1, ?2, ?3, ?4, ?5)"); + + stmt.bind(1, hex_encode(session.session_id())); + stmt.bind(2, session.start_time()); + stmt.bind(3, session.sni_hostname()); + stmt.bind(4, 0); + stmt.bind(5, session.encrypt(m_session_key, m_rng)); + + stmt.spin(); + + prune_session_cache(); + } + +void Session_Manager_SQLite::prune_session_cache() + { + sqlite3_statement remove_expired(m_db, "delete from tls_sessions where session_start <= ?1"); + + remove_expired.bind(1, system_time() - m_session_lifetime); + + remove_expired.spin(); + + const size_t sessions = row_count(m_db, "tls_sessions"); + + if(sessions > m_max_sessions) + { + sqlite3_statement remove_some(m_db, "delete from tls_sessions where session_id in " + "(select session_id from tls_sessions limit ?1)"); + + remove_some.bind(1, sessions - m_max_sessions); + remove_some.spin(); + } + } + +} + +} diff --git a/src/tls/sessions_sqlite/tls_sqlite_sess_mgr.h b/src/tls/sessions_sqlite/tls_sqlite_sess_mgr.h new file mode 100644 index 000000000..424db24e5 --- /dev/null +++ b/src/tls/sessions_sqlite/tls_sqlite_sess_mgr.h @@ -0,0 +1,68 @@ +/* +* SQLite TLS Session Manager +* (C) 2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#ifndef TLS_SQLITE_SESSION_MANAGER_H__ +#define TLS_SQLITE_SESSION_MANAGER_H__ + +#include <botan/tls_session_manager.h> +#include <botan/rng.h> + +class sqlite3; + +namespace Botan { + +namespace TLS { + +/** +*/ +class BOTAN_DLL Session_Manager_SQLite : public Session_Manager + { + public: + /** + * @param passphrase used to encrypt the session data + * @param db_filename filename of the SQLite database file. + The table names tls_sessions and tls_sessions_metadata + will be used + * @param max_sessions a hint on the maximum number of sessions + * to keep in memory at any one time. (If zero, don't cap) + * @param session_lifetime sessions are expired after this many + * seconds have elapsed from initial handshake. + */ + Session_Manager_SQLite(const std::string& passphrase, + RandomNumberGenerator& rng, + const std::string& db_filename, + size_t max_sessions = 1000, + size_t session_lifetime = 7200); + + ~Session_Manager_SQLite(); + + bool load_from_session_id(const MemoryRegion<byte>& session_id, + Session& session); + + bool load_from_host_info(const std::string& hostname, u16bit port, + Session& session); + + void remove_entry(const MemoryRegion<byte>& session_id); + + void save(const Session& session_data); + private: + Session_Manager_SQLite(const Session_Manager_SQLite&); + Session_Manager_SQLite& operator=(const Session_Manager_SQLite&); + + void prune_session_cache(); + + SymmetricKey m_session_key; + RandomNumberGenerator& m_rng; + size_t m_max_sessions, m_session_lifetime; + class sqlite3* m_db; + }; + +} + +} + +#endif diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index fa240cc23..736b37654 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -142,49 +142,38 @@ void Channel::read_handshake(byte rec_type, if(rec_type == HANDSHAKE) { if(!state) - state = new Handshake_State; - state->queue.write(&rec_buf[0], rec_buf.size()); + state = new Handshake_State(new Stream_Handshake_Reader); + state->handshake_reader()->add_input(&rec_buf[0], rec_buf.size()); } + BOTAN_ASSERT(state, "Handshake message recieved without state in place"); + while(true) { Handshake_Type type = HANDSHAKE_NONE; - MemoryVector<byte> contents; if(rec_type == HANDSHAKE) { - if(state->queue.size() >= 4) + if(state->handshake_reader()->have_full_record()) { - byte head[4] = { 0 }; - state->queue.peek(head, 4); - - const size_t length = make_u32bit(0, head[1], head[2], head[3]); - - if(state->queue.size() >= length + 4) - { - type = static_cast<Handshake_Type>(head[0]); - contents.resize(length); - state->queue.read(head, 4); - state->queue.read(&contents[0], contents.size()); - } + std::pair<Handshake_Type, MemoryVector<byte> > msg = + state->handshake_reader()->get_next_record(); + process_handshake_msg(msg.first, msg.second); } + else + break; } else if(rec_type == CHANGE_CIPHER_SPEC) { - if(state->queue.size() == 0 && rec_buf.size() == 1 && rec_buf[0] == 1) - type = HANDSHAKE_CCS; + if(state->handshake_reader()->empty() && rec_buf.size() == 1 && rec_buf[0] == 1) + process_handshake_msg(HANDSHAKE_CCS, MemoryVector<byte>()); else throw Decoding_Error("Malformed ChangeCipherSpec message"); } else throw Decoding_Error("Unknown message type in handshake processing"); - if(type == HANDSHAKE_NONE) - break; - - process_handshake_msg(type, contents); - - if(type == HANDSHAKE_CCS || !state) + if(type == HANDSHAKE_CCS || !state || !state->handshake_reader()->have_full_record()) break; } } diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h index c85b32ba0..9ab89e5f7 100644 --- a/src/tls/tls_channel.h +++ b/src/tls/tls_channel.h @@ -63,8 +63,8 @@ class BOTAN_DLL Channel std::vector<X509_Certificate> peer_cert_chain() const { return peer_certs; } Channel(std::function<void (const byte[], size_t)> socket_output_fn, - std::function<void (const byte[], size_t, Alert)> proc_fn, - std::function<bool (const Session&)> handshake_complete); + std::function<void (const byte[], size_t, Alert)> proc_fn, + std::function<bool (const Session&)> handshake_complete); virtual ~Channel(); protected: diff --git a/src/tls/tls_ciphersuite.cpp b/src/tls/tls_ciphersuite.cpp index 22815a048..89daaf679 100644 --- a/src/tls/tls_ciphersuite.cpp +++ b/src/tls/tls_ciphersuite.cpp @@ -309,7 +309,7 @@ std::string Ciphersuite::to_string() const { if(cipher_algo() == "3DES") out << "3DES_EDE"; - if(cipher_algo() == "Camellia") + else if(cipher_algo() == "Camellia") out << "CAMELLIA_" << std::to_string(8*cipher_keylen()); else out << replace_char(cipher_algo(), '-', '_'); diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index 2dd30819f..77e87e9bc 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -19,14 +19,14 @@ namespace TLS { * TLS Client Constructor */ Client::Client(std::function<void (const byte[], size_t)> output_fn, - std::function<void (const byte[], size_t, Alert)> proc_fn, - std::function<bool (const Session&)> handshake_fn, - Session_Manager& session_manager, - Credentials_Manager& creds, - const Policy& policy, - RandomNumberGenerator& rng, - const std::string& hostname, - std::function<std::string (std::vector<std::string>)> next_protocol) : + std::function<void (const byte[], size_t, Alert)> proc_fn, + std::function<bool (const Session&)> handshake_fn, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + const std::string& hostname, + std::function<std::string (std::vector<std::string>)> next_protocol) : Channel(output_fn, proc_fn, handshake_fn), policy(policy), rng(rng), @@ -35,7 +35,7 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn, { writer.set_version(Protocol_Version::SSL_V3); - state = new Handshake_State; + state = new Handshake_State(new Stream_Handshake_Reader); state->set_expected_next(SERVER_HELLO); state->client_npn_cb = next_protocol; @@ -54,6 +54,7 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn, state->client_hello = new Client_Hello( writer, state->hash, + policy, rng, session_info, send_npn_request); @@ -87,7 +88,7 @@ void Client::renegotiate() if(state) return; // currently in handshake - state = new Handshake_State; + state = new Handshake_State(new Stream_Handshake_Reader); state->set_expected_next(SERVER_HELLO); state->client_hello = new Client_Hello(writer, state->hash, policy, rng, @@ -112,7 +113,7 @@ void Client::alert_notify(const Alert& alert) * Process a handshake message */ void Client::process_handshake_msg(Handshake_Type type, - const MemoryRegion<byte>& contents) + const MemoryRegion<byte>& contents) { if(state == 0) throw Unexpected_Message("Unexpected handshake message from server"); @@ -135,12 +136,12 @@ void Client::process_handshake_msg(Handshake_Type type, return; } - state->set_expected_next(SERVER_HELLO); state->client_hello = new Client_Hello(writer, state->hash, policy, rng, secure_renegotiation.for_client_hello()); - secure_renegotiation.update(state->client_hello); + state->set_expected_next(SERVER_HELLO); + return; } @@ -173,17 +174,27 @@ void Client::process_handshake_msg(Handshake_Type type, "Server sent next protocol but we didn't request it"); } - state->version = state->server_hello->version(); + if(state->server_hello->supports_session_ticket()) + { + if(!state->client_hello->supports_session_ticket()) + throw TLS_Exception(Alert::HANDSHAKE_FAILURE, + "Server sent session ticket extension but we did not"); + } + + state->set_version(state->server_hello->version()); - writer.set_version(state->version); - reader.set_version(state->version); + writer.set_version(state->version()); + reader.set_version(state->version()); secure_renegotiation.update(state->server_hello); state->suite = Ciphersuite::by_id(state->server_hello->ciphersuite()); - if(!state->server_hello->session_id().empty() && - (state->server_hello->session_id() == state->client_hello->session_id())) + const bool server_returned_same_session_id = + !state->server_hello->session_id().empty() && + (state->server_hello->session_id() == state->client_hello->session_id()); + + if(server_returned_same_session_id) { // successful resumption @@ -199,19 +210,22 @@ void Client::process_handshake_msg(Handshake_Type type, state->resume_master_secret, true); - state->set_expected_next(HANDSHAKE_CCS); + if(state->server_hello->supports_session_ticket()) + state->set_expected_next(NEW_SESSION_TICKET); + else + state->set_expected_next(HANDSHAKE_CCS); } else { // new session - if(state->version > state->client_hello->version()) + if(state->version() > state->client_hello->version()) { throw TLS_Exception(Alert::HANDSHAKE_FAILURE, "Client: Server replied with bad version"); } - if(state->version < policy.min_version()) + if(state->version() < policy.min_version()) { throw TLS_Exception(Alert::PROTOCOL_VERSION, "Client: Server is too old for specified policy"); @@ -288,7 +302,7 @@ void Client::process_handshake_msg(Handshake_Type type, state->server_kex = new Server_Key_Exchange(contents, state->suite.kex_algo(), state->suite.sig_algo(), - state->version); + state->version()); if(state->suite.sig_algo() != "") { @@ -302,12 +316,10 @@ void Client::process_handshake_msg(Handshake_Type type, else if(type == CERTIFICATE_REQUEST) { state->set_expected_next(SERVER_HELLO_DONE); - state->cert_req = new Certificate_Req(contents, state->version); + state->cert_req = new Certificate_Req(contents, state->version()); } else if(type == SERVER_HELLO_DONE) { - state->set_expected_next(HANDSHAKE_CCS); - state->server_hello_done = new Server_Hello_Done(contents); if(state->received_handshake_msg(CERTIFICATE_REQUEST)) @@ -364,6 +376,17 @@ void Client::process_handshake_msg(Handshake_Type type, } state->client_finished = new Finished(writer, state, CLIENT); + + if(state->server_hello->supports_session_ticket()) + state->set_expected_next(NEW_SESSION_TICKET); + else + state->set_expected_next(HANDSHAKE_CCS); + } + else if(type == NEW_SESSION_TICKET) + { + state->new_session_ticket = new New_Session_Ticket(contents); + + state->set_expected_next(HANDSHAKE_CCS); } else if(type == HANDSHAKE_CCS) { @@ -394,8 +417,17 @@ void Client::process_handshake_msg(Handshake_Type type, state->client_finished = new Finished(writer, state, CLIENT); } + secure_renegotiation.update(state->client_finished, state->server_finished); + + MemoryVector<byte> session_id = state->server_hello->session_id(); + + const MemoryRegion<byte>& session_ticket = state->session_ticket(); + + if(session_id.empty() && !session_ticket.empty()) + session_id = make_hello_random(rng); + Session session_info( - state->server_hello->session_id(), + session_id, state->keys.master_secret(), state->server_hello->version(), state->server_hello->ciphersuite(), @@ -404,6 +436,7 @@ void Client::process_handshake_msg(Handshake_Type type, secure_renegotiation.supported(), state->server_hello->fragment_size(), peer_certs, + session_ticket, state->client_hello->sni_hostname(), "" ); @@ -413,8 +446,6 @@ void Client::process_handshake_msg(Handshake_Type type, else session_manager.remove_entry(session_info.session_id()); - secure_renegotiation.update(state->client_finished, state->server_finished); - delete state; state = 0; handshake_completed = true; diff --git a/src/tls/tls_extensions.cpp b/src/tls/tls_extensions.cpp index 7162dcf40..c0de24bfe 100644 --- a/src/tls/tls_extensions.cpp +++ b/src/tls/tls_extensions.cpp @@ -42,6 +42,9 @@ Extension* make_extension(TLS_Data_Reader& reader, case TLSEXT_NEXT_PROTOCOL: return new Next_Protocol_Notification(reader, size); + case TLSEXT_SESSION_TICKET: + return new Session_Ticket(reader, size); + default: return 0; // not known } @@ -501,6 +504,12 @@ Signature_Algorithms::Signature_Algorithms(TLS_Data_Reader& reader, } } +Session_Ticket::Session_Ticket(TLS_Data_Reader& reader, + u16bit extension_size) + { + m_ticket = reader.get_elem<byte, MemoryVector<byte> >(extension_size); + } + } } diff --git a/src/tls/tls_extensions.h b/src/tls/tls_extensions.h index 180216b8b..6a97d2560 100644 --- a/src/tls/tls_extensions.h +++ b/src/tls/tls_extensions.h @@ -1,6 +1,6 @@ /* * TLS Extensions -* (C) 2011 Jack Lloyd +* (C) 2011-2012 Jack Lloyd * * Released under the terms of the Botan license */ @@ -210,6 +210,39 @@ class Next_Protocol_Notification : public Extension std::vector<std::string> m_protocols; }; +class Session_Ticket : public Extension + { + public: + static Handshake_Extension_Type static_type() + { return TLSEXT_SESSION_TICKET; } + + Handshake_Extension_Type type() const { return static_type(); } + + const MemoryVector<byte>& contents() const { return m_ticket; } + + /** + * Create empty extension, used by both client and server + */ + Session_Ticket() {} + + /** + * Extension with ticket, used by client + */ + Session_Ticket(const MemoryRegion<byte>& session_ticket) : + m_ticket(session_ticket) {} + + /** + * Deserialize a session ticket + */ + Session_Ticket(TLS_Data_Reader& reader, u16bit extension_size); + + MemoryVector<byte> serialize() const { return m_ticket; } + + bool empty() const { return false; } + private: + MemoryVector<byte> m_ticket; + }; + /** * Supported Elliptic Curves Extension (RFC 4492) */ diff --git a/src/tls/tls_handshake_reader.cpp b/src/tls/tls_handshake_reader.cpp new file mode 100644 index 000000000..8278a2296 --- /dev/null +++ b/src/tls/tls_handshake_reader.cpp @@ -0,0 +1,66 @@ +/* +* TLS Handshake Reader +* (C) 2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#include <botan/internal/tls_handshake_reader.h> +#include <botan/exceptn.h> + +namespace Botan { + +namespace TLS { + +void Stream_Handshake_Reader::add_input(const byte record[], + size_t record_size) + { + m_queue.write(record, record_size); + } + +bool Stream_Handshake_Reader::empty() const + { + return m_queue.empty(); + } + +bool Stream_Handshake_Reader::have_full_record() const + { + if(m_queue.size() >= 4) + { + byte head[4] = { 0 }; + m_queue.peek(head, 4); + + const size_t length = make_u32bit(0, head[1], head[2], head[3]); + + return (m_queue.size() >= length + 4); + } + + return false; + } + +std::pair<Handshake_Type, MemoryVector<byte> > Stream_Handshake_Reader::get_next_record() + { + if(m_queue.size() >= 4) + { + byte head[4] = { 0 }; + m_queue.peek(head, 4); + + const size_t length = make_u32bit(0, head[1], head[2], head[3]); + + if(m_queue.size() >= length + 4) + { + Handshake_Type type = static_cast<Handshake_Type>(head[0]); + MemoryVector<byte> contents(length); + m_queue.read(head, 4); // discard + m_queue.read(&contents[0], contents.size()); + + return std::make_pair(type, contents); + } + } + + throw Internal_Error("Stream_Handshake_Reader::get_next_record called without a full record"); + } + +} + +} diff --git a/src/tls/tls_handshake_reader.h b/src/tls/tls_handshake_reader.h new file mode 100644 index 000000000..06a273ced --- /dev/null +++ b/src/tls/tls_handshake_reader.h @@ -0,0 +1,58 @@ +/* +* TLS Handshake Reader +* (C) 2012 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#ifndef BOTAN_TLS_HANDSHAKE_READER_H__ +#define BOTAN_TLS_HANDSHAKE_READER_H__ + +#include <botan/tls_magic.h> +#include <botan/secqueue.h> +#include <botan/loadstor.h> +#include <utility> + +namespace Botan { + +namespace TLS { + +/** +* Handshake Reader Interface +*/ +class Handshake_Reader + { + public: + virtual void add_input(const byte record[], size_t record_size) = 0; + + virtual bool empty() const = 0; + + virtual bool have_full_record() const = 0; + + virtual std::pair<Handshake_Type, MemoryVector<byte> > get_next_record() = 0; + + virtual ~Handshake_Reader() {} + }; + +/** +* Reader of TLS handshake messages +*/ +class Stream_Handshake_Reader : public Handshake_Reader + { + public: + void add_input(const byte record[], size_t record_size); + + bool empty() const; + + bool have_full_record() const; + + std::pair<Handshake_Type, MemoryVector<byte> > get_next_record(); + private: + SecureQueue m_queue; + }; + +} + +} + +#endif diff --git a/src/tls/tls_handshake_state.cpp b/src/tls/tls_handshake_state.cpp index 86e6e0b55..b34d8616d 100644 --- a/src/tls/tls_handshake_state.cpp +++ b/src/tls/tls_handshake_state.cpp @@ -54,12 +54,15 @@ u32bit bitmask_for_handshake_type(Handshake_Type type) case NEXT_PROTOCOL: return (1 << 9); - case HANDSHAKE_CCS: + case NEW_SESSION_TICKET: return (1 << 10); - case FINISHED: + case HANDSHAKE_CCS: return (1 << 11); + case FINISHED: + return (1 << 12); + // allow explicitly disabling new handshakes case HANDSHAKE_NONE: return 0; @@ -76,7 +79,7 @@ u32bit bitmask_for_handshake_type(Handshake_Type type) /* * Initialize the SSL/TLS Handshake State */ -Handshake_State::Handshake_State() +Handshake_State::Handshake_State(Handshake_Reader* reader) { client_hello = 0; server_hello = 0; @@ -85,6 +88,7 @@ Handshake_State::Handshake_State() cert_req = 0; server_hello_done = 0; next_protocol = 0; + new_session_ticket = 0; client_certs = 0; client_kex = 0; @@ -92,14 +96,21 @@ Handshake_State::Handshake_State() client_finished = 0; server_finished = 0; + m_handshake_reader = reader; + server_rsa_kex_key = 0; - version = Protocol_Version::SSL_V3; + m_version = Protocol_Version::SSL_V3; hand_expecting_mask = 0; hand_received_mask = 0; } +void Handshake_State::set_version(const Protocol_Version& version) + { + m_version = version; + } + void Handshake_State::confirm_transition_to(Handshake_Type handshake_msg) { const u32bit mask = bitmask_for_handshake_type(handshake_msg); @@ -132,17 +143,25 @@ bool Handshake_State::received_handshake_msg(Handshake_Type handshake_msg) const return (hand_received_mask & mask); } +const MemoryRegion<byte>& Handshake_State::session_ticket() const + { + if(new_session_ticket && !new_session_ticket->ticket().empty()) + return new_session_ticket->ticket(); + + return client_hello->session_ticket(); + } + KDF* Handshake_State::protocol_specific_prf() { - if(version == Protocol_Version::SSL_V3) + if(version() == Protocol_Version::SSL_V3) { return get_kdf("SSL3-PRF"); } - else if(version == Protocol_Version::TLS_V10 || version == Protocol_Version::TLS_V11) + else if(version() == Protocol_Version::TLS_V10 || version() == Protocol_Version::TLS_V11) { return get_kdf("TLS-PRF"); } - else if(version == Protocol_Version::TLS_V12) + else if(version() == Protocol_Version::TLS_V12) { if(suite.mac_algo() == "SHA-1" || suite.mac_algo() == "SHA-256") return get_kdf("TLS-12-PRF(SHA-256)"); @@ -150,7 +169,7 @@ KDF* Handshake_State::protocol_specific_prf() return get_kdf("TLS-12-PRF(" + suite.mac_algo() + ")"); } - throw Internal_Error("Unknown version code " + version.to_string()); + throw Internal_Error("Unknown version code " + version().to_string()); } std::pair<std::string, Signature_Format> @@ -175,15 +194,15 @@ Handshake_State::choose_sig_format(const Private_Key* key, } } - if(for_client_auth && this->version == Protocol_Version::SSL_V3) + if(for_client_auth && this->version() == Protocol_Version::SSL_V3) hash_algo = "Raw"; - if(hash_algo == "" && this->version == Protocol_Version::TLS_V12) + if(hash_algo == "" && this->version() == Protocol_Version::TLS_V12) hash_algo = "SHA-1"; // TLS 1.2 but no compatible hashes set (?) BOTAN_ASSERT(hash_algo != "", "Couldn't figure out hash to use"); - if(this->version >= Protocol_Version::TLS_V12) + if(this->version() >= Protocol_Version::TLS_V12) { hash_algo_out = hash_algo; sig_algo_out = sig_algo; @@ -221,7 +240,7 @@ Handshake_State::understand_sig_format(const Public_Key* key, Or not? */ - if(this->version < Protocol_Version::TLS_V12) + if(this->version() < Protocol_Version::TLS_V12) { if(hash_algo != "" || sig_algo != "") throw Decoding_Error("Counterparty sent hash/sig IDs with old version"); @@ -237,11 +256,11 @@ Handshake_State::understand_sig_format(const Public_Key* key, if(algo_name == "RSA") { - if(for_client_auth && this->version == Protocol_Version::SSL_V3) + if(for_client_auth && this->version() == Protocol_Version::SSL_V3) { hash_algo = "Raw"; } - else if(this->version < Protocol_Version::TLS_V12) + else if(this->version() < Protocol_Version::TLS_V12) { hash_algo = "TLS.Digest.0"; } @@ -251,11 +270,11 @@ Handshake_State::understand_sig_format(const Public_Key* key, } else if(algo_name == "DSA" || algo_name == "ECDSA") { - if(algo_name == "DSA" && for_client_auth && this->version == Protocol_Version::SSL_V3) + if(algo_name == "DSA" && for_client_auth && this->version() == Protocol_Version::SSL_V3) { hash_algo = "Raw"; } - else if(this->version < Protocol_Version::TLS_V12) + else if(this->version() < Protocol_Version::TLS_V12) { hash_algo = "SHA-1"; } @@ -280,12 +299,15 @@ Handshake_State::~Handshake_State() delete cert_req; delete server_hello_done; delete next_protocol; + delete new_session_ticket; delete client_certs; delete client_kex; delete client_verify; delete client_finished; delete server_finished; + + delete m_handshake_reader; } } diff --git a/src/tls/tls_handshake_state.h b/src/tls/tls_handshake_state.h index 4e1cb2f25..d2f64c43b 100644 --- a/src/tls/tls_handshake_state.h +++ b/src/tls/tls_handshake_state.h @@ -9,8 +9,8 @@ #define BOTAN_TLS_HANDSHAKE_STATE_H__ #include <botan/internal/tls_handshake_hash.h> +#include <botan/internal/tls_handshake_reader.h> #include <botan/internal/tls_session_key.h> -#include <botan/secqueue.h> #include <botan/pk_keys.h> #include <botan/pubkey.h> @@ -29,7 +29,7 @@ namespace TLS { class Handshake_State { public: - Handshake_State(); + Handshake_State(Handshake_Reader* reader); ~Handshake_State(); bool received_handshake_msg(Handshake_Type handshake_msg) const; @@ -37,6 +37,8 @@ class Handshake_State void confirm_transition_to(Handshake_Type handshake_msg); void set_expected_next(Handshake_Type handshake_msg); + const MemoryRegion<byte>& session_ticket() const; + std::pair<std::string, Signature_Format> understand_sig_format(const Public_Key* key, std::string hash_algo, @@ -51,7 +53,9 @@ class Handshake_State KDF* protocol_specific_prf(); - Protocol_Version version; + Protocol_Version version() const { return m_version; } + + void set_version(const Protocol_Version& version); class Client_Hello* client_hello; class Server_Hello* server_hello; @@ -65,6 +69,7 @@ class Handshake_State class Certificate_Verify* client_verify; class Next_Protocol* next_protocol; + class New_Session_Ticket* new_session_ticket; class Finished* client_finished; class Finished* server_finished; @@ -76,8 +81,6 @@ class Handshake_State Session_Keys keys; Handshake_Hash hash; - SecureQueue queue; - /* * Only used by clients for session resumption */ @@ -88,8 +91,11 @@ class Handshake_State */ std::function<std::string (std::vector<std::string>)> client_npn_cb; + Handshake_Reader* handshake_reader() { return m_handshake_reader; } private: + Handshake_Reader* m_handshake_reader; u32bit hand_expecting_mask, hand_received_mask; + Protocol_Version m_version; }; } diff --git a/src/tls/tls_magic.h b/src/tls/tls_magic.h index 72a430bf2..0e45407d3 100644 --- a/src/tls/tls_magic.h +++ b/src/tls/tls_magic.h @@ -36,23 +36,24 @@ enum Record_Type { }; enum Handshake_Type { - HELLO_REQUEST = 0, - CLIENT_HELLO = 1, - CLIENT_HELLO_SSLV2 = 200, // Not a wire value - SERVER_HELLO = 2, - NEW_SESSION_TICKET = 4, // RFC 5077 - CERTIFICATE = 11, - SERVER_KEX = 12, - CERTIFICATE_REQUEST = 13, - SERVER_HELLO_DONE = 14, - CERTIFICATE_VERIFY = 15, - CLIENT_KEX = 16, - FINISHED = 20, - - NEXT_PROTOCOL = 67, - - HANDSHAKE_CCS = 100, // Not a wire value - HANDSHAKE_NONE = 255 // Null value + HELLO_REQUEST = 0, + CLIENT_HELLO = 1, + CLIENT_HELLO_SSLV2 = 253, // Not a wire value + SERVER_HELLO = 2, + HELLO_VERIFY_REQUEST = 3, + NEW_SESSION_TICKET = 4, // RFC 5077 + CERTIFICATE = 11, + SERVER_KEX = 12, + CERTIFICATE_REQUEST = 13, + SERVER_HELLO_DONE = 14, + CERTIFICATE_VERIFY = 15, + CLIENT_KEX = 16, + FINISHED = 20, + + NEXT_PROTOCOL = 67, + + HANDSHAKE_CCS = 254, // Not a wire value + HANDSHAKE_NONE = 255 // Null value }; enum Ciphersuite_Code { diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h index 7162ece1a..2f8af5fd2 100644 --- a/src/tls/tls_messages.h +++ b/src/tls/tls_messages.h @@ -33,19 +33,39 @@ class Record_Reader; class Handshake_Message { public: - void send(Record_Writer& writer, Handshake_Hash& hash) const; - + virtual MemoryVector<byte> serialize() const = 0; virtual Handshake_Type type() const = 0; + Handshake_Message() {} virtual ~Handshake_Message() {} private: + Handshake_Message(const Handshake_Message&) {} Handshake_Message& operator=(const Handshake_Message&) { return (*this); } - virtual MemoryVector<byte> serialize() const = 0; }; MemoryVector<byte> make_hello_random(RandomNumberGenerator& rng); /** +* DTLS Hello Verify Request +*/ +class Hello_Verify_Request : public Handshake_Message + { + public: + MemoryVector<byte> serialize() const; + Handshake_Type type() const { return HELLO_VERIFY_REQUEST; } + + MemoryVector<byte> cookie() const { return m_cookie; } + + Hello_Verify_Request(const MemoryRegion<byte>& buf); + + Hello_Verify_Request(const MemoryVector<byte>& client_hello_bits, + const std::string& client_identity, + const SymmetricKey& secret_key); + private: + MemoryVector<byte> m_cookie; + }; + +/** * Client Hello Message */ class Client_Hello : public Handshake_Message @@ -88,6 +108,11 @@ class Client_Hello : public Handshake_Message size_t fragment_size() const { return m_fragment_size; } + bool supports_session_ticket() const { return m_supports_session_ticket; } + + const MemoryRegion<byte>& session_ticket() const + { return m_session_ticket; } + Client_Hello(Record_Writer& writer, Handshake_Hash& hash, const Policy& policy, @@ -99,6 +124,7 @@ class Client_Hello : public Handshake_Message Client_Hello(Record_Writer& writer, Handshake_Hash& hash, + const Policy& policy, RandomNumberGenerator& rng, const Session& resumed_session, bool next_protocol = false); @@ -125,6 +151,9 @@ class Client_Hello : public Handshake_Message std::vector<std::pair<std::string, std::string> > m_supported_algos; std::vector<std::string> m_supported_curves; + + bool m_supports_session_ticket; + MemoryVector<byte> m_session_ticket; }; /** @@ -150,6 +179,8 @@ class Server_Hello : public Handshake_Message bool next_protocol_notification() const { return m_next_protocol; } + bool supports_session_ticket() const { return m_supports_session_ticket; } + const std::vector<std::string>& next_protocols() const { return m_next_protocols; } @@ -166,6 +197,7 @@ class Server_Hello : public Handshake_Message const Client_Hello& other, const std::vector<std::string>& available_cert_types, const Policy& policies, + bool have_session_ticket_key, bool client_has_secure_renegotiation, const MemoryRegion<byte>& reneg_info, bool client_has_npn, @@ -181,6 +213,7 @@ class Server_Hello : public Handshake_Message size_t max_fragment_size, bool client_has_secure_renegotiation, const MemoryRegion<byte>& reneg_info, + bool client_supports_session_tickets, bool client_has_npn, const std::vector<std::string>& next_protocols, RandomNumberGenerator& rng); @@ -200,6 +233,7 @@ class Server_Hello : public Handshake_Message bool m_next_protocol; std::vector<std::string> m_next_protocols; + bool m_supports_session_ticket; }; /** @@ -238,10 +272,10 @@ class Certificate : public Handshake_Message { public: Handshake_Type type() const { return CERTIFICATE; } - const std::vector<X509_Certificate>& cert_chain() const { return certs; } + const std::vector<X509_Certificate>& cert_chain() const { return m_certs; } - size_t count() const { return certs.size(); } - bool empty() const { return certs.empty(); } + size_t count() const { return m_certs.size(); } + bool empty() const { return m_certs.empty(); } Certificate(Record_Writer& writer, Handshake_Hash& hash, @@ -251,7 +285,7 @@ class Certificate : public Handshake_Message private: MemoryVector<byte> serialize() const; - std::vector<X509_Certificate> certs; + std::vector<X509_Certificate> m_certs; }; /** @@ -434,6 +468,30 @@ class Next_Protocol : public Handshake_Message std::string m_protocol; }; +class New_Session_Ticket : public Handshake_Message + { + public: + Handshake_Type type() const { return NEW_SESSION_TICKET; } + + u32bit ticket_lifetime_hint() const { return m_ticket_lifetime_hint; } + const MemoryVector<byte>& ticket() const { return m_ticket; } + + New_Session_Ticket(Record_Writer& writer, + Handshake_Hash& hash, + const MemoryRegion<byte>& ticket, + u32bit lifetime = 0); + + New_Session_Ticket(Record_Writer& writer, + Handshake_Hash& hash); + + New_Session_Ticket(const MemoryRegion<byte>& buf); + private: + MemoryVector<byte> serialize() const; + + u32bit m_ticket_lifetime_hint; + MemoryVector<byte> m_ticket; + }; + } } diff --git a/src/tls/tls_policy.cpp b/src/tls/tls_policy.cpp index 49f74975b..1ab55f7c6 100644 --- a/src/tls/tls_policy.cpp +++ b/src/tls/tls_policy.cpp @@ -89,6 +89,16 @@ std::vector<std::string> Policy::allowed_ecc_curves() const return curves; } +Protocol_Version Policy::min_version() const + { + return Protocol_Version::SSL_V3; + } + +Protocol_Version Policy::pref_version() const + { + return Protocol_Version::TLS_V12; + } + namespace { class Ciphersuite_Preference_Ordering diff --git a/src/tls/tls_policy.h b/src/tls/tls_policy.h index cd00331a5..f53b9bab6 100644 --- a/src/tls/tls_policy.h +++ b/src/tls/tls_policy.h @@ -97,14 +97,12 @@ class BOTAN_DLL Policy /** * @return the minimum version that we are willing to negotiate */ - virtual Protocol_Version min_version() const - { return Protocol_Version::SSL_V3; } + virtual Protocol_Version min_version() const; /** * @return the version we would prefer to negotiate */ - virtual Protocol_Version pref_version() const - { return Protocol_Version::TLS_V12; } + virtual Protocol_Version pref_version() const; /** * Return allowed ciphersuites, in order of preference diff --git a/src/tls/tls_reader.h b/src/tls/tls_reader.h index 8c2e9efe2..bf8098bed 100644 --- a/src/tls/tls_reader.h +++ b/src/tls/tls_reader.h @@ -50,6 +50,15 @@ class TLS_Data_Reader offset += bytes; } + u16bit get_u32bit() + { + assert_at_least(4); + u16bit result = make_u32bit(buf[offset ], buf[offset+1], + buf[offset+2], buf[offset+3]); + offset += 4; + return result; + } + u16bit get_u16bit() { assert_at_least(2); diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h index 6634810df..b966e3c72 100644 --- a/src/tls/tls_record.h +++ b/src/tls/tls_record.h @@ -33,6 +33,8 @@ class BOTAN_DLL Record_Writer void send(byte type, const byte input[], size_t length); void send(byte type, byte val) { send(type, &val, 1); } + MemoryVector<byte> send(class Handshake_Message& msg); + void send_alert(const Alert& alert); void activate(Connection_Side side, diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp index 129d5346d..3b4a526cc 100644 --- a/src/tls/tls_server.cpp +++ b/src/tls/tls_server.cpp @@ -20,16 +20,35 @@ namespace { bool check_for_resume(Session& session_info, Session_Manager& session_manager, + Credentials_Manager& credentials, Client_Hello* client_hello) { - MemoryVector<byte> client_session_id = client_hello->session_id(); + const MemoryVector<byte>& client_session_id = client_hello->session_id(); + const MemoryVector<byte>& session_ticket = client_hello->session_ticket(); - if(client_session_id.empty()) // not resuming - return false; + if(session_ticket.empty()) + { + if(client_session_id.empty()) // not resuming + return false; - // not found - if(!session_manager.load_from_session_id(client_session_id, session_info)) - return false; + // not found + if(!session_manager.load_from_session_id(client_session_id, session_info)) + return false; + } + else + { + // If a session ticket was sent, ignore client session ID + try + { + session_info = Session::decrypt( + session_ticket, + credentials.psk("tls-server", "session-ticket", "")); + } + catch(...) + { + return false; + } + } // wrong version if(client_hello->version() != session_info.version()) @@ -45,14 +64,14 @@ bool check_for_resume(Session& session_info, session_info.compression_method())) return false; - // client sent a different SRP identity (!!!) + // client sent a different SRP identity if(client_hello->srp_identifier() != "") { if(client_hello->srp_identifier() != session_info.srp_identifier()) return false; } - // client sent a different SNI hostname (!!!) + // client sent a different SNI hostname if(client_hello->sni_hostname() != "") { if(client_hello->sni_hostname() != session_info.sni_hostname()) @@ -112,7 +131,7 @@ void Server::renegotiate() if(state) return; // currently in handshake - state = new Handshake_State; + state = new Handshake_State(new Stream_Handshake_Reader); state->set_expected_next(CLIENT_HELLO); Hello_Request hello_req(writer); } @@ -137,7 +156,7 @@ void Server::read_handshake(byte rec_type, { if(rec_type == HANDSHAKE && !state) { - state = new Handshake_State; + state = new Handshake_State(new Stream_Handshake_Reader); state->set_expected_next(CLIENT_HELLO); } @@ -183,20 +202,30 @@ void Server::process_handshake_msg(Handshake_Type type, "Client version is unacceptable by policy"); if(client_version <= policy.pref_version()) - state->version = client_version; + state->set_version(client_version); else - state->version = policy.pref_version(); + state->set_version(policy.pref_version()); secure_renegotiation.update(state->client_hello); - writer.set_version(state->version); - reader.set_version(state->version); + writer.set_version(state->version()); + reader.set_version(state->version()); Session session_info; const bool resuming = check_for_resume(session_info, session_manager, + creds, state->client_hello); + bool have_session_ticket_key = false; + + try + { + have_session_ticket_key = + creds.psk("tls-server", "session-ticket", "").length() > 0; + } + catch(...) {} + if(resuming) { // resume session @@ -204,13 +233,14 @@ void Server::process_handshake_msg(Handshake_Type type, state->server_hello = new Server_Hello( writer, state->hash, - session_info.session_id(), + state->client_hello->session_id(), Protocol_Version(session_info.version()), session_info.ciphersuite_code(), session_info.compression_method(), session_info.fragment_size(), secure_renegotiation.supported(), secure_renegotiation.for_server_hello(), + state->client_hello->supports_session_ticket() && have_session_ticket_key, state->client_hello->next_protocol_notification(), m_possible_protocols, rng); @@ -225,6 +255,31 @@ void Server::process_handshake_msg(Handshake_Type type, state->keys = Session_Keys(state, session_info.master_secret(), true); + if(!handshake_fn(session_info)) + { + if(state->server_hello->supports_session_ticket()) + state->new_session_ticket = new New_Session_Ticket(writer, state->hash); + else + session_manager.remove_entry(session_info.session_id()); + } + + // FIXME: should only send a new ticket if we need too (eg old session) + if(state->server_hello->supports_session_ticket() && !state->new_session_ticket) + { + try + { + const SymmetricKey ticket_key = creds.psk("tls-server", "session-ticket", ""); + + state->new_session_ticket = + new New_Session_Ticket(writer, state->hash, + session_info.encrypt(ticket_key, rng)); + } + catch(...) {} + + if(!state->new_session_ticket) + state->new_session_ticket = new New_Session_Ticket(writer, state->hash); + } + writer.send(CHANGE_CIPHER_SPEC, 1); writer.activate(SERVER, state->suite, state->keys, @@ -232,9 +287,6 @@ void Server::process_handshake_msg(Handshake_Type type, state->server_finished = new Finished(writer, state, SERVER); - if(!handshake_fn(session_info)) - session_manager.remove_entry(session_info.session_id()); - state->set_expected_next(HANDSHAKE_CCS); } else // new session @@ -261,10 +313,11 @@ void Server::process_handshake_msg(Handshake_Type type, state->server_hello = new Server_Hello( writer, state->hash, - state->version, + state->version(), *(state->client_hello), available_cert_types, policy, + have_session_ticket_key, secure_renegotiation.supported(), secure_renegotiation.for_server_hello(), state->client_hello->next_protocol_notification(), @@ -323,7 +376,7 @@ void Server::process_handshake_msg(Handshake_Type type, state->hash, policy, client_auth_CAs, - state->version); + state->version()); state->set_expected_next(CERTIFICATE); } @@ -364,7 +417,7 @@ void Server::process_handshake_msg(Handshake_Type type, } else if(type == CERTIFICATE_VERIFY) { - state->client_verify = new Certificate_Verify(contents, state->version); + state->client_verify = new Certificate_Verify(contents, state->version()); const std::vector<X509_Certificate>& client_certs = state->client_certs->cert_chain(); @@ -421,11 +474,48 @@ void Server::process_handshake_msg(Handshake_Type type, throw TLS_Exception(Alert::DECRYPT_ERROR, "Finished message didn't verify"); - // already sent it if resuming if(!state->server_finished) { + // already sent finished if resuming, so this is a new session + state->hash.update(type, contents); + Session session_info( + state->server_hello->session_id(), + state->keys.master_secret(), + state->server_hello->version(), + state->server_hello->ciphersuite(), + state->server_hello->compression_method(), + SERVER, + secure_renegotiation.supported(), + state->server_hello->fragment_size(), + peer_certs, + MemoryVector<byte>(), + m_hostname, + "" + ); + + if(handshake_fn(session_info)) + { + if(state->server_hello->supports_session_ticket()) + { + try + { + const SymmetricKey ticket_key = creds.psk("tls-server", "session-ticket", ""); + + state->new_session_ticket = + new New_Session_Ticket(writer, state->hash, + session_info.encrypt(ticket_key, rng)); + } + catch(...) {} + } + else + session_manager.save(session_info); + } + + if(state->server_hello->supports_session_ticket() && !state->new_session_ticket) + state->new_session_ticket = new New_Session_Ticket(writer, state->hash); + writer.send(CHANGE_CIPHER_SPEC, 1); writer.activate(SERVER, state->suite, state->keys, @@ -437,25 +527,6 @@ void Server::process_handshake_msg(Handshake_Type type, peer_certs = state->client_certs->cert_chain(); } - Session session_info( - state->server_hello->session_id(), - state->keys.master_secret(), - state->server_hello->version(), - state->server_hello->ciphersuite(), - state->server_hello->compression_method(), - SERVER, - secure_renegotiation.supported(), - state->server_hello->fragment_size(), - peer_certs, - m_hostname, - "" - ); - - if(handshake_fn(session_info)) - session_manager.save(session_info); - else - session_manager.remove_entry(session_info.session_id()); - secure_renegotiation.update(state->client_finished, state->server_finished); diff --git a/src/tls/tls_session.cpp b/src/tls/tls_session.cpp index b27409dfa..44689b510 100644 --- a/src/tls/tls_session.cpp +++ b/src/tls/tls_session.cpp @@ -1,6 +1,6 @@ /* * TLS Session State -* (C) 2011 Jack Lloyd +* (C) 2011-2012 Jack Lloyd * * Released under the terms of the Botan license */ @@ -10,6 +10,9 @@ #include <botan/ber_dec.h> #include <botan/asn1_str.h> #include <botan/pem.h> +#include <botan/lookup.h> +#include <botan/loadstor.h> +#include <memory> namespace Botan { @@ -28,6 +31,7 @@ Session::Session(const MemoryRegion<byte>& session_identifier, const std::string& srp_identifier) : m_start_time(std::chrono::system_clock::now()), m_identifier(session_identifier), + m_session_ticket(ticket), m_master_secret(master_secret), m_version(version), m_ciphersuite(ciphersuite), @@ -41,10 +45,15 @@ Session::Session(const MemoryRegion<byte>& session_identifier, { } -Session::Session(const byte ber[], size_t ber_len) +Session::Session(const std::string& pem) { - BER_Decoder decoder(ber, ber_len); + SecureVector<byte> der = PEM_Code::decode_check_label(pem, "SSL SESSION"); + *this = Session(&der[0], der.size()); + } + +Session::Session(const byte ber[], size_t ber_len) + { byte side_code = 0; ASN1_String sni_hostname_str; ASN1_String srp_identifier_str; @@ -59,10 +68,11 @@ Session::Session(const byte ber[], size_t ber_len) .start_cons(SEQUENCE) .decode_and_check(static_cast<size_t>(TLS_SESSION_PARAM_STRUCT_VERSION), "Unknown version in session structure") - .decode(m_identifier, OCTET_STRING) - .decode_integer_type(start_time) + .decode_integer_type(m_start_time) .decode_integer_type(major_version) .decode_integer_type(minor_version) + .decode(m_identifier, OCTET_STRING) + .decode(m_session_ticket, OCTET_STRING) .decode_integer_type(m_ciphersuite) .decode_integer_type(m_compression_method) .decode_integer_type(side_code) @@ -90,13 +100,6 @@ Session::Session(const byte ber[], size_t ber_len) } } -Session::Session(const std::string& pem) - { - SecureVector<byte> der = PEM_Code::decode_check_label(pem, "SSL SESSION"); - - *this = Session(&der[0], der.size()); - } - SecureVector<byte> Session::DER_encode() const { MemoryVector<byte> peer_cert_bits; @@ -106,10 +109,11 @@ SecureVector<byte> Session::DER_encode() const return DER_Encoder() .start_cons(SEQUENCE) .encode(static_cast<size_t>(TLS_SESSION_PARAM_STRUCT_VERSION)) - .encode(m_identifier, OCTET_STRING) .encode(static_cast<size_t>(std::chrono::system_clock::to_time_t(m_start_time))) .encode(static_cast<size_t>(m_version.major_version())) .encode(static_cast<size_t>(m_version.minor_version())) + .encode(m_identifier, OCTET_STRING) + .encode(m_session_ticket, OCTET_STRING) .encode(static_cast<size_t>(m_ciphersuite)) .encode(static_cast<size_t>(m_compression_method)) .encode(static_cast<size_t>(m_connection_side)) @@ -128,6 +132,113 @@ std::string Session::PEM_encode() const return PEM_Code::encode(this->DER_encode(), "SSL SESSION"); } +namespace { + +const u32bit SESSION_CRYPTO_MAGIC = 0x571B0E4E; +const std::string SESSION_CRYPTO_CIPHER = "AES-256/CBC"; +const std::string SESSION_CRYPTO_MAC = "HMAC(SHA-256)"; +const std::string SESSION_CRYPTO_KDF = "KDF2(SHA-256)"; + +const size_t MAGIC_LENGTH = 4; +const size_t MAC_KEY_LENGTH = 32; +const size_t CIPHER_KEY_LENGTH = 32; +const size_t CIPHER_IV_LENGTH = 16; +const size_t MAC_OUTPUT_LENGTH = 32; + } +MemoryVector<byte> +Session::encrypt(const SymmetricKey& master_key, + RandomNumberGenerator& rng) const + { + std::auto_ptr<KDF> kdf(get_kdf(SESSION_CRYPTO_KDF)); + + SymmetricKey cipher_key = + kdf->derive_key(CIPHER_KEY_LENGTH, + master_key.bits_of(), + "tls.session.cipher-key"); + + SymmetricKey mac_key = + kdf->derive_key(MAC_KEY_LENGTH, + master_key.bits_of(), + "tls.session.mac-key"); + + InitializationVector cipher_iv(rng, 16); + + std::auto_ptr<MessageAuthenticationCode> mac(get_mac(SESSION_CRYPTO_MAC)); + mac->set_key(mac_key); + + Pipe pipe(get_cipher(SESSION_CRYPTO_CIPHER, cipher_key, cipher_iv, ENCRYPTION)); + pipe.process_msg(this->DER_encode()); + MemoryVector<byte> ctext = pipe.read_all(0); + + MemoryVector<byte> out(MAGIC_LENGTH); + store_be(SESSION_CRYPTO_MAGIC, &out[0]); + out += cipher_iv.bits_of(); + out += ctext; + + mac->update(out); + + out += mac->final(); + return out; + } + +Session Session::decrypt(const byte buf[], size_t buf_len, + const SymmetricKey& master_key) + { + try + { + const size_t MIN_CTEXT_SIZE = 4 * 16; // due to 48 byte master secret + + if(buf_len < (MAGIC_LENGTH + + CIPHER_IV_LENGTH + + MIN_CTEXT_SIZE + + MAC_OUTPUT_LENGTH)) + throw Decoding_Error("Encrypted TLS session too short to be valid"); + + if(load_be<u32bit>(buf, 0) != SESSION_CRYPTO_MAGIC) + throw Decoding_Error("Unknown header value in encrypted session"); + + std::auto_ptr<KDF> kdf(get_kdf(SESSION_CRYPTO_KDF)); + + SymmetricKey mac_key = + kdf->derive_key(MAC_KEY_LENGTH, + master_key.bits_of(), + "tls.session.mac-key"); + + std::auto_ptr<MessageAuthenticationCode> mac(get_mac(SESSION_CRYPTO_MAC)); + mac->set_key(mac_key); + + mac->update(&buf[0], buf_len - MAC_OUTPUT_LENGTH); + MemoryVector<byte> computed_mac = mac->final(); + + if(!same_mem(&buf[buf_len - MAC_OUTPUT_LENGTH], &computed_mac[0], computed_mac.size())) + throw Decoding_Error("MAC verification failed for encrypted session"); + + SymmetricKey cipher_key = + kdf->derive_key(CIPHER_KEY_LENGTH, + master_key.bits_of(), + "tls.session.cipher-key"); + + InitializationVector cipher_iv(&buf[MAGIC_LENGTH], CIPHER_IV_LENGTH); + + const size_t CTEXT_OFFSET = MAGIC_LENGTH + CIPHER_IV_LENGTH; + + Pipe pipe(get_cipher(SESSION_CRYPTO_CIPHER, cipher_key, cipher_iv, DECRYPTION)); + pipe.process_msg(&buf[CTEXT_OFFSET], + buf_len - (MAC_OUTPUT_LENGTH + CTEXT_OFFSET)); + SecureVector<byte> ber = pipe.read_all(); + + return Session(&ber[0], ber.size()); + } + catch(std::exception& e) + { + throw Decoding_Error("Failed to decrypt encrypted session -" + + std::string(e.what())); + } + } + } + +} + diff --git a/src/tls/tls_session.h b/src/tls/tls_session.h index 82c202ebe..8fc048c75 100644 --- a/src/tls/tls_session.h +++ b/src/tls/tls_session.h @@ -1,6 +1,6 @@ /* * TLS Session -* (C) 2011 Jack Lloyd +* (C) 2011-2012 Jack Lloyd * * Released under the terms of the Botan license */ @@ -13,6 +13,7 @@ #include <botan/tls_ciphersuite.h> #include <botan/tls_magic.h> #include <botan/secmem.h> +#include <botan/symkey.h> #include <chrono> namespace Botan { @@ -51,6 +52,7 @@ class BOTAN_DLL Session bool secure_renegotiation_supported, size_t fragment_size, const std::vector<X509_Certificate>& peer_certs, + const MemoryRegion<byte>& session_ticket, const std::string& sni_hostname = "", const std::string& srp_identifier = ""); @@ -72,6 +74,34 @@ class BOTAN_DLL Session SecureVector<byte> DER_encode() const; /** + * Encrypt a session (useful for serialization or session tickets) + */ + MemoryVector<byte> encrypt(const SymmetricKey& key, + RandomNumberGenerator& rng) const; + + + /** + * Decrypt a session created by encrypt + * @param ctext the ciphertext returned by encrypt + * @param ctext_size the size of ctext in bytes + * @param key the same key used by the encrypting side + */ + static Session decrypt(const byte ctext[], + size_t ctext_size, + const SymmetricKey& key); + + /** + * Decrypt a session created by encrypt + * @param ctext the ciphertext returned by encrypt + * @param key the same key used by the encrypting side + */ + static inline Session decrypt(const MemoryRegion<byte>& ctext, + const SymmetricKey& key) + { + return Session::decrypt(&ctext[0], ctext.size(), key); + } + + /** * Encode this session data for storage * @warning if the master secret is compromised so is the * session traffic @@ -148,12 +178,18 @@ class BOTAN_DLL Session std::chrono::system_clock::time_point start_time() const { return m_start_time; } + /** + * Return the session ticket the server gave us + */ + const MemoryVector<byte>& session_ticket() const { return m_session_ticket; } + private: - enum { TLS_SESSION_PARAM_STRUCT_VERSION = 1 }; + enum { TLS_SESSION_PARAM_STRUCT_VERSION = 0x2994e300 }; std::chrono::system_clock::time_point m_start_time; MemoryVector<byte> m_identifier; + MemoryVector<byte> m_session_ticket; // only used by client side SecureVector<byte> m_master_secret; Protocol_Version m_version; diff --git a/src/tls/tls_session_key.cpp b/src/tls/tls_session_key.cpp index 0f520d140..4d7603ce1 100644 --- a/src/tls/tls_session_key.cpp +++ b/src/tls/tls_session_key.cpp @@ -47,7 +47,7 @@ Session_Keys::Session_Keys(Handshake_State* state, { SecureVector<byte> salt; - if(state->version != Protocol_Version::SSL_V3) + if(state->version() != Protocol_Version::SSL_V3) salt += std::make_pair(MASTER_SECRET_MAGIC, sizeof(MASTER_SECRET_MAGIC)); salt += state->client_hello->random(); @@ -57,7 +57,7 @@ Session_Keys::Session_Keys(Handshake_State* state, } SecureVector<byte> salt; - if(state->version != Protocol_Version::SSL_V3) + if(state->version() != Protocol_Version::SSL_V3) salt += std::make_pair(KEY_GEN_MAGIC, sizeof(KEY_GEN_MAGIC)); salt += state->server_hello->random(); salt += state->client_hello->random(); |