aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorlloyd <[email protected]>2011-12-23 20:23:03 +0000
committerlloyd <[email protected]>2011-12-23 20:23:03 +0000
commitf5c863cf97ea11876acad3c46fffca23685698aa (patch)
tree5537f5c843602f136f6eb1835d8679c9ae67009e
parent61d461d0a5fb63c3aee906c76b4aefe3335a7591 (diff)
Initial hooks for session resumption
-rw-r--r--doc/examples/tls_server.cpp3
-rw-r--r--src/tls/c_kex.cpp4
-rw-r--r--src/tls/cert_req.cpp10
-rw-r--r--src/tls/cert_ver.cpp4
-rw-r--r--src/tls/finished.cpp10
-rw-r--r--src/tls/hello.cpp22
-rw-r--r--src/tls/info.txt1
-rw-r--r--src/tls/s_kex.cpp12
-rw-r--r--src/tls/tls_channel.cpp8
-rw-r--r--src/tls/tls_client.cpp2
-rw-r--r--src/tls/tls_messages.h62
-rw-r--r--src/tls/tls_server.cpp92
-rw-r--r--src/tls/tls_server.h3
-rw-r--r--src/tls/tls_session_key.cpp59
-rw-r--r--src/tls/tls_session_key.h23
-rw-r--r--src/tls/tls_session_state.h128
16 files changed, 280 insertions, 163 deletions
diff --git a/doc/examples/tls_server.cpp b/doc/examples/tls_server.cpp
index 62bc8fadc..eff3a3c3c 100644
--- a/doc/examples/tls_server.cpp
+++ b/doc/examples/tls_server.cpp
@@ -64,6 +64,8 @@ int main(int argc, char* argv[])
Server_TLS_Policy policy;
+ TLS_Session_Manager_In_Memory sessions;
+
while(true)
{
try {
@@ -76,6 +78,7 @@ int main(int argc, char* argv[])
TLS_Server tls(
std::tr1::bind(&Socket::write, std::tr1::ref(sock), _1, _2),
proc_data,
+ sessions,
policy,
rng,
cert,
diff --git a/src/tls/c_kex.cpp b/src/tls/c_kex.cpp
index 0f20b819c..b55973ca3 100644
--- a/src/tls/c_kex.cpp
+++ b/src/tls/c_kex.cpp
@@ -75,11 +75,11 @@ Client_Key_Exchange::Client_Key_Exchange(const MemoryRegion<byte>& contents,
/**
* Serialize a Client Key Exchange message
*/
-SecureVector<byte> Client_Key_Exchange::serialize() const
+MemoryVector<byte> Client_Key_Exchange::serialize() const
{
if(include_length)
{
- SecureVector<byte> buf;
+ MemoryVector<byte> buf;
append_tls_length_value(buf, key_material, 2);
return buf;
}
diff --git a/src/tls/cert_req.cpp b/src/tls/cert_req.cpp
index b8b2624bf..74398a59f 100644
--- a/src/tls/cert_req.cpp
+++ b/src/tls/cert_req.cpp
@@ -34,9 +34,9 @@ Certificate_Req::Certificate_Req(Record_Writer& writer,
/**
* Serialize a Certificate Request message
*/
-SecureVector<byte> Certificate_Req::serialize() const
+MemoryVector<byte> Certificate_Req::serialize() const
{
- SecureVector<byte> buf;
+ MemoryVector<byte> buf;
append_tls_length_value(buf, types, 1);
@@ -94,13 +94,13 @@ Certificate::Certificate(Record_Writer& writer,
/**
* Serialize a Certificate message
*/
-SecureVector<byte> Certificate::serialize() const
+MemoryVector<byte> Certificate::serialize() const
{
- SecureVector<byte> buf(3);
+ MemoryVector<byte> buf(3);
for(size_t i = 0; i != certs.size(); ++i)
{
- SecureVector<byte> raw_cert = certs[i].BER_encode();
+ MemoryVector<byte> raw_cert = certs[i].BER_encode();
const size_t cert_size = raw_cert.size();
for(size_t i = 0; i != 3; ++i)
buf.push_back(get_byte<u32bit>(i+1, cert_size));
diff --git a/src/tls/cert_ver.cpp b/src/tls/cert_ver.cpp
index 3220a8c9e..0d8256e5e 100644
--- a/src/tls/cert_ver.cpp
+++ b/src/tls/cert_ver.cpp
@@ -46,9 +46,9 @@ Certificate_Verify::Certificate_Verify(RandomNumberGenerator& rng,
/**
* Serialize a Certificate Verify message
*/
-SecureVector<byte> Certificate_Verify::serialize() const
+MemoryVector<byte> Certificate_Verify::serialize() const
{
- SecureVector<byte> buf;
+ MemoryVector<byte> buf;
const u16bit sig_len = signature.size();
buf.push_back(get_byte(0, sig_len));
diff --git a/src/tls/finished.cpp b/src/tls/finished.cpp
index d76fbd884..dff977d31 100644
--- a/src/tls/finished.cpp
+++ b/src/tls/finished.cpp
@@ -25,7 +25,7 @@ Finished::Finished(Record_Writer& writer,
/**
* Serialize a Finished message
*/
-SecureVector<byte> Finished::serialize() const
+MemoryVector<byte> Finished::serialize() const
{
return verification_data;
}
@@ -44,7 +44,7 @@ void Finished::deserialize(const MemoryRegion<byte>& buf)
bool Finished::verify(const MemoryRegion<byte>& secret, Version_Code version,
const HandshakeHash& hash, Connection_Side side)
{
- SecureVector<byte> computed = compute_verify(secret, hash, side, version);
+ MemoryVector<byte> computed = compute_verify(secret, hash, side, version);
if(computed == verification_data)
return true;
return false;
@@ -53,7 +53,7 @@ bool Finished::verify(const MemoryRegion<byte>& secret, Version_Code version,
/**
* Compute the verify_data
*/
-SecureVector<byte> Finished::compute_verify(const MemoryRegion<byte>& secret,
+MemoryVector<byte> Finished::compute_verify(const MemoryRegion<byte>& secret,
HandshakeHash hash,
Connection_Side side,
Version_Code version)
@@ -63,7 +63,7 @@ SecureVector<byte> Finished::compute_verify(const MemoryRegion<byte>& secret,
const byte SSL_CLIENT_LABEL[] = { 0x43, 0x4C, 0x4E, 0x54 };
const byte SSL_SERVER_LABEL[] = { 0x53, 0x52, 0x56, 0x52 };
- SecureVector<byte> ssl3_finished;
+ MemoryVector<byte> ssl3_finished;
if(side == CLIENT)
hash.update(SSL_CLIENT_LABEL, sizeof(SSL_CLIENT_LABEL));
@@ -84,7 +84,7 @@ SecureVector<byte> Finished::compute_verify(const MemoryRegion<byte>& secret,
TLS_PRF prf;
- SecureVector<byte> input;
+ MemoryVector<byte> input;
if(side == CLIENT)
input += std::make_pair(TLS_CLIENT_LABEL, sizeof(TLS_CLIENT_LABEL));
else
diff --git a/src/tls/hello.cpp b/src/tls/hello.cpp
index ae0d9607b..a3a15f26f 100644
--- a/src/tls/hello.cpp
+++ b/src/tls/hello.cpp
@@ -15,8 +15,8 @@ namespace Botan {
*/
void HandshakeMessage::send(Record_Writer& writer, HandshakeHash& hash) const
{
- SecureVector<byte> buf = serialize();
- SecureVector<byte> send_buf(4);
+ MemoryVector<byte> buf = serialize();
+ MemoryVector<byte> send_buf(4);
const size_t buf_size = buf.size();
@@ -45,9 +45,9 @@ Hello_Request::Hello_Request(Record_Writer& writer)
/*
* Serialize a Hello Request message
*/
-SecureVector<byte> Hello_Request::serialize() const
+MemoryVector<byte> Hello_Request::serialize() const
{
- return SecureVector<byte>();
+ return MemoryVector<byte>();
}
/*
@@ -79,9 +79,9 @@ Client_Hello::Client_Hello(RandomNumberGenerator& rng,
/*
* Serialize a Client Hello message
*/
-SecureVector<byte> Client_Hello::serialize() const
+MemoryVector<byte> Client_Hello::serialize() const
{
- SecureVector<byte> buf;
+ MemoryVector<byte> buf;
buf.push_back(static_cast<byte>(c_version >> 8));
buf.push_back(static_cast<byte>(c_version ));
@@ -225,6 +225,7 @@ Server_Hello::Server_Hello(RandomNumberGenerator& rng,
const TLS_Policy& policy,
const std::vector<X509_Certificate>& certs,
const Client_Hello& c_hello,
+ const MemoryRegion<byte>& session_id,
Version_Code ver,
HandshakeHash& hash)
{
@@ -250,6 +251,7 @@ Server_Hello::Server_Hello(RandomNumberGenerator& rng,
s_version = ver;
s_random = rng.random_vec(32);
+ sess_id = session_id;
send(writer, hash);
}
@@ -257,9 +259,9 @@ Server_Hello::Server_Hello(RandomNumberGenerator& rng,
/*
* Serialize a Server Hello message
*/
-SecureVector<byte> Server_Hello::serialize() const
+MemoryVector<byte> Server_Hello::serialize() const
{
- SecureVector<byte> buf;
+ MemoryVector<byte> buf;
buf.push_back(static_cast<byte>(s_version >> 8));
buf.push_back(static_cast<byte>(s_version ));
@@ -314,9 +316,9 @@ Server_Hello_Done::Server_Hello_Done(Record_Writer& writer,
/*
* Serialize a Server Hello Done message
*/
-SecureVector<byte> Server_Hello_Done::serialize() const
+MemoryVector<byte> Server_Hello_Done::serialize() const
{
- return SecureVector<byte>();
+ return MemoryVector<byte>();
}
/*
diff --git a/src/tls/info.txt b/src/tls/info.txt
index f09309bd2..a088ed4fb 100644
--- a/src/tls/info.txt
+++ b/src/tls/info.txt
@@ -16,6 +16,7 @@ tls_policy.h
tls_record.h
tls_server.h
tls_session_key.h
+tls_session_state.h
tls_suites.h
</header:public>
diff --git a/src/tls/s_kex.cpp b/src/tls/s_kex.cpp
index 1e7de31d0..b11892923 100644
--- a/src/tls/s_kex.cpp
+++ b/src/tls/s_kex.cpp
@@ -72,9 +72,9 @@ Server_Key_Exchange::Server_Key_Exchange(RandomNumberGenerator& rng,
/**
* Serialize a Server Key Exchange message
*/
-SecureVector<byte> Server_Key_Exchange::serialize() const
+MemoryVector<byte> Server_Key_Exchange::serialize() const
{
- SecureVector<byte> buf = serialize_params();
+ MemoryVector<byte> buf = serialize_params();
append_tls_length_value(buf, signature, 2);
return buf;
}
@@ -82,9 +82,9 @@ SecureVector<byte> Server_Key_Exchange::serialize() const
/**
* Serialize the ServerParams structure
*/
-SecureVector<byte> Server_Key_Exchange::serialize_params() const
+MemoryVector<byte> Server_Key_Exchange::serialize_params() const
{
- SecureVector<byte> buf;
+ MemoryVector<byte> buf;
for(size_t i = 0; i != params.size(); ++i)
append_tls_length_value(buf, BigInt::encode(params[i]), 2);
@@ -100,7 +100,7 @@ void Server_Key_Exchange::deserialize(const MemoryRegion<byte>& buf)
if(buf.size() < 6)
throw Decoding_Error("Server_Key_Exchange: Packet corrupted");
- SecureVector<byte> values[4];
+ MemoryVector<byte> values[4];
size_t so_far = 0;
for(size_t i = 0; i != 4; ++i)
@@ -169,7 +169,7 @@ bool Server_Key_Exchange::verify(const X509_Certificate& cert,
PK_Verifier verifier(*key, padding, format);
- SecureVector<byte> params_got = serialize_params();
+ MemoryVector<byte> params_got = serialize_params();
verifier.update(c_random);
verifier.update(s_random);
verifier.update(params_got);
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index 580c1e5e5..1121de1a1 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -68,11 +68,9 @@ size_t TLS_Channel::received_data(const byte buf[], size_t buf_size)
if(alert_msg.is_fatal() || alert_msg.type() == CLOSE_NOTIFY)
{
if(alert_msg.type() == CLOSE_NOTIFY)
- {
- writer.alert(WARNING, CLOSE_NOTIFY);
- }
-
- alert(FATAL, NO_ALERT_TYPE);
+ alert(FATAL, CLOSE_NOTIFY);
+ else
+ alert(FATAL, NO_ALERT_TYPE);
}
}
else
diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp
index 30c440d29..ee9c397c1 100644
--- a/src/tls/tls_client.cpp
+++ b/src/tls/tls_client.cpp
@@ -312,7 +312,7 @@ void TLS_Client::process_handshake_msg(Handshake_Type type,
if(!state->server_finished->verify(state->keys.master_secret(),
state->version, state->hash, SERVER))
throw TLS_Exception(DECRYPT_ERROR,
- "Finished message didn't verify");
+ "Finished message didn't verify");
delete state;
state = 0;
diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h
index e7eaa56e1..a7aa36366 100644
--- a/src/tls/tls_messages.h
+++ b/src/tls/tls_messages.h
@@ -31,7 +31,7 @@ class HandshakeMessage
virtual ~HandshakeMessage() {}
private:
HandshakeMessage& operator=(const HandshakeMessage&) { return (*this); }
- virtual SecureVector<byte> serialize() const = 0;
+ virtual MemoryVector<byte> serialize() const = 0;
virtual void deserialize(const MemoryRegion<byte>&) = 0;
};
@@ -43,11 +43,19 @@ class Client_Hello : public HandshakeMessage
public:
Handshake_Type type() const { return CLIENT_HELLO; }
Version_Code version() const { return c_version; }
- const SecureVector<byte>& session_id() const { return sess_id; }
+ const MemoryVector<byte>& session_id() const { return sess_id; }
+
+ std::vector<byte> session_id_vector() const
+ {
+ std::vector<byte> v;
+ v.insert(v.begin(), &sess_id[0], &sess_id[sess_id.size()]);
+ return v;
+ }
+
std::vector<u16bit> ciphersuites() const { return suites; }
std::vector<byte> compression_algos() const { return comp_algos; }
- const SecureVector<byte>& random() const { return c_random; }
+ const MemoryVector<byte>& random() const { return c_random; }
std::string hostname() const { return requested_hostname; }
@@ -68,12 +76,12 @@ class Client_Hello : public HandshakeMessage
}
private:
- SecureVector<byte> serialize() const;
+ MemoryVector<byte> serialize() const;
void deserialize(const MemoryRegion<byte>&);
void deserialize_sslv2(const MemoryRegion<byte>&);
Version_Code c_version;
- SecureVector<byte> sess_id, c_random;
+ MemoryVector<byte> sess_id, c_random;
std::vector<u16bit> suites;
std::vector<byte> comp_algos;
std::string requested_hostname;
@@ -105,7 +113,7 @@ class Client_Key_Exchange : public HandshakeMessage
const CipherSuite& suite,
Version_Code using_version);
private:
- SecureVector<byte> serialize() const;
+ MemoryVector<byte> serialize() const;
void deserialize(const MemoryRegion<byte>&);
SecureVector<byte> key_material, pre_master;
@@ -125,7 +133,7 @@ class Certificate : public HandshakeMessage
HandshakeHash&);
Certificate(const MemoryRegion<byte>& buf) { deserialize(buf); }
private:
- SecureVector<byte> serialize() const;
+ MemoryVector<byte> serialize() const;
void deserialize(const MemoryRegion<byte>&);
std::vector<X509_Certificate> certs;
};
@@ -150,7 +158,7 @@ class Certificate_Req : public HandshakeMessage
Certificate_Req(const MemoryRegion<byte>& buf) { deserialize(buf); }
private:
- SecureVector<byte> serialize() const;
+ MemoryVector<byte> serialize() const;
void deserialize(const MemoryRegion<byte>&);
std::vector<X509_DN> names;
@@ -173,10 +181,10 @@ class Certificate_Verify : public HandshakeMessage
Certificate_Verify(const MemoryRegion<byte>& buf) { deserialize(buf); }
private:
- SecureVector<byte> serialize() const;
+ MemoryVector<byte> serialize() const;
void deserialize(const MemoryRegion<byte>&);
- SecureVector<byte> signature;
+ MemoryVector<byte> signature;
};
/**
@@ -194,15 +202,15 @@ class Finished : public HandshakeMessage
const MemoryRegion<byte>&, HandshakeHash&);
Finished(const MemoryRegion<byte>& buf) { deserialize(buf); }
private:
- SecureVector<byte> serialize() const;
+ MemoryVector<byte> serialize() const;
void deserialize(const MemoryRegion<byte>&);
- SecureVector<byte> compute_verify(const MemoryRegion<byte>&,
+ MemoryVector<byte> compute_verify(const MemoryRegion<byte>&,
HandshakeHash, Connection_Side,
Version_Code);
Connection_Side side;
- SecureVector<byte> verification_data;
+ MemoryVector<byte> verification_data;
};
/**
@@ -216,7 +224,7 @@ class Hello_Request : public HandshakeMessage
Hello_Request(Record_Writer&);
Hello_Request(const MemoryRegion<byte>& buf) { deserialize(buf); }
private:
- SecureVector<byte> serialize() const;
+ MemoryVector<byte> serialize() const;
void deserialize(const MemoryRegion<byte>&);
};
@@ -228,24 +236,28 @@ class Server_Hello : public HandshakeMessage
public:
Handshake_Type type() const { return SERVER_HELLO; }
Version_Code version() { return s_version; }
- const SecureVector<byte>& session_id() const { return sess_id; }
+ const MemoryVector<byte>& session_id() const { return sess_id; }
u16bit ciphersuite() const { return suite; }
byte compression_algo() const { return comp_algo; }
- const SecureVector<byte>& random() const { return s_random; }
+ const MemoryVector<byte>& random() const { return s_random; }
Server_Hello(RandomNumberGenerator& rng,
- Record_Writer&, const TLS_Policy&,
- const std::vector<X509_Certificate>&,
- const Client_Hello&, Version_Code, HandshakeHash&);
+ Record_Writer& writer,
+ const TLS_Policy& policies,
+ const std::vector<X509_Certificate>& certs,
+ const Client_Hello& other,
+ const MemoryRegion<byte>& session_id,
+ Version_Code version,
+ HandshakeHash& hash);
Server_Hello(const MemoryRegion<byte>& buf) { deserialize(buf); }
private:
- SecureVector<byte> serialize() const;
+ MemoryVector<byte> serialize() const;
void deserialize(const MemoryRegion<byte>&);
Version_Code s_version;
- SecureVector<byte> sess_id, s_random;
+ MemoryVector<byte> sess_id, s_random;
u16bit suite;
byte comp_algo;
};
@@ -269,12 +281,12 @@ class Server_Key_Exchange : public HandshakeMessage
Server_Key_Exchange(const MemoryRegion<byte>& buf) { deserialize(buf); }
private:
- SecureVector<byte> serialize() const;
- SecureVector<byte> serialize_params() const;
+ MemoryVector<byte> serialize() const;
+ MemoryVector<byte> serialize_params() const;
void deserialize(const MemoryRegion<byte>&);
std::vector<BigInt> params;
- SecureVector<byte> signature;
+ MemoryVector<byte> signature;
};
/**
@@ -288,7 +300,7 @@ class Server_Hello_Done : public HandshakeMessage
Server_Hello_Done(Record_Writer&, HandshakeHash&);
Server_Hello_Done(const MemoryRegion<byte>& buf) { deserialize(buf); }
private:
- SecureVector<byte> serialize() const;
+ MemoryVector<byte> serialize() const;
void deserialize(const MemoryRegion<byte>&);
};
diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp
index 81ed2c48e..e2f994224 100644
--- a/src/tls/tls_server.cpp
+++ b/src/tls/tls_server.cpp
@@ -6,12 +6,12 @@
*/
#include <botan/tls_server.h>
-#include <botan/internal/tls_alerts.h>
#include <botan/internal/tls_state.h>
-#include <botan/loadstor.h>
#include <botan/rsa.h>
#include <botan/dh.h>
+#include <stdio.h>
+
namespace Botan {
namespace {
@@ -87,13 +87,15 @@ void server_check_state(Handshake_Type new_msg, Handshake_State* state)
*/
TLS_Server::TLS_Server(std::tr1::function<void (const byte[], size_t)> output_fn,
std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn,
+ TLS_Session_Manager& session_manager,
const TLS_Policy& policy,
RandomNumberGenerator& rng,
const X509_Certificate& cert,
const Private_Key& cert_key) :
TLS_Channel(output_fn, proc_fn),
policy(policy),
- rng(rng)
+ rng(rng),
+ session_manager(session_manager)
{
writer.set_version(TLS_V10);
@@ -160,48 +162,66 @@ void TLS_Server::process_handshake_msg(Handshake_Type type,
writer.set_version(state->version);
reader.set_version(state->version);
- state->server_hello = new Server_Hello(rng, writer,
- policy, cert_chain,
- *(state->client_hello),
- state->version, state->hash);
-
- state->suite = CipherSuite(state->server_hello->ciphersuite());
+ TLS_Session_Params params;
+ const bool found = session_manager.find(
+ state->client_hello->session_id_vector(),
+ params);
- if(state->suite.sig_type() != TLS_ALGO_SIGNER_ANON)
+ if(found && params.connection_side == SERVER)
{
- // FIXME: should choose certs based on sig type
- state->server_certs = new Certificate(writer, cert_chain,
- state->hash);
- }
- state->kex_priv = PKCS8::copy_key(*private_key, rng);
- if(state->suite.kex_type() != TLS_ALGO_KEYEXCH_NOKEX)
+
+
+
+ }
+ else // new session
{
- if(state->suite.kex_type() == TLS_ALGO_KEYEXCH_RSA)
+ MemoryVector<byte> sess_id = rng.random_vec(32);
+
+ state->server_hello = new Server_Hello(rng, writer,
+ policy, cert_chain,
+ *(state->client_hello),
+ sess_id,
+ state->version, state->hash);
+
+ state->suite = CipherSuite(state->server_hello->ciphersuite());
+
+ if(state->suite.sig_type() != TLS_ALGO_SIGNER_ANON)
{
- state->kex_priv = new RSA_PrivateKey(rng,
- policy.rsa_export_keysize());
+ // FIXME: should choose certs based on sig type
+ state->server_certs = new Certificate(writer, cert_chain,
+ state->hash);
}
- else if(state->suite.kex_type() == TLS_ALGO_KEYEXCH_DH)
+
+ state->kex_priv = PKCS8::copy_key(*private_key, rng);
+ if(state->suite.kex_type() != TLS_ALGO_KEYEXCH_NOKEX)
{
- state->kex_priv = new DH_PrivateKey(rng, policy.dh_group());
+ if(state->suite.kex_type() == TLS_ALGO_KEYEXCH_RSA)
+ {
+ state->kex_priv = new RSA_PrivateKey(rng,
+ policy.rsa_export_keysize());
+ }
+ else if(state->suite.kex_type() == TLS_ALGO_KEYEXCH_DH)
+ {
+ state->kex_priv = new DH_PrivateKey(rng, policy.dh_group());
+ }
+ else
+ throw Internal_Error("TLS_Server: Unknown ciphersuite kex type");
+
+ state->server_kex =
+ new Server_Key_Exchange(rng, writer,
+ state->kex_priv, private_key,
+ state->client_hello->random(),
+ state->server_hello->random(),
+ state->hash);
}
- else
- throw Internal_Error("TLS_Server: Unknown ciphersuite kex type");
-
- state->server_kex =
- new Server_Key_Exchange(rng, writer,
- state->kex_priv, private_key,
- state->client_hello->random(),
- state->server_hello->random(),
- state->hash);
- }
- if(policy.require_client_auth())
- {
- state->do_client_auth = true;
- throw Internal_Error("Client auth not implemented");
- // FIXME: send client auth request here
+ if(policy.require_client_auth())
+ {
+ state->do_client_auth = true;
+ throw Internal_Error("Client auth not implemented");
+ // FIXME: send client auth request here
+ }
}
state->server_hello_done = new Server_Hello_Done(writer, state->hash);
diff --git a/src/tls/tls_server.h b/src/tls/tls_server.h
index e975071d2..a1f99a0ff 100644
--- a/src/tls/tls_server.h
+++ b/src/tls/tls_server.h
@@ -9,6 +9,7 @@
#define BOTAN_TLS_SERVER_H__
#include <botan/tls_channel.h>
+#include <botan/tls_session_state.h>
#include <vector>
namespace Botan {
@@ -28,6 +29,7 @@ class BOTAN_DLL TLS_Server : public TLS_Channel
*/
TLS_Server(std::tr1::function<void (const byte[], size_t)> socket_output_fn,
std::tr1::function<void (const byte[], size_t, u16bit)> proc_fn,
+ TLS_Session_Manager& session_manager,
const TLS_Policy& policy,
RandomNumberGenerator& rng,
const X509_Certificate& cert,
@@ -47,6 +49,7 @@ class BOTAN_DLL TLS_Server : public TLS_Channel
const TLS_Policy& policy;
RandomNumberGenerator& rng;
+ TLS_Session_Manager& session_manager;
std::vector<X509_Certificate> cert_chain;
Private_Key* private_key;
diff --git a/src/tls/tls_session_key.cpp b/src/tls/tls_session_key.cpp
index 7c75d1758..865cc0b80 100644
--- a/src/tls/tls_session_key.cpp
+++ b/src/tls/tls_session_key.cpp
@@ -13,62 +13,6 @@
namespace Botan {
/**
-* Return the client cipher key
-*/
-SymmetricKey SessionKeys::client_cipher_key() const
- {
- return c_cipher;
- }
-
-/**
-* Return the server cipher key
-*/
-SymmetricKey SessionKeys::server_cipher_key() const
- {
- return s_cipher;
- }
-
-/**
-* Return the client MAC key
-*/
-SymmetricKey SessionKeys::client_mac_key() const
- {
- return c_mac;
- }
-
-/**
-* Return the server MAC key
-*/
-SymmetricKey SessionKeys::server_mac_key() const
- {
- return s_mac;
- }
-
-/**
-* Return the client cipher IV
-*/
-InitializationVector SessionKeys::client_iv() const
- {
- return c_iv;
- }
-
-/**
-* Return the server cipher IV
-*/
-InitializationVector SessionKeys::server_iv() const
- {
- return s_iv;
- }
-
-/**
-* Return the TLS master secret
-*/
-SecureVector<byte> SessionKeys::master_secret() const
- {
- return master_sec;
- }
-
-/**
* Generate SSLv3 session keys
*/
SymmetricKey SessionKeys::ssl3_keygen(size_t prf_gen,
@@ -126,7 +70,8 @@ SymmetricKey SessionKeys::tls1_keygen(size_t prf_gen,
/**
* SessionKeys Constructor
*/
-SessionKeys::SessionKeys(const CipherSuite& suite, Version_Code version,
+SessionKeys::SessionKeys(const CipherSuite& suite,
+ Version_Code version,
const MemoryRegion<byte>& pre_master_secret,
const MemoryRegion<byte>& c_random,
const MemoryRegion<byte>& s_random)
diff --git a/src/tls/tls_session_key.h b/src/tls/tls_session_key.h
index 51397984b..f0e185bd8 100644
--- a/src/tls/tls_session_key.h
+++ b/src/tls/tls_session_key.h
@@ -20,20 +20,25 @@ namespace Botan {
class BOTAN_DLL SessionKeys
{
public:
- SymmetricKey client_cipher_key() const;
- SymmetricKey server_cipher_key() const;
+ SymmetricKey client_cipher_key() const { return c_cipher; }
+ SymmetricKey server_cipher_key() const { return s_cipher; }
- SymmetricKey client_mac_key() const;
- SymmetricKey server_mac_key() const;
+ SymmetricKey client_mac_key() const { return c_mac; }
+ SymmetricKey server_mac_key() const { return s_mac; }
- InitializationVector client_iv() const;
- InitializationVector server_iv() const;
+ InitializationVector client_iv() const { return c_iv; }
+ InitializationVector server_iv() const { return s_iv; }
- SecureVector<byte> master_secret() const;
+ SecureVector<byte> master_secret() const { return master_sec; }
SessionKeys() {}
- SessionKeys(const CipherSuite&, Version_Code, const MemoryRegion<byte>&,
- const MemoryRegion<byte>&, const MemoryRegion<byte>&);
+
+ SessionKeys(const CipherSuite& suite,
+ Version_Code version,
+ const MemoryRegion<byte>& pre_master,
+ const MemoryRegion<byte>& client_random,
+ const MemoryRegion<byte>& server_random);
+
private:
SymmetricKey ssl3_keygen(size_t, const MemoryRegion<byte>&,
const MemoryRegion<byte>&,
diff --git a/src/tls/tls_session_state.h b/src/tls/tls_session_state.h
new file mode 100644
index 000000000..e6f25b34d
--- /dev/null
+++ b/src/tls/tls_session_state.h
@@ -0,0 +1,128 @@
+/*
+* TLS Session Management
+* (C) 2011 Jack Lloyd
+*
+* Released under the terms of the Botan license
+*/
+
+#ifndef TLS_SESSION_STATE_H_
+#define TLS_SESSION_STATE_H_
+
+#include <botan/tls_magic.h>
+#include <botan/secmem.h>
+#include <vector>
+#include <map>
+
+#include <iostream>
+
+namespace Botan {
+
+struct BOTAN_DLL TLS_Session_Params
+ {
+ SecureVector<byte> master_secret;
+ std::vector<byte> client_random;
+ std::vector<byte> server_random;
+
+ bool resumable;
+ Connection_Side connection_side;
+ Ciphersuite_Code ciphersuite;
+ Compression_Algo compression_method;
+ };
+
+/**
+* TLS_Session_Manager is an interface to systems which can save
+* session parameters for support session resumption.
+*/
+class BOTAN_DLL TLS_Session_Manager
+ {
+ public:
+ /**
+ * Try to load a saved session
+ * @param session_id the session identifier we are trying to resume
+ * @param params will be set to the saved session data (if found),
+ or not modified if not found
+ * @return true if params was modified
+ */
+ virtual bool find(const std::vector<byte>& session_id,
+ TLS_Session_Params& params) = 0;
+
+ /**
+ * Prohibit resumption of this session. Effectively an erase.
+ */
+ virtual void prohibit_resumption(const std::vector<byte>& session_id) = 0;
+
+ /**
+ * Save a session on a best effort basis; the manager may not in
+ * fact be able to save the session for whatever reason, this is
+ * not an error. Caller cannot assume that calling save followed
+ * immediately by find will result in a successful lookup.
+ *
+ * @param session_id the session identifier
+ * @param params to save
+ */
+ virtual void save(const std::vector<byte>& session_id,
+ const TLS_Session_Params& params) = 0;
+
+ virtual ~TLS_Session_Manager() {}
+ };
+
+/**
+* A simple implementation of TLS_Session_Manager that just saves
+* values in memory, with no persistance abilities
+*/
+class BOTAN_DLL TLS_Session_Manager_In_Memory : public TLS_Session_Manager
+ {
+ public:
+ /**
+ * @param max_sessions a hint on the maximum number of sessions
+ * to save at any one time.
+ */
+ TLS_Session_Manager_In_Memory(size_t max_sessions = 0) :
+ max_sessions(max_sessions) {}
+
+ bool find(const std::vector<byte>& session_id,
+ TLS_Session_Params& params)
+ {
+ std::map<std::vector<byte>, TLS_Session_Params>::const_iterator i =
+ sessions.find(session_id);
+
+ std::cout << "Know about " << sessions.size() << " sessions\n";
+
+ if(i != sessions.end())
+ {
+ params = i->second;
+ return true;
+ }
+
+ return false;
+ }
+
+ void prohibit_resumption(const std::vector<byte>& session_id)
+ {
+ std::map<std::vector<byte>, TLS_Session_Params>::const_iterator i =
+ sessions.find(session_id);
+
+ if(i != sessions.end())
+ sessions.erase(i);
+ }
+
+ void save(const std::vector<byte>& session_id,
+ const TLS_Session_Params& session_data)
+ {
+ if(max_sessions != 0)
+ {
+ while(sessions.size() >= max_sessions)
+ sessions.erase(sessions.begin());
+ }
+
+ sessions[session_id] = session_data;
+ }
+
+ private:
+ size_t max_sessions;
+ std::map<std::vector<byte>, TLS_Session_Params> sessions;
+ };
+
+}
+
+#endif