aboutsummaryrefslogtreecommitdiffstats
path: root/src/tls/tls_client.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/tls/tls_client.cpp')
-rw-r--r--src/tls/tls_client.cpp120
1 files changed, 35 insertions, 85 deletions
diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp
index ee9c397c1..21c97751c 100644
--- a/src/tls/tls_client.cpp
+++ b/src/tls/tls_client.cpp
@@ -13,69 +13,6 @@
namespace Botan {
-namespace {
-
-/*
-* Verify the state transition is allowed
-* FIXME: checks are wrong for session reuse (add a flag for that)
-*/
-void client_check_state(Handshake_Type new_msg, Handshake_State* state)
- {
- class State_Transition_Error : public Unexpected_Message
- {
- public:
- State_Transition_Error(const std::string& err) :
- Unexpected_Message("State transition error from " + err) {}
- };
-
- if(new_msg == HELLO_REQUEST)
- {
- if(state->client_hello)
- throw State_Transition_Error("HelloRequest");
- }
- else if(new_msg == SERVER_HELLO)
- {
- if(!state->client_hello || state->server_hello)
- throw State_Transition_Error("ServerHello");
- }
- else if(new_msg == CERTIFICATE)
- {
- if(!state->server_hello || state->server_kex ||
- state->cert_req || state->server_hello_done)
- throw State_Transition_Error("ServerCertificate");
- }
- else if(new_msg == SERVER_KEX)
- {
- if(!state->server_hello || state->server_kex ||
- state->cert_req || state->server_hello_done)
- throw State_Transition_Error("ServerKeyExchange");
- }
- else if(new_msg == CERTIFICATE_REQUEST)
- {
- if(!state->server_certs || state->cert_req || state->server_hello_done)
- throw State_Transition_Error("CertificateRequest");
- }
- else if(new_msg == SERVER_HELLO_DONE)
- {
- if(!state->server_hello || state->server_hello_done)
- throw State_Transition_Error("ServerHelloDone");
- }
- else if(new_msg == HANDSHAKE_CCS)
- {
- if(!state->client_finished || state->server_finished)
- throw State_Transition_Error("ServerChangeCipherSpec");
- }
- else if(new_msg == FINISHED)
- {
- if(!state->got_server_ccs)
- throw State_Transition_Error("ServerFinished");
- }
- else
- throw Unexpected_Message("Unexpected message in handshake");
- }
-
-}
-
/*
* TLS Client Constructor
*/
@@ -90,6 +27,7 @@ TLS_Client::TLS_Client(std::tr1::function<void (const byte[], size_t)> output_fn
writer.set_version(policy.pref_version());
state = new Handshake_State;
+ state->set_expected_next(SERVER_HELLO);
state->client_hello = new Client_Hello(rng, writer, policy, state->hash);
}
@@ -121,12 +59,14 @@ void TLS_Client::process_handshake_msg(Handshake_Type type,
if(state == 0)
state = new Handshake_State();
else
- return;
+ return; // hello request in middle of handshake?
}
if(state == 0)
throw Unexpected_Message("Unexpected handshake message");
+ state->confirm_transition_to(type);
+
if(type != HANDSHAKE_CCS && type != HELLO_REQUEST && type != FINISHED)
{
state->hash.update(static_cast<byte>(type));
@@ -138,15 +78,11 @@ void TLS_Client::process_handshake_msg(Handshake_Type type,
if(type == HELLO_REQUEST)
{
- client_check_state(type, state);
-
Hello_Request hello_request(contents);
state->client_hello = new Client_Hello(rng, writer, policy, state->hash);
}
else if(type == SERVER_HELLO)
{
- client_check_state(type, state);
-
state->server_hello = new Server_Hello(contents);
if(!state->client_hello->offered_suite(
@@ -170,13 +106,32 @@ void TLS_Client::process_handshake_msg(Handshake_Type type,
reader.set_version(state->version);
state->suite = CipherSuite(state->server_hello->ciphersuite());
+
+ if(state->suite.sig_type() != TLS_ALGO_SIGNER_ANON)
+ {
+ state->set_expected_next(CERTIFICATE);
+ }
+ else if(state->suite.kex_type() != TLS_ALGO_KEYEXCH_NOKEX)
+ {
+ state->set_expected_next(SERVER_KEX);
+ }
+ else
+ {
+ state->set_expected_next(CERTIFICATE_REQUEST); // optional
+ state->set_expected_next(SERVER_HELLO_DONE);
+ }
}
else if(type == CERTIFICATE)
{
- client_check_state(type, state);
-
- if(state->suite.sig_type() == TLS_ALGO_SIGNER_ANON)
- throw Unexpected_Message("Recived certificate from anonymous server");
+ if(state->suite.kex_type() != TLS_ALGO_KEYEXCH_NOKEX)
+ {
+ state->set_expected_next(SERVER_KEX);
+ }
+ else
+ {
+ state->set_expected_next(CERTIFICATE_REQUEST); // optional
+ state->set_expected_next(SERVER_HELLO_DONE);
+ }
state->server_certs = new Certificate(contents);
@@ -208,10 +163,8 @@ void TLS_Client::process_handshake_msg(Handshake_Type type,
}
else if(type == SERVER_KEX)
{
- client_check_state(type, state);
-
- if(state->suite.kex_type() == TLS_ALGO_KEYEXCH_NOKEX)
- throw Unexpected_Message("Unexpected key exchange from server");
+ state->set_expected_next(CERTIFICATE_REQUEST); // optional
+ state->set_expected_next(SERVER_HELLO_DONE);
state->server_kex = new Server_Key_Exchange(contents);
@@ -246,18 +199,16 @@ void TLS_Client::process_handshake_msg(Handshake_Type type,
}
else if(type == CERTIFICATE_REQUEST)
{
- client_check_state(type, state);
-
+ state->set_expected_next(SERVER_HELLO_DONE);
state->cert_req = new Certificate_Req(contents);
- state->do_client_auth = true;
}
else if(type == SERVER_HELLO_DONE)
{
- client_check_state(type, state);
+ state->set_expected_next(HANDSHAKE_CCS);
state->server_hello_done = new Server_Hello_Done(contents);
- if(state->do_client_auth)
+ if(state->received_handshake_msg(CERTIFICATE_REQUEST))
{
std::vector<X509_Certificate> send_certs;
@@ -274,7 +225,7 @@ void TLS_Client::process_handshake_msg(Handshake_Type type,
state->kex_pub, state->version,
state->client_hello->version());
- if(state->do_client_auth)
+ if(state->received_handshake_msg(CERTIFICATE_REQUEST))
{
Private_Key* key_matching_cert = 0; // FIXME
state->client_verify = new Certificate_Verify(rng,
@@ -298,14 +249,13 @@ void TLS_Client::process_handshake_msg(Handshake_Type type,
}
else if(type == HANDSHAKE_CCS)
{
- client_check_state(type, state);
+ state->set_expected_next(FINISHED);
reader.set_keys(state->suite, state->keys, CLIENT);
- state->got_server_ccs = true;
}
else if(type == FINISHED)
{
- client_check_state(type, state);
+ state->set_expected_next(HELLO_REQUEST);
state->server_finished = new Finished(contents);