aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorlloyd <[email protected]>2012-01-27 17:48:11 +0000
committerlloyd <[email protected]>2012-01-27 17:48:11 +0000
commit8d0bbed5d5ccd0995d4794644172b6508959798e (patch)
tree53376ced602c4cfbafbe7683d49d2483b2a48274
parent133bda471c547842044bd66a44bfe64668e966da (diff)
Somewhat cleaner PSK handling
-rw-r--r--src/tls/c_kex.cpp50
-rw-r--r--src/tls/tls_messages.h7
-rw-r--r--src/tls/tls_server.cpp4
3 files changed, 34 insertions, 27 deletions
diff --git a/src/tls/c_kex.cpp b/src/tls/c_kex.cpp
index f5afcc100..0cfe20a40 100644
--- a/src/tls/c_kex.cpp
+++ b/src/tls/c_kex.cpp
@@ -185,8 +185,7 @@ Client_Key_Exchange::Client_Key_Exchange(Record_Writer& writer,
* Read a Client Key Exchange message
*/
Client_Key_Exchange::Client_Key_Exchange(const MemoryRegion<byte>& contents,
- const Handshake_State* state,
- Credentials_Manager& creds)
+ const Handshake_State* state)
{
const std::string kex_algo = state->suite.kex_algo();
@@ -201,20 +200,7 @@ Client_Key_Exchange::Client_Key_Exchange(const MemoryRegion<byte>& contents,
else if(kex_algo == "ECDH")
key_material = reader.get_range<byte>(1, 1, 255);
else if(kex_algo == "PSK")
- {
- const std::string psk_identity = reader.get_string(2, 0, 65535);
-
- SymmetricKey psk = creds.psk("tls-server",
- state->client_hello->sni_hostname(),
- psk_identity);
-
- MemoryVector<byte> zeros(psk.length());
-
- append_tls_length_value(key_material, zeros, 2);
- append_tls_length_value(key_material, psk.bits_of(), 2);
-
- pre_master = key_material;
- }
+ key_material = reader.get_range<byte>(2, 0, 65535);
else
throw Internal_Error("Client_Key_Exchange received unknown kex type " + kex_algo);
}
@@ -225,13 +211,35 @@ Client_Key_Exchange::Client_Key_Exchange(const MemoryRegion<byte>& contents,
*/
SecureVector<byte>
Client_Key_Exchange::pre_master_secret(RandomNumberGenerator& rng,
- const Handshake_State* state)
+ const Handshake_State* state,
+ Credentials_Manager& creds,
+ const Policy& policy)
{
const std::string kex_algo = state->suite.kex_algo();
if(kex_algo == "PSK")
{
- return key_material;
+ const std::string psk_identity(
+ reinterpret_cast<const char*>(&key_material[0]),
+ key_material.size());
+
+ SymmetricKey psk = creds.psk("tls-server",
+ state->client_hello->sni_hostname(),
+ psk_identity);
+
+ if(psk.length() == 0)
+ {
+ if(policy.hide_unknown_users())
+ throw TLS_Exception(Alert::UNKNOWN_PSK_IDENTITY,
+ "No PKS for identifier " + psk_identity);
+ else
+ psk = SymmetricKey(rng, 16);
+ }
+
+ MemoryVector<byte> zeros(psk.length());
+
+ append_tls_length_value(pre_master, zeros, 2);
+ append_tls_length_value(pre_master, psk.bits_of(), 2);
}
else if(kex_algo == "RSA")
{
@@ -267,8 +275,6 @@ Client_Key_Exchange::pre_master_secret(RandomNumberGenerator& rng,
pre_master[0] = client_version.major_version();
pre_master[1] = client_version.minor_version();
}
-
- return pre_master;
}
else if(kex_algo == "DH" || kex_algo == "ECDH")
{
@@ -300,11 +306,11 @@ Client_Key_Exchange::pre_master_secret(RandomNumberGenerator& rng,
*/
pre_master = rng.random_vec(ka_key->public_value().size());
}
-
- return pre_master;
}
else
throw Internal_Error("Client_Key_Exchange: Unknown kex type " + kex_algo);
+
+ return pre_master;
}
}
diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h
index b29aac600..97d72a0cb 100644
--- a/src/tls/tls_messages.h
+++ b/src/tls/tls_messages.h
@@ -214,7 +214,9 @@ class Client_Key_Exchange : public Handshake_Message
{ return pre_master; }
SecureVector<byte> pre_master_secret(RandomNumberGenerator& rng,
- const Handshake_State* state);
+ const Handshake_State* state,
+ Credentials_Manager& creds,
+ const Policy& policy);
Client_Key_Exchange(Record_Writer& output,
Handshake_State* state,
@@ -223,8 +225,7 @@ class Client_Key_Exchange : public Handshake_Message
RandomNumberGenerator& rng);
Client_Key_Exchange(const MemoryRegion<byte>& buf,
- const Handshake_State* state,
- Credentials_Manager& creds);
+ const Handshake_State* state);
private:
MemoryVector<byte> serialize() const { return key_material; }
diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp
index 9c4410938..645001367 100644
--- a/src/tls/tls_server.cpp
+++ b/src/tls/tls_server.cpp
@@ -358,10 +358,10 @@ void Server::process_handshake_msg(Handshake_Type type,
else
state->set_expected_next(HANDSHAKE_CCS);
- state->client_kex = new Client_Key_Exchange(contents, state, creds);
+ state->client_kex = new Client_Key_Exchange(contents, state);
SecureVector<byte> pre_master =
- state->client_kex->pre_master_secret(rng, state);
+ state->client_kex->pre_master_secret(rng, state, creds, policy);
state->keys = Session_Keys(state, pre_master, false);
}