From c72b3f5afbebd8615884228f938c7cb270f5669e Mon Sep 17 00:00:00 2001 From: lloyd Date: Tue, 27 Dec 2011 16:17:31 +0000 Subject: Much smarter state transition checking: at each point in the handshake, keep track of exactly which handshake message type(s) we can expect and assert before processing that what we recieved is what we expected. Contrast with previous 'checking' which was more in the style 'could we perhaps plausibly do something with this message?' aka broken. --- src/tls/tls_client.cpp | 120 +++++++++++++++---------------------------------- src/tls/tls_magic.h | 6 +-- src/tls/tls_server.cpp | 85 +++++++++-------------------------- src/tls/tls_state.cpp | 103 ++++++++++++++++++++++++++++++++++++++++-- src/tls/tls_state.h | 15 +++++-- 5 files changed, 170 insertions(+), 159 deletions(-) diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index ee9c397c1..21c97751c 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -13,69 +13,6 @@ namespace Botan { -namespace { - -/* -* Verify the state transition is allowed -* FIXME: checks are wrong for session reuse (add a flag for that) -*/ -void client_check_state(Handshake_Type new_msg, Handshake_State* state) - { - class State_Transition_Error : public Unexpected_Message - { - public: - State_Transition_Error(const std::string& err) : - Unexpected_Message("State transition error from " + err) {} - }; - - if(new_msg == HELLO_REQUEST) - { - if(state->client_hello) - throw State_Transition_Error("HelloRequest"); - } - else if(new_msg == SERVER_HELLO) - { - if(!state->client_hello || state->server_hello) - throw State_Transition_Error("ServerHello"); - } - else if(new_msg == CERTIFICATE) - { - if(!state->server_hello || state->server_kex || - state->cert_req || state->server_hello_done) - throw State_Transition_Error("ServerCertificate"); - } - else if(new_msg == SERVER_KEX) - { - if(!state->server_hello || state->server_kex || - state->cert_req || state->server_hello_done) - throw State_Transition_Error("ServerKeyExchange"); - } - else if(new_msg == CERTIFICATE_REQUEST) - { - if(!state->server_certs || state->cert_req || state->server_hello_done) - throw State_Transition_Error("CertificateRequest"); - } - else if(new_msg == SERVER_HELLO_DONE) - { - if(!state->server_hello || state->server_hello_done) - throw State_Transition_Error("ServerHelloDone"); - } - else if(new_msg == HANDSHAKE_CCS) - { - if(!state->client_finished || state->server_finished) - throw State_Transition_Error("ServerChangeCipherSpec"); - } - else if(new_msg == FINISHED) - { - if(!state->got_server_ccs) - throw State_Transition_Error("ServerFinished"); - } - else - throw Unexpected_Message("Unexpected message in handshake"); - } - -} - /* * TLS Client Constructor */ @@ -90,6 +27,7 @@ TLS_Client::TLS_Client(std::tr1::function output_fn writer.set_version(policy.pref_version()); state = new Handshake_State; + state->set_expected_next(SERVER_HELLO); state->client_hello = new Client_Hello(rng, writer, policy, state->hash); } @@ -121,12 +59,14 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, if(state == 0) state = new Handshake_State(); else - return; + return; // hello request in middle of handshake? } if(state == 0) throw Unexpected_Message("Unexpected handshake message"); + state->confirm_transition_to(type); + if(type != HANDSHAKE_CCS && type != HELLO_REQUEST && type != FINISHED) { state->hash.update(static_cast(type)); @@ -138,15 +78,11 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, if(type == HELLO_REQUEST) { - client_check_state(type, state); - Hello_Request hello_request(contents); state->client_hello = new Client_Hello(rng, writer, policy, state->hash); } else if(type == SERVER_HELLO) { - client_check_state(type, state); - state->server_hello = new Server_Hello(contents); if(!state->client_hello->offered_suite( @@ -170,13 +106,32 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, reader.set_version(state->version); state->suite = CipherSuite(state->server_hello->ciphersuite()); + + if(state->suite.sig_type() != TLS_ALGO_SIGNER_ANON) + { + state->set_expected_next(CERTIFICATE); + } + else if(state->suite.kex_type() != TLS_ALGO_KEYEXCH_NOKEX) + { + state->set_expected_next(SERVER_KEX); + } + else + { + state->set_expected_next(CERTIFICATE_REQUEST); // optional + state->set_expected_next(SERVER_HELLO_DONE); + } } else if(type == CERTIFICATE) { - client_check_state(type, state); - - if(state->suite.sig_type() == TLS_ALGO_SIGNER_ANON) - throw Unexpected_Message("Recived certificate from anonymous server"); + if(state->suite.kex_type() != TLS_ALGO_KEYEXCH_NOKEX) + { + state->set_expected_next(SERVER_KEX); + } + else + { + state->set_expected_next(CERTIFICATE_REQUEST); // optional + state->set_expected_next(SERVER_HELLO_DONE); + } state->server_certs = new Certificate(contents); @@ -208,10 +163,8 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, } else if(type == SERVER_KEX) { - client_check_state(type, state); - - if(state->suite.kex_type() == TLS_ALGO_KEYEXCH_NOKEX) - throw Unexpected_Message("Unexpected key exchange from server"); + state->set_expected_next(CERTIFICATE_REQUEST); // optional + state->set_expected_next(SERVER_HELLO_DONE); state->server_kex = new Server_Key_Exchange(contents); @@ -246,18 +199,16 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, } else if(type == CERTIFICATE_REQUEST) { - client_check_state(type, state); - + state->set_expected_next(SERVER_HELLO_DONE); state->cert_req = new Certificate_Req(contents); - state->do_client_auth = true; } else if(type == SERVER_HELLO_DONE) { - client_check_state(type, state); + state->set_expected_next(HANDSHAKE_CCS); state->server_hello_done = new Server_Hello_Done(contents); - if(state->do_client_auth) + if(state->received_handshake_msg(CERTIFICATE_REQUEST)) { std::vector send_certs; @@ -274,7 +225,7 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, state->kex_pub, state->version, state->client_hello->version()); - if(state->do_client_auth) + if(state->received_handshake_msg(CERTIFICATE_REQUEST)) { Private_Key* key_matching_cert = 0; // FIXME state->client_verify = new Certificate_Verify(rng, @@ -298,14 +249,13 @@ void TLS_Client::process_handshake_msg(Handshake_Type type, } else if(type == HANDSHAKE_CCS) { - client_check_state(type, state); + state->set_expected_next(FINISHED); reader.set_keys(state->suite, state->keys, CLIENT); - state->got_server_ccs = true; } else if(type == FINISHED) { - client_check_state(type, state); + state->set_expected_next(HELLO_REQUEST); state->server_finished = new Finished(contents); diff --git a/src/tls/tls_magic.h b/src/tls/tls_magic.h index 4dd9b2bb4..7913b576c 100644 --- a/src/tls/tls_magic.h +++ b/src/tls/tls_magic.h @@ -40,7 +40,7 @@ enum Record_Type { enum Handshake_Type { HELLO_REQUEST = 0, CLIENT_HELLO = 1, - CLIENT_HELLO_SSLV2 = 255, // not a wire value + CLIENT_HELLO_SSLV2 = 200, // Not a wire value SERVER_HELLO = 2, CERTIFICATE = 11, SERVER_KEX = 12, @@ -50,8 +50,8 @@ enum Handshake_Type { CLIENT_KEX = 16, FINISHED = 20, - HANDSHAKE_CCS = 100, - HANDSHAKE_NONE = 101 + HANDSHAKE_CCS = 100, // Not a wire value + HANDSHAKE_NONE = 255 // Null value }; enum Alert_Level { diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp index e2f994224..141ff6cba 100644 --- a/src/tls/tls_server.cpp +++ b/src/tls/tls_server.cpp @@ -30,56 +30,6 @@ Version_Code choose_version(Version_Code client, Version_Code minimum) return TLS_V11; } -// FIXME: checks are wrong for session reuse (add a flag for that) -/* -* Verify the state transition is allowed -*/ -void server_check_state(Handshake_Type new_msg, Handshake_State* state) - { - class State_Transition_Error : public Unexpected_Message - { - public: - State_Transition_Error(const std::string& err) : - Unexpected_Message("State transition error from " + err) {} - }; - - if(new_msg == CLIENT_HELLO || new_msg == CLIENT_HELLO_SSLV2) - { - if(state->server_hello) - throw State_Transition_Error("ClientHello"); - } - else if(new_msg == CERTIFICATE) - { - if(!state->do_client_auth || !state->cert_req || - !state->server_hello_done || state->client_kex) - throw State_Transition_Error("ClientCertificate"); - } - else if(new_msg == CLIENT_KEX) - { - if(!state->server_hello_done || state->client_verify || - state->got_client_ccs) - throw State_Transition_Error("ClientKeyExchange"); - } - else if(new_msg == CERTIFICATE_VERIFY) - { - if(!state->cert_req || !state->client_certs || !state->client_kex || - state->got_client_ccs) - throw State_Transition_Error("CertificateVerify"); - } - else if(new_msg == HANDSHAKE_CCS) - { - if(!state->client_kex || state->client_finished) - throw State_Transition_Error("ClientChangeCipherSpec"); - } - else if(new_msg == FINISHED) - { - if(!state->got_client_ccs) - throw State_Transition_Error("ClientFinished"); - } - else - throw Unexpected_Message("Unexpected message in handshake"); - } - } /* @@ -118,7 +68,10 @@ void TLS_Server::read_handshake(byte rec_type, const MemoryRegion& rec_buf) { if(rec_type == HANDSHAKE && !state) + { state = new Handshake_State; + state->set_expected_next(CLIENT_HELLO); + } TLS_Channel::read_handshake(rec_type, rec_buf); } @@ -134,6 +87,8 @@ void TLS_Server::process_handshake_msg(Handshake_Type type, if(state == 0) throw Unexpected_Message("Unexpected handshake message"); + state->confirm_transition_to(type); + if(type != HANDSHAKE_CCS && type != FINISHED) { if(type != CLIENT_HELLO_SSLV2) @@ -150,8 +105,6 @@ void TLS_Server::process_handshake_msg(Handshake_Type type, if(type == CLIENT_HELLO || type == CLIENT_HELLO_SSLV2) { - server_check_state(type, state); - state->client_hello = new Client_Hello(contents, type); client_requested_hostname = state->client_hello->hostname(); @@ -169,13 +122,13 @@ void TLS_Server::process_handshake_msg(Handshake_Type type, if(found && params.connection_side == SERVER) { + // resume session - - - + state->set_expected_next(HANDSHAKE_CCS); } - else // new session + else { + // new session MemoryVector sess_id = rng.random_vec(32); state->server_hello = new Server_Hello(rng, writer, @@ -218,22 +171,27 @@ void TLS_Server::process_handshake_msg(Handshake_Type type, if(policy.require_client_auth()) { - state->do_client_auth = true; throw Internal_Error("Client auth not implemented"); // FIXME: send client auth request here + state->set_expected_next(CERTIFICATE); } + else + state->set_expected_next(CLIENT_KEX); } state->server_hello_done = new Server_Hello_Done(writer, state->hash); } else if(type == CERTIFICATE) { - server_check_state(type, state); + state->set_expected_next(CLIENT_KEX); // FIXME: process this } else if(type == CLIENT_KEX) { - server_check_state(type, state); + if(state->received_handshake_msg(CERTIFICATE)) + state->set_expected_next(CERTIFICATE_VERIFY); + else + state->set_expected_next(HANDSHAKE_CCS); state->client_kex = new Client_Key_Exchange(contents, state->suite, state->version); @@ -245,22 +203,23 @@ void TLS_Server::process_handshake_msg(Handshake_Type type, state->keys = SessionKeys(state->suite, state->version, pre_master, state->client_hello->random(), state->server_hello->random()); + } else if(type == CERTIFICATE_VERIFY) { - server_check_state(type, state); // FIXME: process this + + state->set_expected_next(HANDSHAKE_CCS); } else if(type == HANDSHAKE_CCS) { - server_check_state(type, state); + state->set_expected_next(FINISHED); reader.set_keys(state->suite, state->keys, SERVER); - state->got_client_ccs = true; } else if(type == FINISHED) { - server_check_state(type, state); + state->set_expected_next(HANDSHAKE_NONE); state->client_finished = new Finished(contents); diff --git a/src/tls/tls_state.cpp b/src/tls/tls_state.cpp index 6aaf5e201..61f087dec 100644 --- a/src/tls/tls_state.cpp +++ b/src/tls/tls_state.cpp @@ -1,15 +1,73 @@ /* * TLS Handshaking -* (C) 2004-2006 Jack Lloyd +* (C) 2004-2006,2011 Jack Lloyd * * Released under the terms of the Botan license */ #include +#include + namespace Botan { -/** +namespace { + +u32bit bitmask_for_handshake_type(Handshake_Type type) + { + switch(type) + { + case HELLO_REQUEST: + return (1 << 0); + + /* + * Same code point for both client hello styles + */ + case CLIENT_HELLO: + case CLIENT_HELLO_SSLV2: + return (1 << 1); + + case SERVER_HELLO: + return (1 << 2); + + case CERTIFICATE: + return (1 << 3); + + case SERVER_KEX: + return (1 << 4); + + case CERTIFICATE_REQUEST: + return (1 << 5); + + case SERVER_HELLO_DONE: + return (1 << 6); + + case CERTIFICATE_VERIFY: + return (1 << 7); + + case CLIENT_KEX: + return (1 << 8); + + case FINISHED: + return (1 << 9); + + case HANDSHAKE_CCS: + return (1 << 10); + + // allow explicitly disabling new handshakes + case HANDSHAKE_NONE: + return 0; + + default: + throw Internal_Error("Unknown handshake type " + to_string(type)); + } + + return 0; + } + +} + +/* * Initialize the SSL/TLS Handshake State */ Handshake_State::Handshake_State() @@ -30,11 +88,48 @@ Handshake_State::Handshake_State() kex_pub = 0; kex_priv = 0; - do_client_auth = got_client_ccs = got_server_ccs = false; + //do_client_auth = got_client_ccs = got_server_ccs = false; version = SSL_V3; + + hand_expecting_mask = 0; + hand_received_mask = 0; } -/** +void Handshake_State::confirm_transition_to(Handshake_Type handshake_msg) + { + const u32bit mask = bitmask_for_handshake_type(handshake_msg); + + hand_received_mask |= mask; + + const bool ok = (hand_expecting_mask & mask); // overlap? + + if(!ok) + printf("Bad handshake transition, got %d expected %08X\n", + handshake_msg, hand_expecting_mask); + + /* We don't know what to expect next, so force a call to + set_expected_next; if it doesn't happen, the next transition + check will always fail which is what we want. + */ + hand_expecting_mask = 0; + + if(!ok) + throw Unexpected_Message("Unexpected state transition in handshake"); + } + +void Handshake_State::set_expected_next(Handshake_Type handshake_msg) + { + hand_expecting_mask |= bitmask_for_handshake_type(handshake_msg); + } + +bool Handshake_State::received_handshake_msg(Handshake_Type handshake_msg) const + { + const u32bit mask = bitmask_for_handshake_type(handshake_msg); + + return (hand_received_mask & mask); + } + +/* * Destroy the SSL/TLS Handshake State */ Handshake_State::~Handshake_State() diff --git a/src/tls/tls_state.h b/src/tls/tls_state.h index e2728198f..523dfed9c 100644 --- a/src/tls/tls_state.h +++ b/src/tls/tls_state.h @@ -19,6 +19,14 @@ namespace Botan { class Handshake_State { public: + Handshake_State(); + ~Handshake_State(); + + bool received_handshake_msg(Handshake_Type handshake_msg) const; + + void confirm_transition_to(Handshake_Type handshake_msg); + void set_expected_next(Handshake_Type handshake_msg); + Client_Hello* client_hello; Server_Hello* server_hello; Certificate* server_certs; @@ -42,10 +50,9 @@ class Handshake_State SecureQueue queue; Version_Code version; - bool got_client_ccs, got_server_ccs, do_client_auth; - - Handshake_State(); - ~Handshake_State(); + //bool got_client_ccs, got_server_ccs, do_client_auth; + private: + u32bit hand_expecting_mask, hand_received_mask; }; } -- cgit v1.2.3