aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorlloyd <[email protected]>2012-09-06 17:51:06 +0000
committerlloyd <[email protected]>2012-09-06 17:51:06 +0000
commite5ab9f9ecf9f76d132815e7e89814727c81d7294 (patch)
tree46b8091307cc51f1d7bbbe8d526d272c7b7f7d2d
parente0029df2f95364f21a9538b493e24661e54efa21 (diff)
Pass process_handshake_msg a reference to the Handshake_State
-rw-r--r--src/tls/msg_cert_verify.cpp26
-rw-r--r--src/tls/msg_client_kex.cpp38
-rw-r--r--src/tls/msg_finished.cpp20
-rw-r--r--src/tls/msg_server_kex.cpp28
-rw-r--r--src/tls/tls_channel.cpp5
-rw-r--r--src/tls/tls_channel.h11
-rw-r--r--src/tls/tls_client.cpp228
-rw-r--r--src/tls/tls_client.h3
-rw-r--r--src/tls/tls_messages.h16
-rw-r--r--src/tls/tls_server.cpp235
-rw-r--r--src/tls/tls_server.h4
11 files changed, 307 insertions, 307 deletions
diff --git a/src/tls/msg_cert_verify.cpp b/src/tls/msg_cert_verify.cpp
index 2d283edca..18d851e53 100644
--- a/src/tls/msg_cert_verify.cpp
+++ b/src/tls/msg_cert_verify.cpp
@@ -20,7 +20,7 @@ namespace TLS {
* Create a new Certificate Verify message
*/
Certificate_Verify::Certificate_Verify(Handshake_IO& io,
- Handshake_State* state,
+ Handshake_State& state,
const Policy& policy,
RandomNumberGenerator& rng,
const Private_Key* priv_key)
@@ -28,14 +28,14 @@ Certificate_Verify::Certificate_Verify(Handshake_IO& io,
BOTAN_ASSERT_NONNULL(priv_key);
std::pair<std::string, Signature_Format> format =
- state->choose_sig_format(priv_key, m_hash_algo, m_sig_algo, true, policy);
+ state.choose_sig_format(priv_key, m_hash_algo, m_sig_algo, true, policy);
PK_Signer signer(*priv_key, format.first, format.second);
- if(state->version() == Protocol_Version::SSL_V3)
+ if(state.version() == Protocol_Version::SSL_V3)
{
- secure_vector<byte> md5_sha = state->hash().final_ssl3(
- state->session_keys().master_secret());
+ secure_vector<byte> md5_sha = state.hash().final_ssl3(
+ state.session_keys().master_secret());
if(priv_key->algo_name() == "DSA")
m_signature = signer.sign_message(&md5_sha[16], md5_sha.size()-16, rng);
@@ -44,10 +44,10 @@ Certificate_Verify::Certificate_Verify(Handshake_IO& io,
}
else
{
- m_signature = signer.sign_message(state->hash().get_contents(), rng);
+ m_signature = signer.sign_message(state.hash().get_contents(), rng);
}
- state->hash().update(io.send(*this));
+ state.hash().update(io.send(*this));
}
/*
@@ -92,25 +92,25 @@ std::vector<byte> Certificate_Verify::serialize() const
* Verify a Certificate Verify message
*/
bool Certificate_Verify::verify(const X509_Certificate& cert,
- const Handshake_State* state) const
+ const Handshake_State& state) const
{
std::unique_ptr<Public_Key> key(cert.subject_public_key());
std::pair<std::string, Signature_Format> format =
- state->understand_sig_format(key.get(), m_hash_algo, m_sig_algo, true);
+ state.understand_sig_format(key.get(), m_hash_algo, m_sig_algo, true);
PK_Verifier verifier(*key, format.first, format.second);
- if(state->version() == Protocol_Version::SSL_V3)
+ if(state.version() == Protocol_Version::SSL_V3)
{
- secure_vector<byte> md5_sha = state->hash().final_ssl3(
- state->session_keys().master_secret());
+ secure_vector<byte> md5_sha = state.hash().final_ssl3(
+ state.session_keys().master_secret());
return verifier.verify_message(&md5_sha[16], md5_sha.size()-16,
&m_signature[0], m_signature.size());
}
- return verifier.verify_message(state->hash().get_contents(), m_signature);
+ return verifier.verify_message(state.hash().get_contents(), m_signature);
}
}
diff --git a/src/tls/msg_client_kex.cpp b/src/tls/msg_client_kex.cpp
index 22cad4e5c..d129969a9 100644
--- a/src/tls/msg_client_kex.cpp
+++ b/src/tls/msg_client_kex.cpp
@@ -48,26 +48,26 @@ secure_vector<byte> strip_leading_zeros(const secure_vector<byte>& input)
* Create a new Client Key Exchange message
*/
Client_Key_Exchange::Client_Key_Exchange(Handshake_IO& io,
- Handshake_State* state,
+ Handshake_State& state,
const Policy& policy,
Credentials_Manager& creds,
const std::vector<X509_Certificate>& peer_certs,
const std::string& hostname,
RandomNumberGenerator& rng)
{
- const std::string kex_algo = state->ciphersuite().kex_algo();
+ const std::string kex_algo = state.ciphersuite().kex_algo();
if(kex_algo == "PSK")
{
std::string identity_hint = "";
- if(state->server_kex())
+ if(state.server_kex())
{
- TLS_Data_Reader reader(state->server_kex()->params());
+ TLS_Data_Reader reader(state.server_kex()->params());
identity_hint = reader.get_string(2, 0, 65535);
}
- const std::string hostname = state->client_hello()->sni_hostname();
+ const std::string hostname = state.client_hello()->sni_hostname();
const std::string psk_identity = creds.psk_identity("tls-client",
hostname,
@@ -82,9 +82,9 @@ Client_Key_Exchange::Client_Key_Exchange(Handshake_IO& io,
append_tls_length_value(m_pre_master, zeros, 2);
append_tls_length_value(m_pre_master, psk.bits_of(), 2);
}
- else if(state->server_kex())
+ else if(state.server_kex())
{
- TLS_Data_Reader reader(state->server_kex()->params());
+ TLS_Data_Reader reader(state.server_kex()->params());
SymmetricKey psk;
@@ -92,7 +92,7 @@ Client_Key_Exchange::Client_Key_Exchange(Handshake_IO& io,
{
std::string identity_hint = reader.get_string(2, 0, 65535);
- const std::string hostname = state->client_hello()->sni_hostname();
+ const std::string hostname = state.client_hello()->sni_hostname();
const std::string psk_identity = creds.psk_identity("tls-client",
hostname,
@@ -239,7 +239,7 @@ Client_Key_Exchange::Client_Key_Exchange(Handshake_IO& io,
if(const RSA_PublicKey* rsa_pub = dynamic_cast<const RSA_PublicKey*>(pub_key.get()))
{
- const Protocol_Version offered_version = state->client_hello()->version();
+ const Protocol_Version offered_version = state.client_hello()->version();
m_pre_master = rng.random_vec(48);
m_pre_master[0] = offered_version.major_version();
@@ -249,7 +249,7 @@ Client_Key_Exchange::Client_Key_Exchange(Handshake_IO& io,
std::vector<byte> encrypted_key = encryptor.encrypt(m_pre_master, rng);
- if(state->version() == Protocol_Version::SSL_V3)
+ if(state.version() == Protocol_Version::SSL_V3)
m_key_material = encrypted_key; // no length field
else
append_tls_length_value(m_key_material, encrypted_key, 2);
@@ -260,24 +260,24 @@ Client_Key_Exchange::Client_Key_Exchange(Handshake_IO& io,
pub_key->algo_name());
}
- state->hash().update(io.send(*this));
+ state.hash().update(io.send(*this));
}
/*
* Read a Client Key Exchange message
*/
Client_Key_Exchange::Client_Key_Exchange(const std::vector<byte>& contents,
- const Handshake_State* state,
+ const Handshake_State& state,
const Private_Key* server_rsa_kex_key,
Credentials_Manager& creds,
const Policy& policy,
RandomNumberGenerator& rng)
{
- const std::string kex_algo = state->ciphersuite().kex_algo();
+ const std::string kex_algo = state.ciphersuite().kex_algo();
if(kex_algo == "RSA")
{
- BOTAN_ASSERT(state->server_certs() && !state->server_certs()->cert_chain().empty(),
+ BOTAN_ASSERT(state.server_certs() && !state.server_certs()->cert_chain().empty(),
"RSA key exchange negotiated so server sent a certificate");
if(!server_rsa_kex_key)
@@ -288,11 +288,11 @@ Client_Key_Exchange::Client_Key_Exchange(const std::vector<byte>& contents,
PK_Decryptor_EME decryptor(*server_rsa_kex_key, "PKCS1v15");
- Protocol_Version client_version = state->client_hello()->version();
+ Protocol_Version client_version = state.client_hello()->version();
try
{
- if(state->version() == Protocol_Version::SSL_V3)
+ if(state.version() == Protocol_Version::SSL_V3)
{
m_pre_master = decryptor.decrypt(contents);
}
@@ -328,7 +328,7 @@ Client_Key_Exchange::Client_Key_Exchange(const std::vector<byte>& contents,
const std::string psk_identity = reader.get_string(2, 0, 65535);
psk = creds.psk("tls-server",
- state->client_hello()->sni_hostname(),
+ state.client_hello()->sni_hostname(),
psk_identity);
if(psk.length() == 0)
@@ -349,14 +349,14 @@ Client_Key_Exchange::Client_Key_Exchange(const std::vector<byte>& contents,
}
else if(kex_algo == "SRP_SHA")
{
- SRP6_Server_Session& srp = state->server_kex()->server_srp_params();
+ SRP6_Server_Session& srp = state.server_kex()->server_srp_params();
m_pre_master = srp.step2(BigInt::decode(reader.get_range<byte>(2, 0, 65535))).bits_of();
}
else if(kex_algo == "DH" || kex_algo == "DHE_PSK" ||
kex_algo == "ECDH" || kex_algo == "ECDHE_PSK")
{
- const Private_Key& private_key = state->server_kex()->server_kex_key();
+ const Private_Key& private_key = state.server_kex()->server_kex_key();
const PK_Key_Agreement_Key* ka_key =
dynamic_cast<const PK_Key_Agreement_Key*>(&private_key);
diff --git a/src/tls/msg_finished.cpp b/src/tls/msg_finished.cpp
index 390f05300..059ed8363 100644
--- a/src/tls/msg_finished.cpp
+++ b/src/tls/msg_finished.cpp
@@ -18,15 +18,15 @@ namespace {
/*
* Compute the verify_data
*/
-std::vector<byte> finished_compute_verify(const Handshake_State* state,
+std::vector<byte> finished_compute_verify(const Handshake_State& state,
Connection_Side side)
{
- if(state->version() == Protocol_Version::SSL_V3)
+ if(state.version() == Protocol_Version::SSL_V3)
{
const byte SSL_CLIENT_LABEL[] = { 0x43, 0x4C, 0x4E, 0x54 };
const byte SSL_SERVER_LABEL[] = { 0x53, 0x52, 0x56, 0x52 };
- Handshake_Hash hash = state->hash(); // don't modify state
+ Handshake_Hash hash = state.hash(); // don't modify state
std::vector<byte> ssl3_finished;
@@ -35,7 +35,7 @@ std::vector<byte> finished_compute_verify(const Handshake_State* state,
else
hash.update(SSL_SERVER_LABEL, sizeof(SSL_SERVER_LABEL));
- return unlock(hash.final_ssl3(state->session_keys().master_secret()));
+ return unlock(hash.final_ssl3(state.session_keys().master_secret()));
}
else
{
@@ -47,7 +47,7 @@ std::vector<byte> finished_compute_verify(const Handshake_State* state,
0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x20, 0x66, 0x69, 0x6E, 0x69,
0x73, 0x68, 0x65, 0x64 };
- std::unique_ptr<KDF> prf(state->protocol_specific_prf());
+ std::unique_ptr<KDF> prf(state.protocol_specific_prf());
std::vector<byte> input;
if(side == CLIENT)
@@ -55,9 +55,9 @@ std::vector<byte> finished_compute_verify(const Handshake_State* state,
else
input += std::make_pair(TLS_SERVER_LABEL, sizeof(TLS_SERVER_LABEL));
- input += state->hash().final(state->version(), state->ciphersuite().mac_algo());
+ input += state.hash().final(state.version(), state.ciphersuite().mac_algo());
- return unlock(prf->derive_key(12, state->session_keys().master_secret(), input));
+ return unlock(prf->derive_key(12, state.session_keys().master_secret(), input));
}
}
@@ -67,11 +67,11 @@ std::vector<byte> finished_compute_verify(const Handshake_State* state,
* Create a new Finished message
*/
Finished::Finished(Handshake_IO& io,
- Handshake_State* state,
+ Handshake_State& state,
Connection_Side side)
{
m_verification_data = finished_compute_verify(state, side);
- state->hash().update(io.send(*this));
+ state.hash().update(io.send(*this));
}
/*
@@ -93,7 +93,7 @@ Finished::Finished(const std::vector<byte>& buf)
/*
* Verify a Finished message
*/
-bool Finished::verify(const Handshake_State* state,
+bool Finished::verify(const Handshake_State& state,
Connection_Side side) const
{
return (m_verification_data == finished_compute_verify(state, side));
diff --git a/src/tls/msg_server_kex.cpp b/src/tls/msg_server_kex.cpp
index b3c4e9017..250a8c126 100644
--- a/src/tls/msg_server_kex.cpp
+++ b/src/tls/msg_server_kex.cpp
@@ -28,14 +28,14 @@ namespace TLS {
* Create a new Server Key Exchange message
*/
Server_Key_Exchange::Server_Key_Exchange(Handshake_IO& io,
- Handshake_State* state,
+ Handshake_State& state,
const Policy& policy,
Credentials_Manager& creds,
RandomNumberGenerator& rng,
const Private_Key* signing_key)
{
- const std::string hostname = state->client_hello()->sni_hostname();
- const std::string kex_algo = state->ciphersuite().kex_algo();
+ const std::string hostname = state.client_hello()->sni_hostname();
+ const std::string kex_algo = state.ciphersuite().kex_algo();
if(kex_algo == "PSK" || kex_algo == "DHE_PSK" || kex_algo == "ECDHE_PSK")
{
@@ -57,7 +57,7 @@ Server_Key_Exchange::Server_Key_Exchange(Handshake_IO& io,
else if(kex_algo == "ECDH" || kex_algo == "ECDHE_PSK")
{
const std::vector<std::string>& curves =
- state->client_hello()->supported_ecc_curves();
+ state.client_hello()->supported_ecc_curves();
if(curves.empty())
throw Internal_Error("Client sent no ECC extension but we negotiated ECDH");
@@ -90,7 +90,7 @@ Server_Key_Exchange::Server_Key_Exchange(Handshake_IO& io,
}
else if(kex_algo == "SRP_SHA")
{
- const std::string srp_identifier = state->client_hello()->srp_identifier();
+ const std::string srp_identifier = state.client_hello()->srp_identifier();
std::string group_id;
BigInt v;
@@ -120,22 +120,22 @@ Server_Key_Exchange::Server_Key_Exchange(Handshake_IO& io,
else if(kex_algo != "PSK")
throw Internal_Error("Server_Key_Exchange: Unknown kex type " + kex_algo);
- if(state->ciphersuite().sig_algo() != "")
+ if(state.ciphersuite().sig_algo() != "")
{
BOTAN_ASSERT(signing_key, "Signing key was set");
std::pair<std::string, Signature_Format> format =
- state->choose_sig_format(signing_key, m_hash_algo, m_sig_algo, false, policy);
+ state.choose_sig_format(signing_key, m_hash_algo, m_sig_algo, false, policy);
PK_Signer signer(*signing_key, format.first, format.second);
- signer.update(state->client_hello()->random());
- signer.update(state->server_hello()->random());
+ signer.update(state.client_hello()->random());
+ signer.update(state.server_hello()->random());
signer.update(params());
m_signature = signer.signature(rng);
}
- state->hash().update(io.send(*this));
+ state.hash().update(io.send(*this));
}
/**
@@ -255,17 +255,17 @@ std::vector<byte> Server_Key_Exchange::serialize() const
* Verify a Server Key Exchange message
*/
bool Server_Key_Exchange::verify(const X509_Certificate& cert,
- const Handshake_State* state) const
+ const Handshake_State& state) const
{
std::unique_ptr<Public_Key> key(cert.subject_public_key());
std::pair<std::string, Signature_Format> format =
- state->understand_sig_format(key.get(), m_hash_algo, m_sig_algo, false);
+ state.understand_sig_format(key.get(), m_hash_algo, m_sig_algo, false);
PK_Verifier verifier(*key, format.first, format.second);
- verifier.update(state->client_hello()->random());
- verifier.update(state->server_hello()->random());
+ verifier.update(state.client_hello()->random());
+ verifier.update(state.server_hello()->random());
verifier.update(params());
return verifier.check_signature(m_signature);
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index 024aec099..a2260e448 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -175,7 +175,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
if(msg.first == HANDSHAKE_NONE) // no full handshake yet
break;
- process_handshake_msg(msg.first, msg.second);
+ process_handshake_msg(*m_state.get(), msg.first, msg.second);
}
}
else if(rec_type == HEARTBEAT && m_peer_supports_heartbeats)
@@ -383,6 +383,9 @@ void Channel::send_alert(const Alert& alert)
catch(...) { /* swallow it */ }
}
+ if(alert.type() == Alert::NO_RENEGOTIATION)
+ m_state.reset();
+
if(alert.is_fatal() && !m_active_session.empty())
{
m_session_manager.remove_entry(m_active_session);
diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h
index 1730be792..d1a6e8b00 100644
--- a/src/tls/tls_channel.h
+++ b/src/tls/tls_channel.h
@@ -96,6 +96,12 @@ class BOTAN_DLL Channel
virtual ~Channel();
protected:
+ virtual void process_handshake_msg(class Handshake_State& state,
+ Handshake_Type type,
+ const std::vector<byte>& contents) = 0;
+
+ virtual class Handshake_State* new_handshake_state() = 0;
+
/**
* Send a TLS alert message. If the alert is fatal, the internal
* state (keys, etc) will be reset.
@@ -105,11 +111,6 @@ class BOTAN_DLL Channel
void activate_session(const std::vector<byte>& session_id);
- virtual void process_handshake_msg(Handshake_Type type,
- const std::vector<byte>& contents) = 0;
-
- virtual class Handshake_State* new_handshake_state() = 0;
-
void heartbeat_support(bool peer_supports, bool allowed_to_send);
void set_protocol_version(Protocol_Version version);
diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp
index d2508d579..8cb89ffb5 100644
--- a/src/tls/tls_client.cpp
+++ b/src/tls/tls_client.cpp
@@ -146,25 +146,21 @@ void Client::initiate_handshake(bool force_full_renegotiation,
/*
* Process a handshake message
*/
-void Client::process_handshake_msg(Handshake_Type type,
+void Client::process_handshake_msg(Handshake_State& state,
+ Handshake_Type type,
const std::vector<byte>& contents)
{
- if(!m_state)
- throw Unexpected_Message("Unexpected handshake message from server");
-
if(type == HELLO_REQUEST)
{
Hello_Request hello_request(contents);
// Ignore request entirely if we are currently negotiating a handshake
- if(m_state->client_hello())
+ if(state.client_hello())
return;
if(!m_policy.allow_server_initiated_renegotiation() ||
(!m_policy.allow_insecure_renegotiation() && !m_secure_renegotiation.supported()))
{
- m_state.reset();
-
// RFC 5746 section 4.2
send_alert(Alert(Alert::NO_RENEGOTIATION));
return;
@@ -175,69 +171,69 @@ void Client::process_handshake_msg(Handshake_Type type,
return;
}
- m_state->confirm_transition_to(type);
+ state.confirm_transition_to(type);
if(type != HANDSHAKE_CCS && type != FINISHED && type != HELLO_VERIFY_REQUEST)
- m_state->hash().update(m_state->handshake_io().format(contents, type));
+ state.hash().update(state.handshake_io().format(contents, type));
if(type == HELLO_VERIFY_REQUEST)
{
- m_state->set_expected_next(SERVER_HELLO);
- m_state->set_expected_next(HELLO_VERIFY_REQUEST); // might get it again
+ state.set_expected_next(SERVER_HELLO);
+ state.set_expected_next(HELLO_VERIFY_REQUEST); // might get it again
Hello_Verify_Request hello_verify_request(contents);
- m_state->note_message(hello_verify_request);
+ state.note_message(hello_verify_request);
std::unique_ptr<Client_Hello> client_hello_w_cookie(
- new Client_Hello(m_state->handshake_io(),
- m_state->hash(),
- *m_state->client_hello(),
+ new Client_Hello(state.handshake_io(),
+ state.hash(),
+ *state.client_hello(),
hello_verify_request));
- m_state->client_hello(client_hello_w_cookie.release());
+ state.client_hello(client_hello_w_cookie.release());
}
else if(type == SERVER_HELLO)
{
- m_state->server_hello(new Server_Hello(contents));
+ state.server_hello(new Server_Hello(contents));
- if(!m_state->client_hello()->offered_suite(m_state->server_hello()->ciphersuite()))
+ if(!state.client_hello()->offered_suite(state.server_hello()->ciphersuite()))
{
throw TLS_Exception(Alert::HANDSHAKE_FAILURE,
"Server replied with ciphersuite we didn't send");
}
- if(!value_exists(m_state->client_hello()->compression_methods(),
- m_state->server_hello()->compression_method()))
+ if(!value_exists(state.client_hello()->compression_methods(),
+ state.server_hello()->compression_method()))
{
throw TLS_Exception(Alert::HANDSHAKE_FAILURE,
"Server replied with compression method we didn't send");
}
- if(!m_state->client_hello()->next_protocol_notification() &&
- m_state->server_hello()->next_protocol_notification())
+ if(!state.client_hello()->next_protocol_notification() &&
+ state.server_hello()->next_protocol_notification())
{
throw TLS_Exception(Alert::HANDSHAKE_FAILURE,
"Server sent next protocol but we didn't request it");
}
- if(m_state->server_hello()->supports_session_ticket())
+ if(state.server_hello()->supports_session_ticket())
{
- if(!m_state->client_hello()->supports_session_ticket())
+ if(!state.client_hello()->supports_session_ticket())
throw TLS_Exception(Alert::HANDSHAKE_FAILURE,
"Server sent session ticket extension but we did not");
}
- set_protocol_version(m_state->server_hello()->version());
+ set_protocol_version(state.server_hello()->version());
- m_secure_renegotiation.update(m_state->server_hello());
+ m_secure_renegotiation.update(state.server_hello());
- heartbeat_support(m_state->server_hello()->supports_heartbeats(),
- m_state->server_hello()->peer_can_send_heartbeats());
+ heartbeat_support(state.server_hello()->supports_heartbeats(),
+ state.server_hello()->peer_can_send_heartbeats());
const bool server_returned_same_session_id =
- !m_state->server_hello()->session_id().empty() &&
- (m_state->server_hello()->session_id() == m_state->client_hello()->session_id());
+ !state.server_hello()->session_id().empty() &&
+ (state.server_hello()->session_id() == state.client_hello()->session_id());
if(server_returned_same_session_id)
{
@@ -247,40 +243,40 @@ void Client::process_handshake_msg(Handshake_Type type,
* In this case, we offered the version used in the original
* session, and the server must resume with the same version.
*/
- if(m_state->server_hello()->version() != m_state->client_hello()->version())
+ if(state.server_hello()->version() != state.client_hello()->version())
throw TLS_Exception(Alert::HANDSHAKE_FAILURE,
"Server resumed session but with wrong version");
- m_state->compute_session_keys(
- dynamic_cast<Client_Handshake_State&>(*m_state).resume_master_secret
+ state.compute_session_keys(
+ dynamic_cast<Client_Handshake_State&>(state).resume_master_secret
);
- if(m_state->server_hello()->supports_session_ticket())
- m_state->set_expected_next(NEW_SESSION_TICKET);
+ if(state.server_hello()->supports_session_ticket())
+ state.set_expected_next(NEW_SESSION_TICKET);
else
- m_state->set_expected_next(HANDSHAKE_CCS);
+ state.set_expected_next(HANDSHAKE_CCS);
}
else
{
// new session
- if(m_state->version() > m_state->client_hello()->version())
+ if(state.version() > state.client_hello()->version())
{
throw TLS_Exception(Alert::HANDSHAKE_FAILURE,
"Server replied with later version than in hello");
}
- if(!m_policy.acceptable_protocol_version(m_state->version()))
+ if(!m_policy.acceptable_protocol_version(state.version()))
{
throw TLS_Exception(Alert::PROTOCOL_VERSION,
"Server version is unacceptable by policy");
}
- if(m_state->ciphersuite().sig_algo() != "")
+ if(state.ciphersuite().sig_algo() != "")
{
- m_state->set_expected_next(CERTIFICATE);
+ state.set_expected_next(CERTIFICATE);
}
- else if(m_state->ciphersuite().kex_algo() == "PSK")
+ else if(state.ciphersuite().kex_algo() == "PSK")
{
/* PSK is anonymous so no certificate/cert req message is
ever sent. The server may or may not send a server kex,
@@ -290,35 +286,35 @@ void Client::process_handshake_msg(Handshake_Type type,
DH exchange portion.
*/
- m_state->set_expected_next(SERVER_KEX);
- m_state->set_expected_next(SERVER_HELLO_DONE);
+ state.set_expected_next(SERVER_KEX);
+ state.set_expected_next(SERVER_HELLO_DONE);
}
- else if(m_state->ciphersuite().kex_algo() != "RSA")
+ else if(state.ciphersuite().kex_algo() != "RSA")
{
- m_state->set_expected_next(SERVER_KEX);
+ state.set_expected_next(SERVER_KEX);
}
else
{
- m_state->set_expected_next(CERTIFICATE_REQUEST); // optional
- m_state->set_expected_next(SERVER_HELLO_DONE);
+ state.set_expected_next(CERTIFICATE_REQUEST); // optional
+ state.set_expected_next(SERVER_HELLO_DONE);
}
}
}
else if(type == CERTIFICATE)
{
- if(m_state->ciphersuite().kex_algo() != "RSA")
+ if(state.ciphersuite().kex_algo() != "RSA")
{
- m_state->set_expected_next(SERVER_KEX);
+ state.set_expected_next(SERVER_KEX);
}
else
{
- m_state->set_expected_next(CERTIFICATE_REQUEST); // optional
- m_state->set_expected_next(SERVER_HELLO_DONE);
+ state.set_expected_next(CERTIFICATE_REQUEST); // optional
+ state.set_expected_next(SERVER_HELLO_DONE);
}
- m_state->server_certs(new Certificate(contents));
+ state.server_certs(new Certificate(contents));
- m_peer_certs = m_state->server_certs()->cert_chain();
+ m_peer_certs = state.server_certs()->cert_chain();
if(m_peer_certs.empty())
throw TLS_Exception(Alert::HANDSHAKE_FAILURE,
"Client: No certificates sent by server");
@@ -334,25 +330,25 @@ void Client::process_handshake_msg(Handshake_Type type,
std::unique_ptr<Public_Key> peer_key(m_peer_certs[0].subject_public_key());
- if(peer_key->algo_name() != m_state->ciphersuite().sig_algo())
+ if(peer_key->algo_name() != state.ciphersuite().sig_algo())
throw TLS_Exception(Alert::ILLEGAL_PARAMETER,
"Certificate key type did not match ciphersuite");
}
else if(type == SERVER_KEX)
{
- m_state->set_expected_next(CERTIFICATE_REQUEST); // optional
- m_state->set_expected_next(SERVER_HELLO_DONE);
+ state.set_expected_next(CERTIFICATE_REQUEST); // optional
+ state.set_expected_next(SERVER_HELLO_DONE);
- m_state->server_kex(
+ state.server_kex(
new Server_Key_Exchange(contents,
- m_state->ciphersuite().kex_algo(),
- m_state->ciphersuite().sig_algo(),
- m_state->version())
+ state.ciphersuite().kex_algo(),
+ state.ciphersuite().sig_algo(),
+ state.version())
);
- if(m_state->ciphersuite().sig_algo() != "")
+ if(state.ciphersuite().sig_algo() != "")
{
- if(!m_state->server_kex()->verify(m_peer_certs[0], m_state.get()))
+ if(!state.server_kex()->verify(m_peer_certs[0], state))
{
throw TLS_Exception(Alert::DECRYPT_ERROR,
"Bad signature on server key exchange");
@@ -361,37 +357,37 @@ void Client::process_handshake_msg(Handshake_Type type,
}
else if(type == CERTIFICATE_REQUEST)
{
- m_state->set_expected_next(SERVER_HELLO_DONE);
- m_state->cert_req(
- new Certificate_Req(contents, m_state->version())
+ state.set_expected_next(SERVER_HELLO_DONE);
+ state.cert_req(
+ new Certificate_Req(contents, state.version())
);
}
else if(type == SERVER_HELLO_DONE)
{
- m_state->server_hello_done(
+ state.server_hello_done(
new Server_Hello_Done(contents)
);
- if(m_state->received_handshake_msg(CERTIFICATE_REQUEST))
+ if(state.received_handshake_msg(CERTIFICATE_REQUEST))
{
const std::vector<std::string>& types =
- m_state->cert_req()->acceptable_cert_types();
+ state.cert_req()->acceptable_cert_types();
std::vector<X509_Certificate> client_certs =
m_creds.cert_chain(types,
"tls-client",
m_hostname);
- m_state->client_certs(
- new Certificate(m_state->handshake_io(),
- m_state->hash(),
+ state.client_certs(
+ new Certificate(state.handshake_io(),
+ state.hash(),
client_certs)
);
}
- m_state->client_kex(
- new Client_Key_Exchange(m_state->handshake_io(),
- m_state.get(),
+ state.client_kex(
+ new Client_Key_Exchange(state.handshake_io(),
+ state,
m_policy,
m_creds,
m_peer_certs,
@@ -399,114 +395,114 @@ void Client::process_handshake_msg(Handshake_Type type,
m_rng)
);
- m_state->compute_session_keys();
+ state.compute_session_keys();
- if(m_state->received_handshake_msg(CERTIFICATE_REQUEST) &&
- !m_state->client_certs()->empty())
+ if(state.received_handshake_msg(CERTIFICATE_REQUEST) &&
+ !state.client_certs()->empty())
{
Private_Key* private_key =
- m_creds.private_key_for(m_state->client_certs()->cert_chain()[0],
+ m_creds.private_key_for(state.client_certs()->cert_chain()[0],
"tls-client",
m_hostname);
- m_state->client_verify(
- new Certificate_Verify(m_state->handshake_io(),
- m_state.get(),
+ state.client_verify(
+ new Certificate_Verify(state.handshake_io(),
+ state,
m_policy,
m_rng,
private_key)
);
}
- m_state->handshake_io().send(Change_Cipher_Spec());
+ state.handshake_io().send(Change_Cipher_Spec());
change_cipher_spec_writer(CLIENT);
- if(m_state->server_hello()->next_protocol_notification())
+ if(state.server_hello()->next_protocol_notification())
{
const std::string protocol =
- dynamic_cast<Client_Handshake_State&>(*m_state).client_npn_cb(
- m_state->server_hello()->next_protocols());
+ dynamic_cast<Client_Handshake_State&>(state).client_npn_cb(
+ state.server_hello()->next_protocols());
- m_state->next_protocol(
- new Next_Protocol(m_state->handshake_io(), m_state->hash(), protocol)
+ state.next_protocol(
+ new Next_Protocol(state.handshake_io(), state.hash(), protocol)
);
}
- m_state->client_finished(
- new Finished(m_state->handshake_io(), m_state.get(), CLIENT)
+ state.client_finished(
+ new Finished(state.handshake_io(), state, CLIENT)
);
- if(m_state->server_hello()->supports_session_ticket())
- m_state->set_expected_next(NEW_SESSION_TICKET);
+ if(state.server_hello()->supports_session_ticket())
+ state.set_expected_next(NEW_SESSION_TICKET);
else
- m_state->set_expected_next(HANDSHAKE_CCS);
+ state.set_expected_next(HANDSHAKE_CCS);
}
else if(type == NEW_SESSION_TICKET)
{
- m_state->new_session_ticket(new New_Session_Ticket(contents));
+ state.new_session_ticket(new New_Session_Ticket(contents));
- m_state->set_expected_next(HANDSHAKE_CCS);
+ state.set_expected_next(HANDSHAKE_CCS);
}
else if(type == HANDSHAKE_CCS)
{
- m_state->set_expected_next(FINISHED);
+ state.set_expected_next(FINISHED);
change_cipher_spec_reader(CLIENT);
}
else if(type == FINISHED)
{
- m_state->set_expected_next(HELLO_REQUEST);
+ state.set_expected_next(HELLO_REQUEST);
- m_state->server_finished(new Finished(contents));
+ state.server_finished(new Finished(contents));
- if(!m_state->server_finished()->verify(m_state.get(), SERVER))
+ if(!state.server_finished()->verify(state, SERVER))
throw TLS_Exception(Alert::DECRYPT_ERROR,
"Finished message didn't verify");
- m_state->hash().update(m_state->handshake_io().format(contents, type));
+ state.hash().update(state.handshake_io().format(contents, type));
- if(!m_state->client_finished()) // session resume case
+ if(!state.client_finished()) // session resume case
{
- m_state->handshake_io().send(Change_Cipher_Spec());
+ state.handshake_io().send(Change_Cipher_Spec());
change_cipher_spec_writer(CLIENT);
- if(m_state->server_hello()->next_protocol_notification())
+ if(state.server_hello()->next_protocol_notification())
{
const std::string protocol =
dynamic_cast<Client_Handshake_State&>(*m_state).client_npn_cb(
- m_state->server_hello()->next_protocols());
+ state.server_hello()->next_protocols());
- m_state->next_protocol(
- new Next_Protocol(m_state->handshake_io(), m_state->hash(), protocol)
+ state.next_protocol(
+ new Next_Protocol(state.handshake_io(), state.hash(), protocol)
);
}
- m_state->client_finished(
- new Finished(m_state->handshake_io(), m_state.get(), CLIENT)
+ state.client_finished(
+ new Finished(state.handshake_io(), state, CLIENT)
);
}
- m_secure_renegotiation.update(m_state->client_finished(),
- m_state->server_finished());
+ m_secure_renegotiation.update(state.client_finished(),
+ state.server_finished());
- std::vector<byte> session_id = m_state->server_hello()->session_id();
+ std::vector<byte> session_id = state.server_hello()->session_id();
- const std::vector<byte>& session_ticket = m_state->session_ticket();
+ const std::vector<byte>& session_ticket = state.session_ticket();
if(session_id.empty() && !session_ticket.empty())
session_id = make_hello_random(m_rng);
Session session_info(
session_id,
- m_state->session_keys().master_secret(),
- m_state->server_hello()->version(),
- m_state->server_hello()->ciphersuite(),
- m_state->server_hello()->compression_method(),
+ state.session_keys().master_secret(),
+ state.server_hello()->version(),
+ state.server_hello()->ciphersuite(),
+ state.server_hello()->compression_method(),
CLIENT,
m_secure_renegotiation.supported(),
- m_state->server_hello()->fragment_size(),
+ state.server_hello()->fragment_size(),
m_peer_certs,
session_ticket,
m_hostname,
diff --git a/src/tls/tls_client.h b/src/tls/tls_client.h
index b256a75e9..d7c7598b5 100644
--- a/src/tls/tls_client.h
+++ b/src/tls/tls_client.h
@@ -74,7 +74,8 @@ class BOTAN_DLL Client : public Channel
std::function<std::string (std::vector<std::string>)> next_protocol =
std::function<std::string (std::vector<std::string>)>());
- void process_handshake_msg(Handshake_Type type,
+ void process_handshake_msg(Handshake_State& state,
+ Handshake_Type type,
const std::vector<byte>& contents) override;
class Handshake_State* new_handshake_state() override;
diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h
index 555520073..6fd7f675c 100644
--- a/src/tls/tls_messages.h
+++ b/src/tls/tls_messages.h
@@ -237,7 +237,7 @@ class Client_Key_Exchange : public Handshake_Message
{ return m_pre_master; }
Client_Key_Exchange(Handshake_IO& io,
- Handshake_State* state,
+ Handshake_State& state,
const Policy& policy,
Credentials_Manager& creds,
const std::vector<X509_Certificate>& peer_certs,
@@ -245,7 +245,7 @@ class Client_Key_Exchange : public Handshake_Message
RandomNumberGenerator& rng);
Client_Key_Exchange(const std::vector<byte>& buf,
- const Handshake_State* state,
+ const Handshake_State& state,
const Private_Key* server_rsa_kex_key,
Credentials_Manager& creds,
const Policy& policy,
@@ -329,10 +329,10 @@ class Certificate_Verify : public Handshake_Message
* @param state the handshake state
*/
bool verify(const X509_Certificate& cert,
- const Handshake_State* state) const;
+ const Handshake_State& state) const;
Certificate_Verify(Handshake_IO& io,
- Handshake_State* state,
+ Handshake_State& state,
const Policy& policy,
RandomNumberGenerator& rng,
const Private_Key* key);
@@ -358,11 +358,11 @@ class Finished : public Handshake_Message
std::vector<byte> verify_data() const
{ return m_verification_data; }
- bool verify(const Handshake_State* state,
+ bool verify(const Handshake_State& state,
Connection_Side side) const;
Finished(Handshake_IO& io,
- Handshake_State* state,
+ Handshake_State& state,
Connection_Side side);
Finished(const std::vector<byte>& buf);
@@ -398,7 +398,7 @@ class Server_Key_Exchange : public Handshake_Message
const std::vector<byte>& params() const { return m_params; }
bool verify(const X509_Certificate& cert,
- const Handshake_State* state) const;
+ const Handshake_State& state) const;
// Only valid for certain kex types
const Private_Key& server_kex_key() const;
@@ -407,7 +407,7 @@ class Server_Key_Exchange : public Handshake_Message
SRP6_Server_Session& server_srp_params() const;
Server_Key_Exchange(Handshake_IO& io,
- Handshake_State* state,
+ Handshake_State& state,
const Policy& policy,
Credentials_Manager& creds,
RandomNumberGenerator& rng,
diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp
index 21ea568e4..d858592d4 100644
--- a/src/tls/tls_server.cpp
+++ b/src/tls/tls_server.cpp
@@ -254,13 +254,11 @@ void Server::renegotiate(bool force_full_renegotiation)
/*
* Process a handshake message
*/
-void Server::process_handshake_msg(Handshake_Type type,
+void Server::process_handshake_msg(Handshake_State& state,
+ Handshake_Type type,
const std::vector<byte>& contents)
{
- if(!m_state)
- throw Unexpected_Message("Unexpected handshake message from client");
-
- m_state->confirm_transition_to(type);
+ state.confirm_transition_to(type);
/*
* The change cipher spec message isn't technically a handshake
@@ -272,9 +270,9 @@ void Server::process_handshake_msg(Handshake_Type type,
if(type != HANDSHAKE_CCS && type != FINISHED && type != CERTIFICATE_VERIFY)
{
if(type == CLIENT_HELLO_SSLV2)
- m_state->hash().update(contents);
+ state.hash().update(contents);
else
- m_state->hash().update(m_state->handshake_io().format(contents, type));
+ state.hash().update(state.handshake_io().format(contents, type));
}
if(type == CLIENT_HELLO || type == CLIENT_HELLO_SSLV2)
@@ -282,17 +280,16 @@ void Server::process_handshake_msg(Handshake_Type type,
if(!m_policy.allow_insecure_renegotiation() &&
!(m_secure_renegotiation.initial_handshake() || m_secure_renegotiation.supported()))
{
- m_state.reset();
send_alert(Alert(Alert::NO_RENEGOTIATION));
return;
}
- m_state->client_hello(new Client_Hello(contents, type));
+ state.client_hello(new Client_Hello(contents, type));
- if(m_state->client_hello()->sni_hostname() != "")
- m_hostname = m_state->client_hello()->sni_hostname();
+ if(state.client_hello()->sni_hostname() != "")
+ m_hostname = state.client_hello()->sni_hostname();
- Protocol_Version client_version = m_state->client_hello()->version();
+ Protocol_Version client_version = state.client_hello()->version();
const Protocol_Version prev_version = current_protocol_version();
const bool is_renegotiation = prev_version.valid();
@@ -346,20 +343,20 @@ void Server::process_handshake_msg(Handshake_Type type,
"Client version is unacceptable by policy");
}
- m_secure_renegotiation.update(m_state->client_hello());
+ m_secure_renegotiation.update(state.client_hello());
set_protocol_version(negotiated_version);
- heartbeat_support(m_state->client_hello()->supports_heartbeats(),
- m_state->client_hello()->peer_can_send_heartbeats());
+ heartbeat_support(state.client_hello()->supports_heartbeats(),
+ state.client_hello()->peer_can_send_heartbeats());
Session session_info;
const bool resuming =
- dynamic_cast<Server_Handshake_State&>(*m_state).allow_session_resumption &&
+ dynamic_cast<Server_Handshake_State&>(state).allow_session_resumption &&
check_for_resume(session_info,
m_session_manager,
m_creds,
- m_state->client_hello(),
+ state.client_hello(),
std::chrono::seconds(m_policy.session_ticket_lifetime()));
bool have_session_ticket_key = false;
@@ -376,15 +373,15 @@ void Server::process_handshake_msg(Handshake_Type type,
// resume session
const bool offer_new_session_ticket =
- (m_state->client_hello()->supports_session_ticket() &&
- m_state->client_hello()->session_ticket().empty() &&
+ (state.client_hello()->supports_session_ticket() &&
+ state.client_hello()->session_ticket().empty() &&
have_session_ticket_key);
- m_state->server_hello(
+ state.server_hello(
new Server_Hello(
- m_state->handshake_io(),
- m_state->hash(),
- m_state->client_hello()->session_id(),
+ state.handshake_io(),
+ state.hash(),
+ state.client_hello()->session_id(),
Protocol_Version(session_info.version()),
session_info.ciphersuite_code(),
session_info.compression_method(),
@@ -392,64 +389,64 @@ void Server::process_handshake_msg(Handshake_Type type,
m_secure_renegotiation.supported(),
m_secure_renegotiation.for_server_hello(),
offer_new_session_ticket,
- m_state->client_hello()->next_protocol_notification(),
+ state.client_hello()->next_protocol_notification(),
m_possible_protocols,
- m_state->client_hello()->supports_heartbeats(),
+ state.client_hello()->supports_heartbeats(),
m_rng)
);
- m_secure_renegotiation.update(m_state->server_hello());
+ m_secure_renegotiation.update(state.server_hello());
if(session_info.fragment_size())
set_maximum_fragment_size(session_info.fragment_size());
- m_state->compute_session_keys(session_info.master_secret());
+ state.compute_session_keys(session_info.master_secret());
if(!m_handshake_fn(session_info))
{
m_session_manager.remove_entry(session_info.session_id());
- if(m_state->server_hello()->supports_session_ticket()) // send an empty ticket
+ if(state.server_hello()->supports_session_ticket()) // send an empty ticket
{
- m_state->new_session_ticket(
- new New_Session_Ticket(m_state->handshake_io(),
- m_state->hash())
+ state.new_session_ticket(
+ new New_Session_Ticket(state.handshake_io(),
+ state.hash())
);
}
}
- if(m_state->server_hello()->supports_session_ticket() && !m_state->new_session_ticket())
+ if(state.server_hello()->supports_session_ticket() && !state.new_session_ticket())
{
try
{
const SymmetricKey ticket_key = m_creds.psk("tls-server", "session-ticket", "");
- m_state->new_session_ticket(
- new New_Session_Ticket(m_state->handshake_io(),
- m_state->hash(),
+ state.new_session_ticket(
+ new New_Session_Ticket(state.handshake_io(),
+ state.hash(),
session_info.encrypt(ticket_key, m_rng),
m_policy.session_ticket_lifetime())
);
}
catch(...) {}
- if(!m_state->new_session_ticket())
+ if(!state.new_session_ticket())
{
- m_state->new_session_ticket(
- new New_Session_Ticket(m_state->handshake_io(), m_state->hash())
+ state.new_session_ticket(
+ new New_Session_Ticket(state.handshake_io(), state.hash())
);
}
}
- m_state->handshake_io().send(Change_Cipher_Spec());
+ state.handshake_io().send(Change_Cipher_Spec());
change_cipher_spec_writer(SERVER);
- m_state->server_finished(
- new Finished(m_state->handshake_io(), m_state.get(), SERVER)
+ state.server_finished(
+ new Finished(state.handshake_io(), state, SERVER)
);
- m_state->set_expected_next(HANDSHAKE_CCS);
+ state.set_expected_next(HANDSHAKE_CCS);
}
else // new session
{
@@ -472,44 +469,44 @@ void Server::process_handshake_msg(Handshake_Type type,
send_alert(Alert(Alert::UNRECOGNIZED_NAME));
}
- m_state->server_hello(
+ state.server_hello(
new Server_Hello(
- m_state->handshake_io(),
- m_state->hash(),
+ state.handshake_io(),
+ state.hash(),
make_hello_random(m_rng), // new session ID
- m_state->version(),
+ state.version(),
choose_ciphersuite(m_policy,
- m_state->version(),
+ state.version(),
m_creds,
cert_chains,
- m_state->client_hello()),
- choose_compression(m_policy, m_state->client_hello()->compression_methods()),
- m_state->client_hello()->fragment_size(),
+ state.client_hello()),
+ choose_compression(m_policy, state.client_hello()->compression_methods()),
+ state.client_hello()->fragment_size(),
m_secure_renegotiation.supported(),
m_secure_renegotiation.for_server_hello(),
- m_state->client_hello()->supports_session_ticket() && have_session_ticket_key,
- m_state->client_hello()->next_protocol_notification(),
+ state.client_hello()->supports_session_ticket() && have_session_ticket_key,
+ state.client_hello()->next_protocol_notification(),
m_possible_protocols,
- m_state->client_hello()->supports_heartbeats(),
+ state.client_hello()->supports_heartbeats(),
m_rng)
);
- m_secure_renegotiation.update(m_state->server_hello());
+ m_secure_renegotiation.update(state.server_hello());
- if(m_state->client_hello()->fragment_size())
- set_maximum_fragment_size(m_state->client_hello()->fragment_size());
+ if(state.client_hello()->fragment_size())
+ set_maximum_fragment_size(state.client_hello()->fragment_size());
- const std::string sig_algo = m_state->ciphersuite().sig_algo();
- const std::string kex_algo = m_state->ciphersuite().kex_algo();
+ const std::string sig_algo = state.ciphersuite().sig_algo();
+ const std::string kex_algo = state.ciphersuite().kex_algo();
if(sig_algo != "")
{
BOTAN_ASSERT(!cert_chains[sig_algo].empty(),
"Attempting to send empty certificate chain");
- m_state->server_certs(
- new Certificate(m_state->handshake_io(),
- m_state->hash(),
+ state.server_certs(
+ new Certificate(state.handshake_io(),
+ state.hash(),
cert_chains[sig_algo])
);
}
@@ -519,7 +516,7 @@ void Server::process_handshake_msg(Handshake_Type type,
if(kex_algo == "RSA" || sig_algo != "")
{
private_key = m_creds.private_key_for(
- m_state->server_certs()->cert_chain()[0],
+ state.server_certs()->cert_chain()[0],
"tls-server",
m_hostname);
@@ -529,13 +526,13 @@ void Server::process_handshake_msg(Handshake_Type type,
if(kex_algo == "RSA")
{
- dynamic_cast<Server_Handshake_State&>(*m_state).server_rsa_kex_key = private_key;
+ dynamic_cast<Server_Handshake_State&>(state).server_rsa_kex_key = private_key;
}
else
{
- m_state->server_kex(
- new Server_Key_Exchange(m_state->handshake_io(),
- m_state.get(),
+ state.server_kex(
+ new Server_Key_Exchange(state.handshake_io(),
+ state,
m_policy,
m_creds,
m_rng,
@@ -546,17 +543,17 @@ void Server::process_handshake_msg(Handshake_Type type,
std::vector<X509_Certificate> client_auth_CAs =
m_creds.trusted_certificate_authorities("tls-server", m_hostname);
- if(!client_auth_CAs.empty() && m_state->ciphersuite().sig_algo() != "")
+ if(!client_auth_CAs.empty() && state.ciphersuite().sig_algo() != "")
{
- m_state->cert_req(
- new Certificate_Req(m_state->handshake_io(),
- m_state->hash(),
+ state.cert_req(
+ new Certificate_Req(state.handshake_io(),
+ state.hash(),
m_policy,
client_auth_CAs,
- m_state->version())
+ state.version())
);
- m_state->set_expected_next(CERTIFICATE);
+ state.set_expected_next(CERTIFICATE);
}
/*
@@ -564,44 +561,44 @@ void Server::process_handshake_msg(Handshake_Type type,
* allowed to send either an empty cert message or proceed
* directly to the client key exchange, so allow either case.
*/
- m_state->set_expected_next(CLIENT_KEX);
+ state.set_expected_next(CLIENT_KEX);
- m_state->server_hello_done(
- new Server_Hello_Done(m_state->handshake_io(), m_state->hash())
+ state.server_hello_done(
+ new Server_Hello_Done(state.handshake_io(), state.hash())
);
}
}
else if(type == CERTIFICATE)
{
- m_state->client_certs(new Certificate(contents));
+ state.client_certs(new Certificate(contents));
- m_state->set_expected_next(CLIENT_KEX);
+ state.set_expected_next(CLIENT_KEX);
}
else if(type == CLIENT_KEX)
{
- if(m_state->received_handshake_msg(CERTIFICATE) && !m_state->client_certs()->empty())
- m_state->set_expected_next(CERTIFICATE_VERIFY);
+ if(state.received_handshake_msg(CERTIFICATE) && !state.client_certs()->empty())
+ state.set_expected_next(CERTIFICATE_VERIFY);
else
- m_state->set_expected_next(HANDSHAKE_CCS);
+ state.set_expected_next(HANDSHAKE_CCS);
- m_state->client_kex(
- new Client_Key_Exchange(contents, m_state.get(),
- dynamic_cast<Server_Handshake_State&>(*m_state).server_rsa_kex_key,
+ state.client_kex(
+ new Client_Key_Exchange(contents, state,
+ dynamic_cast<Server_Handshake_State&>(state).server_rsa_kex_key,
m_creds, m_policy, m_rng)
);
- m_state->compute_session_keys();
+ state.compute_session_keys();
}
else if(type == CERTIFICATE_VERIFY)
{
- m_state->client_verify(new Certificate_Verify(contents, m_state->version()));
+ state.client_verify(new Certificate_Verify(contents, state.version()));
- m_peer_certs = m_state->client_certs()->cert_chain();
+ m_peer_certs = state.client_certs()->cert_chain();
const bool sig_valid =
- m_state->client_verify()->verify(m_peer_certs[0], m_state.get());
+ state.client_verify()->verify(m_peer_certs[0], state);
- m_state->hash().update(m_state->handshake_io().format(contents, type));
+ state.hash().update(state.handshake_io().format(contents, type));
/*
* Using DECRYPT_ERROR looks weird here, but per RFC 4346 is for
@@ -620,68 +617,68 @@ void Server::process_handshake_msg(Handshake_Type type,
throw TLS_Exception(Alert::BAD_CERTIFICATE, e.what());
}
- m_state->set_expected_next(HANDSHAKE_CCS);
+ state.set_expected_next(HANDSHAKE_CCS);
}
else if(type == HANDSHAKE_CCS)
{
- if(m_state->server_hello()->next_protocol_notification())
- m_state->set_expected_next(NEXT_PROTOCOL);
+ if(state.server_hello()->next_protocol_notification())
+ state.set_expected_next(NEXT_PROTOCOL);
else
- m_state->set_expected_next(FINISHED);
+ state.set_expected_next(FINISHED);
change_cipher_spec_reader(SERVER);
}
else if(type == NEXT_PROTOCOL)
{
- m_state->set_expected_next(FINISHED);
+ state.set_expected_next(FINISHED);
- m_state->next_protocol(new Next_Protocol(contents));
+ state.next_protocol(new Next_Protocol(contents));
// should this be a callback?
- m_next_protocol = m_state->next_protocol()->protocol();
+ m_next_protocol = state.next_protocol()->protocol();
}
else if(type == FINISHED)
{
- m_state->set_expected_next(HANDSHAKE_NONE);
+ state.set_expected_next(HANDSHAKE_NONE);
- m_state->client_finished(new Finished(contents));
+ state.client_finished(new Finished(contents));
- if(!m_state->client_finished()->verify(m_state.get(), CLIENT))
+ if(!state.client_finished()->verify(state, CLIENT))
throw TLS_Exception(Alert::DECRYPT_ERROR,
"Finished message didn't verify");
- if(!m_state->server_finished())
+ if(!state.server_finished())
{
// already sent finished if resuming, so this is a new session
- m_state->hash().update(m_state->handshake_io().format(contents, type));
+ state.hash().update(state.handshake_io().format(contents, type));
Session session_info(
- m_state->server_hello()->session_id(),
- m_state->session_keys().master_secret(),
- m_state->server_hello()->version(),
- m_state->server_hello()->ciphersuite(),
- m_state->server_hello()->compression_method(),
+ state.server_hello()->session_id(),
+ state.session_keys().master_secret(),
+ state.server_hello()->version(),
+ state.server_hello()->ciphersuite(),
+ state.server_hello()->compression_method(),
SERVER,
m_secure_renegotiation.supported(),
- m_state->server_hello()->fragment_size(),
+ state.server_hello()->fragment_size(),
m_peer_certs,
std::vector<byte>(),
m_hostname,
- m_state->srp_identifier()
+ state.srp_identifier()
);
if(m_handshake_fn(session_info))
{
- if(m_state->server_hello()->supports_session_ticket())
+ if(state.server_hello()->supports_session_ticket())
{
try
{
const SymmetricKey ticket_key = m_creds.psk("tls-server", "session-ticket", "");
- m_state->new_session_ticket(
- new New_Session_Ticket(m_state->handshake_io(),
- m_state->hash(),
+ state.new_session_ticket(
+ new New_Session_Ticket(state.handshake_io(),
+ state.hash(),
session_info.encrypt(ticket_key, m_rng),
m_policy.session_ticket_lifetime())
);
@@ -692,24 +689,24 @@ void Server::process_handshake_msg(Handshake_Type type,
m_session_manager.save(session_info);
}
- if(!m_state->new_session_ticket() &&
- m_state->server_hello()->supports_session_ticket())
+ if(!state.new_session_ticket() &&
+ state.server_hello()->supports_session_ticket())
{
- m_state->new_session_ticket(
- new New_Session_Ticket(m_state->handshake_io(), m_state->hash())
+ state.new_session_ticket(
+ new New_Session_Ticket(state.handshake_io(), state.hash())
);
}
- m_state->handshake_io().send(Change_Cipher_Spec());
+ state.handshake_io().send(Change_Cipher_Spec());
change_cipher_spec_writer(SERVER);
- m_state->server_finished(
- new Finished(m_state->handshake_io(), m_state.get(), SERVER)
+ state.server_finished(
+ new Finished(state.handshake_io(), state, SERVER)
);
}
- activate_session(m_state->server_hello()->session_id());
+ activate_session(state.server_hello()->session_id());
}
else
throw Unexpected_Message("Unknown handshake message received");
diff --git a/src/tls/tls_server.h b/src/tls/tls_server.h
index 1add1ff40..e70989ad0 100644
--- a/src/tls/tls_server.h
+++ b/src/tls/tls_server.h
@@ -50,7 +50,9 @@ class BOTAN_DLL Server : public Channel
{ return m_next_protocol; }
private:
- void process_handshake_msg(Handshake_Type, const std::vector<byte>&) override;
+ void process_handshake_msg(Handshake_State& state,
+ Handshake_Type type,
+ const std::vector<byte>& contents) override;
class Handshake_State* new_handshake_state() override;