diff options
-rw-r--r-- | doc/examples/tls_server.cpp | 3 | ||||
-rw-r--r-- | src/tls/c_kex.cpp | 4 | ||||
-rw-r--r-- | src/tls/cert_req.cpp | 10 | ||||
-rw-r--r-- | src/tls/cert_ver.cpp | 4 | ||||
-rw-r--r-- | src/tls/finished.cpp | 10 | ||||
-rw-r--r-- | src/tls/hello.cpp | 22 | ||||
-rw-r--r-- | src/tls/info.txt | 1 | ||||
-rw-r--r-- | src/tls/s_kex.cpp | 12 | ||||
-rw-r--r-- | src/tls/tls_channel.cpp | 8 | ||||
-rw-r--r-- | src/tls/tls_client.cpp | 2 | ||||
-rw-r--r-- | src/tls/tls_messages.h | 62 | ||||
-rw-r--r-- | src/tls/tls_server.cpp | 92 | ||||
-rw-r--r-- | src/tls/tls_server.h | 3 | ||||
-rw-r--r-- | src/tls/tls_session_key.cpp | 59 | ||||
-rw-r--r-- | src/tls/tls_session_key.h | 23 | ||||
-rw-r--r-- | src/tls/tls_session_state.h | 128 |
16 files changed, 280 insertions, 163 deletions
diff --git a/doc/examples/tls_server.cpp b/doc/examples/tls_server.cpp index 62bc8fadc..eff3a3c3c 100644 --- a/doc/examples/tls_server.cpp +++ b/doc/examples/tls_server.cpp @@ -64,6 +64,8 @@ int main(int argc, char* argv[]) Server_TLS_Policy policy; + TLS_Session_Manager_In_Memory sessions; + while(true) { try { @@ -76,6 +78,7 @@ int main(int argc, char* argv[]) TLS_Server tls( std::tr1::bind(&Socket::write, std::tr1::ref(sock), _1, _2), proc_data, + sessions, policy, rng, cert, diff --git a/src/tls/c_kex.cpp b/src/tls/c_kex.cpp index 0f20b819c..b55973ca3 100644 --- a/src/tls/c_kex.cpp +++ b/src/tls/c_kex.cpp @@ -75,11 +75,11 @@ Client_Key_Exchange::Client_Key_Exchange(const MemoryRegion<byte>& contents, /** * Serialize a Client Key Exchange message */ -SecureVector<byte> Client_Key_Exchange::serialize() const +MemoryVector<byte> Client_Key_Exchange::serialize() const { if(include_length) { - SecureVector<byte> buf; + MemoryVector<byte> buf; append_tls_length_value(buf, key_material, 2); return buf; } diff --git a/src/tls/cert_req.cpp b/src/tls/cert_req.cpp index b8b2624bf..74398a59f 100644 --- a/src/tls/cert_req.cpp +++ b/src/tls/cert_req.cpp @@ -34,9 +34,9 @@ Certificate_Req::Certificate_Req(Record_Writer& writer, /** * Serialize a Certificate Request message */ -SecureVector<byte> Certificate_Req::serialize() const +MemoryVector<byte> Certificate_Req::serialize() const { - SecureVector<byte> buf; + MemoryVector<byte> buf; append_tls_length_value(buf, types, 1); @@ -94,13 +94,13 @@ Certificate::Certificate(Record_Writer& writer, /** * Serialize a Certificate message */ -SecureVector<byte> Certificate::serialize() const +MemoryVector<byte> Certificate::serialize() const { - SecureVector<byte> buf(3); + MemoryVector<byte> buf(3); for(size_t i = 0; i != certs.size(); ++i) { - SecureVector<byte> raw_cert = certs[i].BER_encode(); + MemoryVector<byte> raw_cert = certs[i].BER_encode(); const size_t cert_size = raw_cert.size(); for(size_t i = 0; i != 3; ++i) buf.push_back(get_byte<u32bit>(i+1, cert_size)); diff --git a/src/tls/cert_ver.cpp b/src/tls/cert_ver.cpp index 3220a8c9e..0d8256e5e 100644 --- a/src/tls/cert_ver.cpp +++ b/src/tls/cert_ver.cpp @@ -46,9 +46,9 @@ Certificate_Verify::Certificate_Verify(RandomNumberGenerator& rng, /** * Serialize a Certificate Verify message */ -SecureVector<byte> Certificate_Verify::serialize() const +MemoryVector<byte> Certificate_Verify::serialize() const { - SecureVector<byte> buf; + MemoryVector<byte> buf; const u16bit sig_len = signature.size(); buf.push_back(get_byte(0, sig_len)); diff --git a/src/tls/finished.cpp b/src/tls/finished.cpp index d76fbd884..dff977d31 100644 --- a/src/tls/finished.cpp +++ b/src/tls/finished.cpp @@ -25,7 +25,7 @@ Finished::Finished(Record_Writer& writer, /** * Serialize a Finished message */ -SecureVector<byte> Finished::serialize() const +MemoryVector<byte> Finished::serialize() const { return verification_data; } @@ -44,7 +44,7 @@ void Finished::deserialize(const MemoryRegion<byte>& buf) bool Finished::verify(const MemoryRegion<byte>& secret, Version_Code version, const HandshakeHash& hash, Connection_Side side) { - SecureVector<byte> computed = compute_verify(secret, hash, side, version); + MemoryVector<byte> computed = compute_verify(secret, hash, side, version); if(computed == verification_data) return true; return false; @@ -53,7 +53,7 @@ bool Finished::verify(const MemoryRegion<byte>& secret, Version_Code version, /** * Compute the verify_data */ -SecureVector<byte> Finished::compute_verify(const MemoryRegion<byte>& secret, +MemoryVector<byte> Finished::compute_verify(const MemoryRegion<byte>& secret, HandshakeHash hash, Connection_Side side, Version_Code version) @@ -63,7 +63,7 @@ SecureVector<byte> Finished::compute_verify(const MemoryRegion<byte>& secret, const byte SSL_CLIENT_LABEL[] = { 0x43, 0x4C, 0x4E, 0x54 }; const byte SSL_SERVER_LABEL[] = { 0x53, 0x52, 0x56, 0x52 }; - SecureVector<byte> ssl3_finished; + MemoryVector<byte> ssl3_finished; if(side == CLIENT) hash.update(SSL_CLIENT_LABEL, sizeof(SSL_CLIENT_LABEL)); @@ -84,7 +84,7 @@ SecureVector<byte> Finished::compute_verify(const MemoryRegion<byte>& secret, TLS_PRF prf; - SecureVector<byte> input; + MemoryVector<byte> input; if(side == CLIENT) input += std::make_pair(TLS_CLIENT_LABEL, sizeof(TLS_CLIENT_LABEL)); else diff --git a/src/tls/hello.cpp b/src/tls/hello.cpp index ae0d9607b..a3a15f26f 100644 --- a/src/tls/hello.cpp +++ b/src/tls/hello.cpp @@ -15,8 +15,8 @@ namespace Botan { */ void HandshakeMessage::send(Record_Writer& writer, HandshakeHash& hash) const { - SecureVector<byte> buf = serialize(); - SecureVector<byte> send_buf(4); + MemoryVector<byte> buf = serialize(); + MemoryVector<byte> send_buf(4); const size_t buf_size = buf.size(); @@ -45,9 +45,9 @@ Hello_Request::Hello_Request(Record_Writer& writer) /* * Serialize a Hello Request message */ -SecureVector<byte> Hello_Request::serialize() const +MemoryVector<byte> Hello_Request::serialize() const { - return SecureVector<byte>(); + return MemoryVector<byte>(); } /* @@ -79,9 +79,9 @@ Client_Hello::Client_Hello(RandomNumberGenerator& rng, /* * Serialize a Client Hello message */ -SecureVector<byte> Client_Hello::serialize() const +MemoryVector<byte> Client_Hello::serialize() const { - SecureVector<byte> buf; + MemoryVector<byte> buf; buf.push_back(static_cast<byte>(c_version >> 8)); buf.push_back(static_cast<byte>(c_version )); @@ -225,6 +225,7 @@ Server_Hello::Server_Hello(RandomNumberGenerator& rng, const TLS_Policy& policy, const std::vector<X509_Certificate>& certs, const Client_Hello& c_hello, + const MemoryRegion<byte>& session_id, Version_Code ver, HandshakeHash& hash) { @@ -250,6 +251,7 @@ Server_Hello::Server_Hello(RandomNumberGenerator& rng, s_version = ver; s_random = rng.random_vec(32); + sess_id = session_id; send(writer, hash); } @@ -257,9 +259,9 @@ Server_Hello::Server_Hello(RandomNumberGenerator& rng, /* * Serialize a Server Hello message */ -SecureVector<byte> Server_Hello::serialize() const +MemoryVector<byte> Server_Hello::serialize() const { - SecureVector<byte> buf; + MemoryVector<byte> buf; buf.push_back(static_cast<byte>(s_version >> 8)); buf.push_back(static_cast<byte>(s_version )); @@ -314,9 +316,9 @@ Server_Hello_Done::Server_Hello_Done(Record_Writer& writer, /* * Serialize a Server Hello Done message */ -SecureVector<byte> Server_Hello_Done::serialize() const +MemoryVector<byte> Server_Hello_Done::serialize() const { - return SecureVector<byte>(); + return MemoryVector<byte>(); } /* diff --git a/src/tls/info.txt b/src/tls/info.txt index f09309bd2..a088ed4fb 100644 --- a/src/tls/info.txt +++ b/src/tls/info.txt @@ -16,6 +16,7 @@ tls_policy.h tls_record.h tls_server.h tls_session_key.h +tls_session_state.h tls_suites.h </header:public> diff --git a/src/tls/s_kex.cpp b/src/tls/s_kex.cpp index 1e7de31d0..b11892923 100644 --- a/src/tls/s_kex.cpp +++ b/src/tls/s_kex.cpp @@ -72,9 +72,9 @@ Server_Key_Exchange::Server_Key_Exchange(RandomNumberGenerator& rng, /** * Serialize a Server Key Exchange message */ -SecureVector<byte> Server_Key_Exchange::serialize() const +MemoryVector<byte> Server_Key_Exchange::serialize() const { - SecureVector<byte> buf = serialize_params(); + MemoryVector<byte> buf = serialize_params(); append_tls_length_value(buf, signature, 2); return buf; } @@ -82,9 +82,9 @@ SecureVector<byte> Server_Key_Exchange::serialize() const /** * Serialize the ServerParams structure */ -SecureVector<byte> Server_Key_Exchange::serialize_params() const +MemoryVector<byte> Server_Key_Exchange::serialize_params() const { - SecureVector<byte> buf; + MemoryVector<byte> buf; for(size_t i = 0; i != params.size(); ++i) append_tls_length_value(buf, BigInt::encode(params[i]), 2); @@ -100,7 +100,7 @@ void Server_Key_Exchange::deserialize(const MemoryRegion<byte>& buf) if(buf.size() < 6) throw Decoding_Error("Server_Key_Exchange: Packet corrupted"); - SecureVector<byte> values[4]; + MemoryVector<byte> values[4]; size_t so_far = 0; for(size_t i = 0; i != 4; ++i) @@ -169,7 +169,7 @@ bool Server_Key_Exchange::verify(const X509_Certificate& cert, PK_Verifier verifier(*key, padding, format); - SecureVector<byte> params_got = serialize_params(); + MemoryVector<byte> params_got = serialize_params(); verifier.update(c_random); verifier.update(s_random); verifier.update(params_got); diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index 580c1e5e5..1121de1a1 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -68,11 +68,9 @@ size_t TLS_Channel::received_data(const byte buf[], size_t buf_size) if(alert_msg.is_fatal() || alert_msg.type() == CLOSE_NOTIFY) { if(alert_msg.type() == CLOSE_NOTIFY) - { - writer.alert(WARNING, CLOSE_NOTIFY); - } - - alert(FATAL, NO_ALERT_TYPE); + alert(FATAL, CLOSE_NOTIFY); + else + alert(FATAL, NO_ALERT_TYPE); } } else diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index 30c440d29..ee9c397c1 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -312,7 +312,7 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, if(!state->server_finished->verify(state->keys.master_secret(), state->version, state->hash, SERVER)) throw TLS_Exception(DECRYPT_ERROR, - "Finished message didn't verify"); + "Finished message didn't verify"); delete state; state = 0; diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h index e7eaa56e1..a7aa36366 100644 --- a/src/tls/tls_messages.h +++ b/src/tls/tls_messages.h @@ -31,7 +31,7 @@ class HandshakeMessage virtual ~HandshakeMessage() {} private: HandshakeMessage& operator=(const HandshakeMessage&) { return (*this); } - virtual SecureVector<byte> serialize() const = 0; + virtual MemoryVector<byte> serialize() const = 0; virtual void deserialize(const MemoryRegion<byte>&) = 0; }; @@ -43,11 +43,19 @@ class Client_Hello : public HandshakeMessage public: Handshake_Type type() const { return CLIENT_HELLO; } Version_Code version() const { return c_version; } - const SecureVector<byte>& session_id() const { return sess_id; } + const MemoryVector<byte>& session_id() const { return sess_id; } + + std::vector<byte> session_id_vector() const + { + std::vector<byte> v; + v.insert(v.begin(), &sess_id[0], &sess_id[sess_id.size()]); + return v; + } + std::vector<u16bit> ciphersuites() const { return suites; } std::vector<byte> compression_algos() const { return comp_algos; } - const SecureVector<byte>& random() const { return c_random; } + const MemoryVector<byte>& random() const { return c_random; } std::string hostname() const { return requested_hostname; } @@ -68,12 +76,12 @@ class Client_Hello : public HandshakeMessage } private: - SecureVector<byte> serialize() const; + MemoryVector<byte> serialize() const; void deserialize(const MemoryRegion<byte>&); void deserialize_sslv2(const MemoryRegion<byte>&); Version_Code c_version; - SecureVector<byte> sess_id, c_random; + MemoryVector<byte> sess_id, c_random; std::vector<u16bit> suites; std::vector<byte> comp_algos; std::string requested_hostname; @@ -105,7 +113,7 @@ class Client_Key_Exchange : public HandshakeMessage const CipherSuite& suite, Version_Code using_version); private: - SecureVector<byte> serialize() const; + MemoryVector<byte> serialize() const; void deserialize(const MemoryRegion<byte>&); SecureVector<byte> key_material, pre_master; @@ -125,7 +133,7 @@ class Certificate : public HandshakeMessage HandshakeHash&); Certificate(const MemoryRegion<byte>& buf) { deserialize(buf); } private: - SecureVector<byte> serialize() const; + MemoryVector<byte> serialize() const; void deserialize(const MemoryRegion<byte>&); std::vector<X509_Certificate> certs; }; @@ -150,7 +158,7 @@ class Certificate_Req : public HandshakeMessage Certificate_Req(const MemoryRegion<byte>& buf) { deserialize(buf); } private: - SecureVector<byte> serialize() const; + MemoryVector<byte> serialize() const; void deserialize(const MemoryRegion<byte>&); std::vector<X509_DN> names; @@ -173,10 +181,10 @@ class Certificate_Verify : public HandshakeMessage Certificate_Verify(const MemoryRegion<byte>& buf) { deserialize(buf); } private: - SecureVector<byte> serialize() const; + MemoryVector<byte> serialize() const; void deserialize(const MemoryRegion<byte>&); - SecureVector<byte> signature; + MemoryVector<byte> signature; }; /** @@ -194,15 +202,15 @@ class Finished : public HandshakeMessage const MemoryRegion<byte>&, HandshakeHash&); Finished(const MemoryRegion<byte>& buf) { deserialize(buf); } private: - SecureVector<byte> serialize() const; + MemoryVector<byte> serialize() const; void deserialize(const MemoryRegion<byte>&); - SecureVector<byte> compute_verify(const MemoryRegion<byte>&, + MemoryVector<byte> compute_verify(const MemoryRegion<byte>&, HandshakeHash, Connection_Side, Version_Code); Connection_Side side; - SecureVector<byte> verification_data; + MemoryVector<byte> verification_data; }; /** @@ -216,7 +224,7 @@ class Hello_Request : public HandshakeMessage Hello_Request(Record_Writer&); Hello_Request(const MemoryRegion<byte>& buf) { deserialize(buf); } private: - SecureVector<byte> serialize() const; + MemoryVector<byte> serialize() const; void deserialize(const MemoryRegion<byte>&); }; @@ -228,24 +236,28 @@ class Server_Hello : public HandshakeMessage public: Handshake_Type type() const { return SERVER_HELLO; } Version_Code version() { return s_version; } - const SecureVector<byte>& session_id() const { return sess_id; } + const MemoryVector<byte>& session_id() const { return sess_id; } u16bit ciphersuite() const { return suite; } byte compression_algo() const { return comp_algo; } - const SecureVector<byte>& random() const { return s_random; } + const MemoryVector<byte>& random() const { return s_random; } Server_Hello(RandomNumberGenerator& rng, - Record_Writer&, const TLS_Policy&, - const std::vector<X509_Certificate>&, - const Client_Hello&, Version_Code, HandshakeHash&); + Record_Writer& writer, + const TLS_Policy& policies, + const std::vector<X509_Certificate>& certs, + const Client_Hello& other, + const MemoryRegion<byte>& session_id, + Version_Code version, + HandshakeHash& hash); Server_Hello(const MemoryRegion<byte>& buf) { deserialize(buf); } private: - SecureVector<byte> serialize() const; + MemoryVector<byte> serialize() const; void deserialize(const MemoryRegion<byte>&); Version_Code s_version; - SecureVector<byte> sess_id, s_random; + MemoryVector<byte> sess_id, s_random; u16bit suite; byte comp_algo; }; @@ -269,12 +281,12 @@ class Server_Key_Exchange : public HandshakeMessage Server_Key_Exchange(const MemoryRegion<byte>& buf) { deserialize(buf); } private: - SecureVector<byte> serialize() const; - SecureVector<byte> serialize_params() const; + MemoryVector<byte> serialize() const; + MemoryVector<byte> serialize_params() const; void deserialize(const MemoryRegion<byte>&); std::vector<BigInt> params; - SecureVector<byte> signature; + MemoryVector<byte> signature; }; /** @@ -288,7 +300,7 @@ class Server_Hello_Done : public HandshakeMessage Server_Hello_Done(Record_Writer&, HandshakeHash&); Server_Hello_Done(const MemoryRegion<byte>& buf) { deserialize(buf); } private: - SecureVector<byte> serialize() const; + MemoryVector<byte> serialize() const; void deserialize(const MemoryRegion<byte>&); }; diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp index 81ed2c48e..e2f994224 100644 --- a/src/tls/tls_server.cpp +++ b/src/tls/tls_server.cpp @@ -6,12 +6,12 @@ */ #include <botan/tls_server.h> -#include <botan/internal/tls_alerts.h> #include <botan/internal/tls_state.h> -#include <botan/loadstor.h> #include <botan/rsa.h> #include <botan/dh.h> +#include <stdio.h> + namespace Botan { namespace { @@ -87,13 +87,15 @@ void server_check_state(Handshake_Type new_msg, Handshake_State* state) */ TLS_Server::TLS_Server(std::tr1::function<void (const byte[], size_t)> output_fn, std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn, + TLS_Session_Manager& session_manager, const TLS_Policy& policy, RandomNumberGenerator& rng, const X509_Certificate& cert, const Private_Key& cert_key) : TLS_Channel(output_fn, proc_fn), policy(policy), - rng(rng) + rng(rng), + session_manager(session_manager) { writer.set_version(TLS_V10); @@ -160,48 +162,66 @@ void TLS_Server::process_handshake_msg(Handshake_Type type, writer.set_version(state->version); reader.set_version(state->version); - state->server_hello = new Server_Hello(rng, writer, - policy, cert_chain, - *(state->client_hello), - state->version, state->hash); - - state->suite = CipherSuite(state->server_hello->ciphersuite()); + TLS_Session_Params params; + const bool found = session_manager.find( + state->client_hello->session_id_vector(), + params); - if(state->suite.sig_type() != TLS_ALGO_SIGNER_ANON) + if(found && params.connection_side == SERVER) { - // FIXME: should choose certs based on sig type - state->server_certs = new Certificate(writer, cert_chain, - state->hash); - } - state->kex_priv = PKCS8::copy_key(*private_key, rng); - if(state->suite.kex_type() != TLS_ALGO_KEYEXCH_NOKEX) + + + + } + else // new session { - if(state->suite.kex_type() == TLS_ALGO_KEYEXCH_RSA) + MemoryVector<byte> sess_id = rng.random_vec(32); + + state->server_hello = new Server_Hello(rng, writer, + policy, cert_chain, + *(state->client_hello), + sess_id, + state->version, state->hash); + + state->suite = CipherSuite(state->server_hello->ciphersuite()); + + if(state->suite.sig_type() != TLS_ALGO_SIGNER_ANON) { - state->kex_priv = new RSA_PrivateKey(rng, - policy.rsa_export_keysize()); + // FIXME: should choose certs based on sig type + state->server_certs = new Certificate(writer, cert_chain, + state->hash); } - else if(state->suite.kex_type() == TLS_ALGO_KEYEXCH_DH) + + state->kex_priv = PKCS8::copy_key(*private_key, rng); + if(state->suite.kex_type() != TLS_ALGO_KEYEXCH_NOKEX) { - state->kex_priv = new DH_PrivateKey(rng, policy.dh_group()); + if(state->suite.kex_type() == TLS_ALGO_KEYEXCH_RSA) + { + state->kex_priv = new RSA_PrivateKey(rng, + policy.rsa_export_keysize()); + } + else if(state->suite.kex_type() == TLS_ALGO_KEYEXCH_DH) + { + state->kex_priv = new DH_PrivateKey(rng, policy.dh_group()); + } + else + throw Internal_Error("TLS_Server: Unknown ciphersuite kex type"); + + state->server_kex = + new Server_Key_Exchange(rng, writer, + state->kex_priv, private_key, + state->client_hello->random(), + state->server_hello->random(), + state->hash); } - else - throw Internal_Error("TLS_Server: Unknown ciphersuite kex type"); - - state->server_kex = - new Server_Key_Exchange(rng, writer, - state->kex_priv, private_key, - state->client_hello->random(), - state->server_hello->random(), - state->hash); - } - if(policy.require_client_auth()) - { - state->do_client_auth = true; - throw Internal_Error("Client auth not implemented"); - // FIXME: send client auth request here + if(policy.require_client_auth()) + { + state->do_client_auth = true; + throw Internal_Error("Client auth not implemented"); + // FIXME: send client auth request here + } } state->server_hello_done = new Server_Hello_Done(writer, state->hash); diff --git a/src/tls/tls_server.h b/src/tls/tls_server.h index e975071d2..a1f99a0ff 100644 --- a/src/tls/tls_server.h +++ b/src/tls/tls_server.h @@ -9,6 +9,7 @@ #define BOTAN_TLS_SERVER_H__ #include <botan/tls_channel.h> +#include <botan/tls_session_state.h> #include <vector> namespace Botan { @@ -28,6 +29,7 @@ class BOTAN_DLL TLS_Server : public TLS_Channel */ TLS_Server(std::tr1::function<void (const byte[], size_t)> socket_output_fn, std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn, + TLS_Session_Manager& session_manager, const TLS_Policy& policy, RandomNumberGenerator& rng, const X509_Certificate& cert, @@ -47,6 +49,7 @@ class BOTAN_DLL TLS_Server : public TLS_Channel const TLS_Policy& policy; RandomNumberGenerator& rng; + TLS_Session_Manager& session_manager; std::vector<X509_Certificate> cert_chain; Private_Key* private_key; diff --git a/src/tls/tls_session_key.cpp b/src/tls/tls_session_key.cpp index 7c75d1758..865cc0b80 100644 --- a/src/tls/tls_session_key.cpp +++ b/src/tls/tls_session_key.cpp @@ -13,62 +13,6 @@ namespace Botan { /** -* Return the client cipher key -*/ -SymmetricKey SessionKeys::client_cipher_key() const - { - return c_cipher; - } - -/** -* Return the server cipher key -*/ -SymmetricKey SessionKeys::server_cipher_key() const - { - return s_cipher; - } - -/** -* Return the client MAC key -*/ -SymmetricKey SessionKeys::client_mac_key() const - { - return c_mac; - } - -/** -* Return the server MAC key -*/ -SymmetricKey SessionKeys::server_mac_key() const - { - return s_mac; - } - -/** -* Return the client cipher IV -*/ -InitializationVector SessionKeys::client_iv() const - { - return c_iv; - } - -/** -* Return the server cipher IV -*/ -InitializationVector SessionKeys::server_iv() const - { - return s_iv; - } - -/** -* Return the TLS master secret -*/ -SecureVector<byte> SessionKeys::master_secret() const - { - return master_sec; - } - -/** * Generate SSLv3 session keys */ SymmetricKey SessionKeys::ssl3_keygen(size_t prf_gen, @@ -126,7 +70,8 @@ SymmetricKey SessionKeys::tls1_keygen(size_t prf_gen, /** * SessionKeys Constructor */ -SessionKeys::SessionKeys(const CipherSuite& suite, Version_Code version, +SessionKeys::SessionKeys(const CipherSuite& suite, + Version_Code version, const MemoryRegion<byte>& pre_master_secret, const MemoryRegion<byte>& c_random, const MemoryRegion<byte>& s_random) diff --git a/src/tls/tls_session_key.h b/src/tls/tls_session_key.h index 51397984b..f0e185bd8 100644 --- a/src/tls/tls_session_key.h +++ b/src/tls/tls_session_key.h @@ -20,20 +20,25 @@ namespace Botan { class BOTAN_DLL SessionKeys { public: - SymmetricKey client_cipher_key() const; - SymmetricKey server_cipher_key() const; + SymmetricKey client_cipher_key() const { return c_cipher; } + SymmetricKey server_cipher_key() const { return s_cipher; } - SymmetricKey client_mac_key() const; - SymmetricKey server_mac_key() const; + SymmetricKey client_mac_key() const { return c_mac; } + SymmetricKey server_mac_key() const { return s_mac; } - InitializationVector client_iv() const; - InitializationVector server_iv() const; + InitializationVector client_iv() const { return c_iv; } + InitializationVector server_iv() const { return s_iv; } - SecureVector<byte> master_secret() const; + SecureVector<byte> master_secret() const { return master_sec; } SessionKeys() {} - SessionKeys(const CipherSuite&, Version_Code, const MemoryRegion<byte>&, - const MemoryRegion<byte>&, const MemoryRegion<byte>&); + + SessionKeys(const CipherSuite& suite, + Version_Code version, + const MemoryRegion<byte>& pre_master, + const MemoryRegion<byte>& client_random, + const MemoryRegion<byte>& server_random); + private: SymmetricKey ssl3_keygen(size_t, const MemoryRegion<byte>&, const MemoryRegion<byte>&, diff --git a/src/tls/tls_session_state.h b/src/tls/tls_session_state.h new file mode 100644 index 000000000..e6f25b34d --- /dev/null +++ b/src/tls/tls_session_state.h @@ -0,0 +1,128 @@ +/* +* TLS Session Management +* (C) 2011 Jack Lloyd +* +* Released under the terms of the Botan license +*/ + +#ifndef TLS_SESSION_STATE_H_ +#define TLS_SESSION_STATE_H_ + +#include <botan/tls_magic.h> +#include <botan/secmem.h> +#include <vector> +#include <map> + +#include <iostream> + +namespace Botan { + +struct BOTAN_DLL TLS_Session_Params + { + SecureVector<byte> master_secret; + std::vector<byte> client_random; + std::vector<byte> server_random; + + bool resumable; + Connection_Side connection_side; + Ciphersuite_Code ciphersuite; + Compression_Algo compression_method; + }; + +/** +* TLS_Session_Manager is an interface to systems which can save +* session parameters for support session resumption. +*/ +class BOTAN_DLL TLS_Session_Manager + { + public: + /** + * Try to load a saved session + * @param session_id the session identifier we are trying to resume + * @param params will be set to the saved session data (if found), + or not modified if not found + * @return true if params was modified + */ + virtual bool find(const std::vector<byte>& session_id, + TLS_Session_Params& params) = 0; + + /** + * Prohibit resumption of this session. Effectively an erase. + */ + virtual void prohibit_resumption(const std::vector<byte>& session_id) = 0; + + /** + * Save a session on a best effort basis; the manager may not in + * fact be able to save the session for whatever reason, this is + * not an error. Caller cannot assume that calling save followed + * immediately by find will result in a successful lookup. + * + * @param session_id the session identifier + * @param params to save + */ + virtual void save(const std::vector<byte>& session_id, + const TLS_Session_Params& params) = 0; + + virtual ~TLS_Session_Manager() {} + }; + +/** +* A simple implementation of TLS_Session_Manager that just saves +* values in memory, with no persistance abilities +*/ +class BOTAN_DLL TLS_Session_Manager_In_Memory : public TLS_Session_Manager + { + public: + /** + * @param max_sessions a hint on the maximum number of sessions + * to save at any one time. + */ + TLS_Session_Manager_In_Memory(size_t max_sessions = 0) : + max_sessions(max_sessions) {} + + bool find(const std::vector<byte>& session_id, + TLS_Session_Params& params) + { + std::map<std::vector<byte>, TLS_Session_Params>::const_iterator i = + sessions.find(session_id); + + std::cout << "Know about " << sessions.size() << " sessions\n"; + + if(i != sessions.end()) + { + params = i->second; + return true; + } + + return false; + } + + void prohibit_resumption(const std::vector<byte>& session_id) + { + std::map<std::vector<byte>, TLS_Session_Params>::const_iterator i = + sessions.find(session_id); + + if(i != sessions.end()) + sessions.erase(i); + } + + void save(const std::vector<byte>& session_id, + const TLS_Session_Params& session_data) + { + if(max_sessions != 0) + { + while(sessions.size() >= max_sessions) + sessions.erase(sessions.begin()); + } + + sessions[session_id] = session_data; + } + + private: + size_t max_sessions; + std::map<std::vector<byte>, TLS_Session_Params> sessions; + }; + +} + +#endif |