aboutsummaryrefslogtreecommitdiffstats
path: root/src/tests/unit_tls.cpp
diff options
context:
space:
mode:
authorJack Lloyd <[email protected]>2015-10-25 22:25:40 -0400
committerJack Lloyd <[email protected]>2015-10-25 22:25:40 -0400
commitb2da74ca508745f00bb3d6b35cbe34d5031e27e7 (patch)
tree032fafd34f178af3b66877d52897f2e14359adaf /src/tests/unit_tls.cpp
parent2d078053b1ac7c1e2316892d8634c386288ee159 (diff)
TLS improvements
Use constant time operations when checking CBC padding in TLS decryption Fix a bug in decoding ClientHellos that prevented DTLS rehandshakes from working: on decode the session id and hello cookie would be swapped, causing confusion between client and server. Various changes in the service of finding the above DTLS bug that should have been done before now anyway - better control of handshake timeouts (via TLS::Policy), better reporting of handshake state in the case of an error, and finally expose the facility for per-message application callbacks.
Diffstat (limited to 'src/tests/unit_tls.cpp')
-rw-r--r--src/tests/unit_tls.cpp551
1 files changed, 431 insertions, 120 deletions
diff --git a/src/tests/unit_tls.cpp b/src/tests/unit_tls.cpp
index 116eb2cdf..d4abef119 100644
--- a/src/tests/unit_tls.cpp
+++ b/src/tests/unit_tls.cpp
@@ -11,6 +11,7 @@
#include <botan/tls_server.h>
#include <botan/tls_client.h>
+#include <botan/tls_handshake_msg.h>
#include <botan/pkcs10.h>
#include <botan/x509self.h>
#include <botan/rsa.h>
@@ -21,6 +22,7 @@
#include <iostream>
#include <vector>
#include <memory>
+#include <thread>
using namespace Botan;
@@ -155,158 +157,441 @@ Credentials_Manager* create_creds()
return new Credentials_Manager_Test(server_cert, ca_cert, server_key);
}
-size_t basic_test_handshake(RandomNumberGenerator& rng,
- TLS::Protocol_Version offer_version,
- Credentials_Manager& creds,
- TLS::Policy& policy)
+std::function<void (const byte[], size_t)> queue_inserter(std::vector<byte>& q)
{
- TLS::Session_Manager_In_Memory server_sessions(rng);
- TLS::Session_Manager_In_Memory client_sessions(rng);
-
- std::vector<byte> c2s_q, s2c_q, c2s_data, s2c_data;
+ return [&](const byte buf[], size_t sz) { q.insert(q.end(), buf, buf + sz); };
+ }
- auto handshake_complete = [&](const TLS::Session& session) -> bool
+void print_alert(TLS::Alert alert, const byte[], size_t)
{
- if(session.version() != offer_version)
- std::cout << "Offered " << offer_version.to_string()
- << " got " << session.version().to_string() << std::endl;
- return true;
+ //std::cout << "Alert " << alert.type_string() << std::endl;
};
- auto print_alert = [&](TLS::Alert alert, const byte[], size_t)
+void mutate(std::vector<byte>& v, RandomNumberGenerator& rng)
{
- if(alert.is_valid())
- std::cout << "Server recvd alert " << alert.type_string() << std::endl;
- };
+ if(v.empty())
+ return;
- auto save_server_data = [&](const byte buf[], size_t sz)
- {
- c2s_data.insert(c2s_data.end(), buf, buf+sz);
- };
+ size_t voff = rng.get_random<size_t>() % v.size();
+ v[voff] ^= rng.next_nonzero_byte();
+ }
- auto save_client_data = [&](const byte buf[], size_t sz)
+size_t test_tls_handshake(RandomNumberGenerator& rng,
+ TLS::Protocol_Version offer_version,
+ Credentials_Manager& creds,
+ TLS::Policy& policy)
{
- s2c_data.insert(s2c_data.end(), buf, buf+sz);
- };
+ TLS::Session_Manager_In_Memory server_sessions(rng);
+ TLS::Session_Manager_In_Memory client_sessions(rng);
- auto next_protocol_chooser = [&](std::vector<std::string> protos) {
- if(protos.size() != 2)
- std::cout << "Bad protocol size" << std::endl;
- if(protos[0] != "test/1" || protos[1] != "test/2")
- std::cout << "Bad protocol values" << std::endl;
- return "test/3";
- };
- const std::vector<std::string> protocols_offered = { "test/1", "test/2" };
-
- TLS::Server server([&](const byte buf[], size_t sz)
- { s2c_q.insert(s2c_q.end(), buf, buf+sz); },
- save_server_data,
- print_alert,
- handshake_complete,
- server_sessions,
- creds,
- policy,
- rng,
- next_protocol_chooser,
- offer_version.is_datagram_protocol());
-
- TLS::Client client([&](const byte buf[], size_t sz)
- { c2s_q.insert(c2s_q.end(), buf, buf+sz); },
- save_client_data,
- print_alert,
- handshake_complete,
- client_sessions,
- creds,
- policy,
- rng,
- TLS::Server_Information("server.example.com"),
- offer_version,
- protocols_offered);
-
- while(true)
+ for(size_t r = 1; r <= 4; ++r)
{
- if(client.is_closed() && server.is_closed())
- break;
+ //std::cout << offer_version.to_string() << " r " << r << "\n";
- if(client.is_active())
- client.send("1");
- if(server.is_active())
- {
- if(server.next_protocol() != "test/3")
- std::cout << "Wrong protocol " << server.next_protocol() << std::endl;
- server.send("2");
- }
+ bool handshake_done = false;
- /*
- * Use this as a temp value to hold the queues as otherwise they
- * might end up appending more in response to messages during the
- * handshake.
- */
- std::vector<byte> input;
- std::swap(c2s_q, input);
+ auto handshake_complete = [&](const TLS::Session& session) -> bool {
+ handshake_done = true;
- try
- {
- server.received_data(input.data(), input.size());
- }
- catch(std::exception& e)
- {
- std::cout << "Server error - " << e.what() << std::endl;
- return 1;
- }
+ /*
+ std::cout << "Session established " << session.version().to_string() << " "
+ << session.ciphersuite().to_string() << " " << hex_encode(session.session_id()) << "\n";
+ */
+
+ if(session.version() != offer_version)
+ std::cout << "Offered " << offer_version.to_string()
+ << " got " << session.version().to_string() << std::endl;
+ return true;
+ };
- input.clear();
- std::swap(s2c_q, input);
+ auto next_protocol_chooser = [&](std::vector<std::string> protos) {
+ if(protos.size() != 2)
+ std::cout << "Bad protocol size" << std::endl;
+ if(protos[0] != "test/1" || protos[1] != "test/2")
+ std::cout << "Bad protocol values" << std::endl;
+ return "test/3";
+ };
+
+ const std::vector<std::string> protocols_offered = { "test/1", "test/2" };
try
{
- client.received_data(input.data(), input.size());
+ std::vector<byte> c2s_traffic, s2c_traffic, client_recv, server_recv, client_sent, server_sent;
+
+ TLS::Server server(queue_inserter(s2c_traffic),
+ queue_inserter(server_recv),
+ print_alert,
+ handshake_complete,
+ server_sessions,
+ creds,
+ policy,
+ rng,
+ next_protocol_chooser,
+ false);
+
+ TLS::Client client(queue_inserter(c2s_traffic),
+ queue_inserter(client_recv),
+ print_alert,
+ handshake_complete,
+ client_sessions,
+ creds,
+ policy,
+ rng,
+ TLS::Server_Information("server.example.com"),
+ offer_version,
+ protocols_offered);
+
+ size_t rounds = 0;
+
+ while(true)
+ {
+ ++rounds;
+
+ if(rounds > 25)
+ {
+ std::cout << "Still here, something went wrong\n";
+ return 1;
+ }
+
+ if(handshake_done && (client.is_closed() || server.is_closed()))
+ break;
+
+ if(client.is_active() && client_sent.empty())
+ {
+ // Choose a len between 1 and 511
+ const size_t c_len = 1 + rng.next_byte() + rng.next_byte();
+ client_sent = unlock(rng.random_vec(c_len));
+
+ // TODO send in several records
+ client.send(client_sent);
+ }
+
+ if(server.is_active() && server_sent.empty())
+ {
+ if(server.next_protocol() != "test/3")
+ std::cout << "Wrong protocol " << server.next_protocol() << std::endl;
+
+ const size_t s_len = 1 + rng.next_byte() + rng.next_byte();
+ server_sent = unlock(rng.random_vec(s_len));
+ server.send(server_sent);
+ }
+
+ const bool corrupt_client_data = (r == 3 && c2s_traffic.size() && rng.next_byte() % 3 == 0 && rounds > 1);
+ const bool corrupt_server_data = (r == 4 && s2c_traffic.size() && rng.next_byte() % 3 == 0 && rounds > 1);
+
+ try
+ {
+ /*
+ * Use this as a temp value to hold the queues as otherwise they
+ * might end up appending more in response to messages during the
+ * handshake.
+ */
+ //std::cout << "server recv " << c2s_traffic.size() << " bytes\n";
+ std::vector<byte> input;
+ std::swap(c2s_traffic, input);
+
+ if(corrupt_server_data)
+ {
+ //std::cout << "Corrupting server data\n";
+ mutate(input, rng);
+ }
+ server.received_data(input.data(), input.size());
+ }
+ catch(std::exception& e)
+ {
+ std::cout << "Server error - " << e.what() << std::endl;
+ continue;
+ }
+
+ try
+ {
+ //std::cout << "client recv " << s2c_traffic.size() << " bytes\n";
+ std::vector<byte> input;
+ std::swap(s2c_traffic, input);
+ if(corrupt_client_data)
+ {
+ //std::cout << "Corrupting client data\n";
+ mutate(input, rng);
+ }
+
+ client.received_data(input.data(), input.size());
+ }
+ catch(std::exception& e)
+ {
+ std::cout << "Client error - " << e.what() << std::endl;
+ continue;
+ }
+
+ if(client_recv.size())
+ {
+ if(client_recv != server_sent)
+ {
+ std::cout << "Error in client recv" << std::endl;
+ return 1;
+ }
+ }
+
+ if(server_recv.size())
+ {
+ if(server_recv != client_sent)
+ {
+ std::cout << "Error in server recv" << std::endl;
+ return 1;
+ }
+ }
+
+ if(client.is_closed() && server.is_closed())
+ break;
+
+ if(server_recv.size() && client_recv.size())
+ {
+ SymmetricKey client_key = client.key_material_export("label", "context", 32);
+ SymmetricKey server_key = server.key_material_export("label", "context", 32);
+
+ if(client_key != server_key)
+ {
+ std::cout << "TLS key material export mismatch: "
+ << client_key.as_string() << " != "
+ << server_key.as_string() << "\n";
+ return 1;
+ }
+
+ if(r % 2 == 0)
+ client.close();
+ else
+ server.close();
+ }
+ }
}
catch(std::exception& e)
{
- std::cout << "Client error - " << e.what() << std::endl;
+ std::cout << e.what() << "\n";
return 1;
}
+ }
- if(c2s_data.size())
- {
- if(c2s_data[0] != '1')
- {
- std::cout << "Error" << std::endl;
- return 1;
- }
- }
+ return 0;
+ }
+
+size_t test_dtls_handshake(RandomNumberGenerator& rng,
+ TLS::Protocol_Version offer_version,
+ Credentials_Manager& creds,
+ TLS::Policy& policy)
+ {
+ BOTAN_ASSERT(offer_version.is_datagram_protocol(), "Test is for datagram version");
+
+ TLS::Session_Manager_In_Memory server_sessions(rng);
+ TLS::Session_Manager_In_Memory client_sessions(rng);
+
+ for(size_t r = 1; r <= 2; ++r)
+ {
+ //std::cout << offer_version.to_string() << " round " << r << "\n";
+
+ bool handshake_done = false;
+
+ auto handshake_complete = [&](const TLS::Session& session) -> bool {
+ handshake_done = true;
- if(s2c_data.size())
+ /*
+ std::cout << "Session established " << session.version().to_string() << " "
+ << session.ciphersuite().to_string() << " " << hex_encode(session.session_id()) << "\n";
+ */
+
+ if(session.version() != offer_version)
+ std::cout << "Offered " << offer_version.to_string()
+ << " got " << session.version().to_string() << std::endl;
+ return true;
+ };
+
+ auto next_protocol_chooser = [&](std::vector<std::string> protos) {
+ if(protos.size() != 2)
+ std::cout << "Bad protocol size" << std::endl;
+ if(protos[0] != "test/1" || protos[1] != "test/2")
+ std::cout << "Bad protocol values" << std::endl;
+ return "test/3";
+ };
+
+ const std::vector<std::string> protocols_offered = { "test/1", "test/2" };
+
+ try
{
- if(s2c_data[0] != '2')
+ std::vector<byte> c2s_traffic, s2c_traffic, client_recv, server_recv, client_sent, server_sent;
+
+ TLS::Server server(queue_inserter(s2c_traffic),
+ queue_inserter(server_recv),
+ print_alert,
+ handshake_complete,
+ server_sessions,
+ creds,
+ policy,
+ rng,
+ next_protocol_chooser,
+ true);
+
+ TLS::Client client(queue_inserter(c2s_traffic),
+ queue_inserter(client_recv),
+ print_alert,
+ handshake_complete,
+ client_sessions,
+ creds,
+ policy,
+ rng,
+ TLS::Server_Information("server.example.com"),
+ offer_version,
+ protocols_offered);
+
+ size_t rounds = 0;
+
+ while(true)
{
- std::cout << "Error" << std::endl;
- return 1;
- }
- }
+ // TODO: client and server should be in different threads
+ std::this_thread::sleep_for(std::chrono::milliseconds(rng.next_byte() % 2));
+ ++rounds;
- if(s2c_data.size() && c2s_data.size())
- {
- SymmetricKey client_key = client.key_material_export("label", "context", 32);
- SymmetricKey server_key = server.key_material_export("label", "context", 32);
+ if(rounds > 100)
+ {
+ std::cout << "Still here, something went wrong\n";
+ return 1;
+ }
+
+ if(handshake_done && (client.is_closed() || server.is_closed()))
+ break;
+
+ if(client.is_active() && client_sent.empty())
+ {
+ // Choose a len between 1 and 511 and send random chunks:
+ const size_t c_len = 1 + rng.next_byte() + rng.next_byte();
+ client_sent = unlock(rng.random_vec(c_len));
+
+ // TODO send multiple parts
+ //std::cout << "Sending " << client_sent.size() << " bytes to server\n";
+ client.send(client_sent);
+ }
- if(client_key != server_key)
- return 1;
+ if(server.is_active() && server_sent.empty())
+ {
+ if(server.next_protocol() != "test/3")
+ std::cout << "Wrong protocol " << server.next_protocol() << std::endl;
+
+ const size_t s_len = 1 + rng.next_byte() + rng.next_byte();
+ server_sent = unlock(rng.random_vec(s_len));
+ //std::cout << "Sending " << server_sent.size() << " bytes to client\n";
+ server.send(server_sent);
+ }
+
+ const bool corrupt_client_data = (r == 3 && c2s_traffic.size() && rng.next_byte() % 3 == 0 && rounds < 10);
+ const bool corrupt_server_data = (r == 4 && s2c_traffic.size() && rng.next_byte() % 3 == 0 && rounds < 10);
+
+ try
+ {
+ /*
+ * Use this as a temp value to hold the queues as otherwise they
+ * might end up appending more in response to messages during the
+ * handshake.
+ */
+ //std::cout << "server got " << c2s_traffic.size() << " bytes\n";
+ std::vector<byte> input;
+ std::swap(c2s_traffic, input);
+
+ if(corrupt_client_data)
+ {
+ //std::cout << "Corrupting client data\n";
+ mutate(input, rng);
+ }
+
+ server.received_data(input.data(), input.size());
+ }
+ catch(std::exception& e)
+ {
+ std::cout << "Server error - " << e.what() << std::endl;
+ continue;
+ }
+
+ try
+ {
+ //std::cout << "client got " << s2c_traffic.size() << " bytes\n";
+ std::vector<byte> input;
+ std::swap(s2c_traffic, input);
+
+ if(corrupt_server_data)
+ {
+ //std::cout << "Corrupting server data\n";
+ mutate(input, rng);
+ }
+ client.received_data(input.data(), input.size());
+ }
+ catch(std::exception& e)
+ {
+ std::cout << "Client error - " << e.what() << std::endl;
+ continue;
+ }
- server.close();
- client.close();
+ // If we corrupted a DTLS application message, resend it:
+ if(client.is_active() && corrupt_client_data && server_recv.empty())
+ client.send(client_sent);
+ if(server.is_active() && corrupt_server_data && client_recv.empty())
+ server.send(server_sent);
+
+ if(client_recv.size())
+ {
+ if(client_recv != server_sent)
+ {
+ std::cout << "Error in client recv" << std::endl;
+ return 1;
+ }
+ }
+
+ if(server_recv.size())
+ {
+ if(server_recv != client_sent)
+ {
+ std::cout << "Error in server recv" << std::endl;
+ return 1;
+ }
+ }
+
+ if(client.is_closed() && server.is_closed())
+ break;
+
+ if(server_recv.size() && client_recv.size())
+ {
+ SymmetricKey client_key = client.key_material_export("label", "context", 32);
+ SymmetricKey server_key = server.key_material_export("label", "context", 32);
+
+ if(client_key != server_key)
+ {
+ std::cout << "TLS key material export mismatch: "
+ << client_key.as_string() << " != "
+ << server_key.as_string() << "\n";
+ return 1;
+ }
+
+ if(r % 2 == 0)
+ client.close();
+ else
+ server.close();
+ }
+ }
+ }
+ catch(std::exception& e)
+ {
+ std::cout << e.what() << "\n";
+ return 1;
}
}
return 0;
}
-class Test_Policy : public TLS::Policy
+class Test_Policy : public TLS::Text_Policy
{
public:
+ Test_Policy() : Text_Policy("") {}
bool acceptable_protocol_version(TLS::Protocol_Version) const override { return true; }
bool send_fallback_scsv(TLS::Protocol_Version) const override { return false; }
+
+ size_t dtls_initial_timeout() const override { return 1; }
+ size_t dtls_maximum_timeout() const override { return 8; }
};
}
@@ -315,17 +600,43 @@ size_t test_tls()
{
size_t errors = 0;
- Test_Policy default_policy;
auto& rng = test_rng();
std::unique_ptr<Credentials_Manager> basic_creds(create_creds());
- errors += basic_test_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, default_policy);
- errors += basic_test_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, default_policy);
- errors += basic_test_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, default_policy);
- errors += basic_test_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, default_policy);
- errors += basic_test_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, default_policy);
-
- test_report("TLS", 5, errors);
+ Test_Policy policy;
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy);
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy);
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy);
+
+ policy.set("key_exchange_methods", "RSA");
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy);
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy);
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy);
+
+ policy.set("key_exchange_methods", "DH");
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy);
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy);
+ policy.set("key_exchange_methods", "ECDH");
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy);
+
+ policy.set("ciphers", "AES-128");
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy);
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy);
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy);
+
+ policy.set("ciphers", "ChaCha20Poly1305");
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy);
+
+ test_report("TLS", 22, errors);
return errors;
}