diff options
-rw-r--r-- | src/tls/c_kex.cpp | 50 | ||||
-rw-r--r-- | src/tls/tls_messages.h | 7 | ||||
-rw-r--r-- | src/tls/tls_server.cpp | 4 |
3 files changed, 34 insertions, 27 deletions
diff --git a/src/tls/c_kex.cpp b/src/tls/c_kex.cpp index f5afcc100..0cfe20a40 100644 --- a/src/tls/c_kex.cpp +++ b/src/tls/c_kex.cpp @@ -185,8 +185,7 @@ Client_Key_Exchange::Client_Key_Exchange(Record_Writer& writer, * Read a Client Key Exchange message */ Client_Key_Exchange::Client_Key_Exchange(const MemoryRegion<byte>& contents, - const Handshake_State* state, - Credentials_Manager& creds) + const Handshake_State* state) { const std::string kex_algo = state->suite.kex_algo(); @@ -201,20 +200,7 @@ Client_Key_Exchange::Client_Key_Exchange(const MemoryRegion<byte>& contents, else if(kex_algo == "ECDH") key_material = reader.get_range<byte>(1, 1, 255); else if(kex_algo == "PSK") - { - const std::string psk_identity = reader.get_string(2, 0, 65535); - - SymmetricKey psk = creds.psk("tls-server", - state->client_hello->sni_hostname(), - psk_identity); - - MemoryVector<byte> zeros(psk.length()); - - append_tls_length_value(key_material, zeros, 2); - append_tls_length_value(key_material, psk.bits_of(), 2); - - pre_master = key_material; - } + key_material = reader.get_range<byte>(2, 0, 65535); else throw Internal_Error("Client_Key_Exchange received unknown kex type " + kex_algo); } @@ -225,13 +211,35 @@ Client_Key_Exchange::Client_Key_Exchange(const MemoryRegion<byte>& contents, */ SecureVector<byte> Client_Key_Exchange::pre_master_secret(RandomNumberGenerator& rng, - const Handshake_State* state) + const Handshake_State* state, + Credentials_Manager& creds, + const Policy& policy) { const std::string kex_algo = state->suite.kex_algo(); if(kex_algo == "PSK") { - return key_material; + const std::string psk_identity( + reinterpret_cast<const char*>(&key_material[0]), + key_material.size()); + + SymmetricKey psk = creds.psk("tls-server", + state->client_hello->sni_hostname(), + psk_identity); + + if(psk.length() == 0) + { + if(policy.hide_unknown_users()) + throw TLS_Exception(Alert::UNKNOWN_PSK_IDENTITY, + "No PKS for identifier " + psk_identity); + else + psk = SymmetricKey(rng, 16); + } + + MemoryVector<byte> zeros(psk.length()); + + append_tls_length_value(pre_master, zeros, 2); + append_tls_length_value(pre_master, psk.bits_of(), 2); } else if(kex_algo == "RSA") { @@ -267,8 +275,6 @@ Client_Key_Exchange::pre_master_secret(RandomNumberGenerator& rng, pre_master[0] = client_version.major_version(); pre_master[1] = client_version.minor_version(); } - - return pre_master; } else if(kex_algo == "DH" || kex_algo == "ECDH") { @@ -300,11 +306,11 @@ Client_Key_Exchange::pre_master_secret(RandomNumberGenerator& rng, */ pre_master = rng.random_vec(ka_key->public_value().size()); } - - return pre_master; } else throw Internal_Error("Client_Key_Exchange: Unknown kex type " + kex_algo); + + return pre_master; } } diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h index b29aac600..97d72a0cb 100644 --- a/src/tls/tls_messages.h +++ b/src/tls/tls_messages.h @@ -214,7 +214,9 @@ class Client_Key_Exchange : public Handshake_Message { return pre_master; } SecureVector<byte> pre_master_secret(RandomNumberGenerator& rng, - const Handshake_State* state); + const Handshake_State* state, + Credentials_Manager& creds, + const Policy& policy); Client_Key_Exchange(Record_Writer& output, Handshake_State* state, @@ -223,8 +225,7 @@ class Client_Key_Exchange : public Handshake_Message RandomNumberGenerator& rng); Client_Key_Exchange(const MemoryRegion<byte>& buf, - const Handshake_State* state, - Credentials_Manager& creds); + const Handshake_State* state); private: MemoryVector<byte> serialize() const { return key_material; } diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp index 9c4410938..645001367 100644 --- a/src/tls/tls_server.cpp +++ b/src/tls/tls_server.cpp @@ -358,10 +358,10 @@ void Server::process_handshake_msg(Handshake_Type type, else state->set_expected_next(HANDSHAKE_CCS); - state->client_kex = new Client_Key_Exchange(contents, state, creds); + state->client_kex = new Client_Key_Exchange(contents, state); SecureVector<byte> pre_master = - state->client_kex->pre_master_secret(rng, state); + state->client_kex->pre_master_secret(rng, state, creds, policy); state->keys = Session_Keys(state, pre_master, false); } |