aboutsummaryrefslogtreecommitdiffstats
path: root/src/tests/unit_tls.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/tests/unit_tls.cpp')
-rw-r--r--src/tests/unit_tls.cpp551
1 files changed, 431 insertions, 120 deletions
diff --git a/src/tests/unit_tls.cpp b/src/tests/unit_tls.cpp
index 116eb2cdf..d4abef119 100644
--- a/src/tests/unit_tls.cpp
+++ b/src/tests/unit_tls.cpp
@@ -11,6 +11,7 @@
#include <botan/tls_server.h>
#include <botan/tls_client.h>
+#include <botan/tls_handshake_msg.h>
#include <botan/pkcs10.h>
#include <botan/x509self.h>
#include <botan/rsa.h>
@@ -21,6 +22,7 @@
#include <iostream>
#include <vector>
#include <memory>
+#include <thread>
using namespace Botan;
@@ -155,158 +157,441 @@ Credentials_Manager* create_creds()
return new Credentials_Manager_Test(server_cert, ca_cert, server_key);
}
-size_t basic_test_handshake(RandomNumberGenerator& rng,
- TLS::Protocol_Version offer_version,
- Credentials_Manager& creds,
- TLS::Policy& policy)
+std::function<void (const byte[], size_t)> queue_inserter(std::vector<byte>& q)
{
- TLS::Session_Manager_In_Memory server_sessions(rng);
- TLS::Session_Manager_In_Memory client_sessions(rng);
-
- std::vector<byte> c2s_q, s2c_q, c2s_data, s2c_data;
+ return [&](const byte buf[], size_t sz) { q.insert(q.end(), buf, buf + sz); };
+ }
- auto handshake_complete = [&](const TLS::Session& session) -> bool
+void print_alert(TLS::Alert alert, const byte[], size_t)
{
- if(session.version() != offer_version)
- std::cout << "Offered " << offer_version.to_string()
- << " got " << session.version().to_string() << std::endl;
- return true;
+ //std::cout << "Alert " << alert.type_string() << std::endl;
};
- auto print_alert = [&](TLS::Alert alert, const byte[], size_t)
+void mutate(std::vector<byte>& v, RandomNumberGenerator& rng)
{
- if(alert.is_valid())
- std::cout << "Server recvd alert " << alert.type_string() << std::endl;
- };
+ if(v.empty())
+ return;
- auto save_server_data = [&](const byte buf[], size_t sz)
- {
- c2s_data.insert(c2s_data.end(), buf, buf+sz);
- };
+ size_t voff = rng.get_random<size_t>() % v.size();
+ v[voff] ^= rng.next_nonzero_byte();
+ }
- auto save_client_data = [&](const byte buf[], size_t sz)
+size_t test_tls_handshake(RandomNumberGenerator& rng,
+ TLS::Protocol_Version offer_version,
+ Credentials_Manager& creds,
+ TLS::Policy& policy)
{
- s2c_data.insert(s2c_data.end(), buf, buf+sz);
- };
+ TLS::Session_Manager_In_Memory server_sessions(rng);
+ TLS::Session_Manager_In_Memory client_sessions(rng);
- auto next_protocol_chooser = [&](std::vector<std::string> protos) {
- if(protos.size() != 2)
- std::cout << "Bad protocol size" << std::endl;
- if(protos[0] != "test/1" || protos[1] != "test/2")
- std::cout << "Bad protocol values" << std::endl;
- return "test/3";
- };
- const std::vector<std::string> protocols_offered = { "test/1", "test/2" };
-
- TLS::Server server([&](const byte buf[], size_t sz)
- { s2c_q.insert(s2c_q.end(), buf, buf+sz); },
- save_server_data,
- print_alert,
- handshake_complete,
- server_sessions,
- creds,
- policy,
- rng,
- next_protocol_chooser,
- offer_version.is_datagram_protocol());
-
- TLS::Client client([&](const byte buf[], size_t sz)
- { c2s_q.insert(c2s_q.end(), buf, buf+sz); },
- save_client_data,
- print_alert,
- handshake_complete,
- client_sessions,
- creds,
- policy,
- rng,
- TLS::Server_Information("server.example.com"),
- offer_version,
- protocols_offered);
-
- while(true)
+ for(size_t r = 1; r <= 4; ++r)
{
- if(client.is_closed() && server.is_closed())
- break;
+ //std::cout << offer_version.to_string() << " r " << r << "\n";
- if(client.is_active())
- client.send("1");
- if(server.is_active())
- {
- if(server.next_protocol() != "test/3")
- std::cout << "Wrong protocol " << server.next_protocol() << std::endl;
- server.send("2");
- }
+ bool handshake_done = false;
- /*
- * Use this as a temp value to hold the queues as otherwise they
- * might end up appending more in response to messages during the
- * handshake.
- */
- std::vector<byte> input;
- std::swap(c2s_q, input);
+ auto handshake_complete = [&](const TLS::Session& session) -> bool {
+ handshake_done = true;
- try
- {
- server.received_data(input.data(), input.size());
- }
- catch(std::exception& e)
- {
- std::cout << "Server error - " << e.what() << std::endl;
- return 1;
- }
+ /*
+ std::cout << "Session established " << session.version().to_string() << " "
+ << session.ciphersuite().to_string() << " " << hex_encode(session.session_id()) << "\n";
+ */
+
+ if(session.version() != offer_version)
+ std::cout << "Offered " << offer_version.to_string()
+ << " got " << session.version().to_string() << std::endl;
+ return true;
+ };
- input.clear();
- std::swap(s2c_q, input);
+ auto next_protocol_chooser = [&](std::vector<std::string> protos) {
+ if(protos.size() != 2)
+ std::cout << "Bad protocol size" << std::endl;
+ if(protos[0] != "test/1" || protos[1] != "test/2")
+ std::cout << "Bad protocol values" << std::endl;
+ return "test/3";
+ };
+
+ const std::vector<std::string> protocols_offered = { "test/1", "test/2" };
try
{
- client.received_data(input.data(), input.size());
+ std::vector<byte> c2s_traffic, s2c_traffic, client_recv, server_recv, client_sent, server_sent;
+
+ TLS::Server server(queue_inserter(s2c_traffic),
+ queue_inserter(server_recv),
+ print_alert,
+ handshake_complete,
+ server_sessions,
+ creds,
+ policy,
+ rng,
+ next_protocol_chooser,
+ false);
+
+ TLS::Client client(queue_inserter(c2s_traffic),
+ queue_inserter(client_recv),
+ print_alert,
+ handshake_complete,
+ client_sessions,
+ creds,
+ policy,
+ rng,
+ TLS::Server_Information("server.example.com"),
+ offer_version,
+ protocols_offered);
+
+ size_t rounds = 0;
+
+ while(true)
+ {
+ ++rounds;
+
+ if(rounds > 25)
+ {
+ std::cout << "Still here, something went wrong\n";
+ return 1;
+ }
+
+ if(handshake_done && (client.is_closed() || server.is_closed()))
+ break;
+
+ if(client.is_active() && client_sent.empty())
+ {
+ // Choose a len between 1 and 511
+ const size_t c_len = 1 + rng.next_byte() + rng.next_byte();
+ client_sent = unlock(rng.random_vec(c_len));
+
+ // TODO send in several records
+ client.send(client_sent);
+ }
+
+ if(server.is_active() && server_sent.empty())
+ {
+ if(server.next_protocol() != "test/3")
+ std::cout << "Wrong protocol " << server.next_protocol() << std::endl;
+
+ const size_t s_len = 1 + rng.next_byte() + rng.next_byte();
+ server_sent = unlock(rng.random_vec(s_len));
+ server.send(server_sent);
+ }
+
+ const bool corrupt_client_data = (r == 3 && c2s_traffic.size() && rng.next_byte() % 3 == 0 && rounds > 1);
+ const bool corrupt_server_data = (r == 4 && s2c_traffic.size() && rng.next_byte() % 3 == 0 && rounds > 1);
+
+ try
+ {
+ /*
+ * Use this as a temp value to hold the queues as otherwise they
+ * might end up appending more in response to messages during the
+ * handshake.
+ */
+ //std::cout << "server recv " << c2s_traffic.size() << " bytes\n";
+ std::vector<byte> input;
+ std::swap(c2s_traffic, input);
+
+ if(corrupt_server_data)
+ {
+ //std::cout << "Corrupting server data\n";
+ mutate(input, rng);
+ }
+ server.received_data(input.data(), input.size());
+ }
+ catch(std::exception& e)
+ {
+ std::cout << "Server error - " << e.what() << std::endl;
+ continue;
+ }
+
+ try
+ {
+ //std::cout << "client recv " << s2c_traffic.size() << " bytes\n";
+ std::vector<byte> input;
+ std::swap(s2c_traffic, input);
+ if(corrupt_client_data)
+ {
+ //std::cout << "Corrupting client data\n";
+ mutate(input, rng);
+ }
+
+ client.received_data(input.data(), input.size());
+ }
+ catch(std::exception& e)
+ {
+ std::cout << "Client error - " << e.what() << std::endl;
+ continue;
+ }
+
+ if(client_recv.size())
+ {
+ if(client_recv != server_sent)
+ {
+ std::cout << "Error in client recv" << std::endl;
+ return 1;
+ }
+ }
+
+ if(server_recv.size())
+ {
+ if(server_recv != client_sent)
+ {
+ std::cout << "Error in server recv" << std::endl;
+ return 1;
+ }
+ }
+
+ if(client.is_closed() && server.is_closed())
+ break;
+
+ if(server_recv.size() && client_recv.size())
+ {
+ SymmetricKey client_key = client.key_material_export("label", "context", 32);
+ SymmetricKey server_key = server.key_material_export("label", "context", 32);
+
+ if(client_key != server_key)
+ {
+ std::cout << "TLS key material export mismatch: "
+ << client_key.as_string() << " != "
+ << server_key.as_string() << "\n";
+ return 1;
+ }
+
+ if(r % 2 == 0)
+ client.close();
+ else
+ server.close();
+ }
+ }
}
catch(std::exception& e)
{
- std::cout << "Client error - " << e.what() << std::endl;
+ std::cout << e.what() << "\n";
return 1;
}
+ }
- if(c2s_data.size())
- {
- if(c2s_data[0] != '1')
- {
- std::cout << "Error" << std::endl;
- return 1;
- }
- }
+ return 0;
+ }
+
+size_t test_dtls_handshake(RandomNumberGenerator& rng,
+ TLS::Protocol_Version offer_version,
+ Credentials_Manager& creds,
+ TLS::Policy& policy)
+ {
+ BOTAN_ASSERT(offer_version.is_datagram_protocol(), "Test is for datagram version");
+
+ TLS::Session_Manager_In_Memory server_sessions(rng);
+ TLS::Session_Manager_In_Memory client_sessions(rng);
+
+ for(size_t r = 1; r <= 2; ++r)
+ {
+ //std::cout << offer_version.to_string() << " round " << r << "\n";
+
+ bool handshake_done = false;
+
+ auto handshake_complete = [&](const TLS::Session& session) -> bool {
+ handshake_done = true;
- if(s2c_data.size())
+ /*
+ std::cout << "Session established " << session.version().to_string() << " "
+ << session.ciphersuite().to_string() << " " << hex_encode(session.session_id()) << "\n";
+ */
+
+ if(session.version() != offer_version)
+ std::cout << "Offered " << offer_version.to_string()
+ << " got " << session.version().to_string() << std::endl;
+ return true;
+ };
+
+ auto next_protocol_chooser = [&](std::vector<std::string> protos) {
+ if(protos.size() != 2)
+ std::cout << "Bad protocol size" << std::endl;
+ if(protos[0] != "test/1" || protos[1] != "test/2")
+ std::cout << "Bad protocol values" << std::endl;
+ return "test/3";
+ };
+
+ const std::vector<std::string> protocols_offered = { "test/1", "test/2" };
+
+ try
{
- if(s2c_data[0] != '2')
+ std::vector<byte> c2s_traffic, s2c_traffic, client_recv, server_recv, client_sent, server_sent;
+
+ TLS::Server server(queue_inserter(s2c_traffic),
+ queue_inserter(server_recv),
+ print_alert,
+ handshake_complete,
+ server_sessions,
+ creds,
+ policy,
+ rng,
+ next_protocol_chooser,
+ true);
+
+ TLS::Client client(queue_inserter(c2s_traffic),
+ queue_inserter(client_recv),
+ print_alert,
+ handshake_complete,
+ client_sessions,
+ creds,
+ policy,
+ rng,
+ TLS::Server_Information("server.example.com"),
+ offer_version,
+ protocols_offered);
+
+ size_t rounds = 0;
+
+ while(true)
{
- std::cout << "Error" << std::endl;
- return 1;
- }
- }
+ // TODO: client and server should be in different threads
+ std::this_thread::sleep_for(std::chrono::milliseconds(rng.next_byte() % 2));
+ ++rounds;
- if(s2c_data.size() && c2s_data.size())
- {
- SymmetricKey client_key = client.key_material_export("label", "context", 32);
- SymmetricKey server_key = server.key_material_export("label", "context", 32);
+ if(rounds > 100)
+ {
+ std::cout << "Still here, something went wrong\n";
+ return 1;
+ }
+
+ if(handshake_done && (client.is_closed() || server.is_closed()))
+ break;
+
+ if(client.is_active() && client_sent.empty())
+ {
+ // Choose a len between 1 and 511 and send random chunks:
+ const size_t c_len = 1 + rng.next_byte() + rng.next_byte();
+ client_sent = unlock(rng.random_vec(c_len));
+
+ // TODO send multiple parts
+ //std::cout << "Sending " << client_sent.size() << " bytes to server\n";
+ client.send(client_sent);
+ }
- if(client_key != server_key)
- return 1;
+ if(server.is_active() && server_sent.empty())
+ {
+ if(server.next_protocol() != "test/3")
+ std::cout << "Wrong protocol " << server.next_protocol() << std::endl;
+
+ const size_t s_len = 1 + rng.next_byte() + rng.next_byte();
+ server_sent = unlock(rng.random_vec(s_len));
+ //std::cout << "Sending " << server_sent.size() << " bytes to client\n";
+ server.send(server_sent);
+ }
+
+ const bool corrupt_client_data = (r == 3 && c2s_traffic.size() && rng.next_byte() % 3 == 0 && rounds < 10);
+ const bool corrupt_server_data = (r == 4 && s2c_traffic.size() && rng.next_byte() % 3 == 0 && rounds < 10);
+
+ try
+ {
+ /*
+ * Use this as a temp value to hold the queues as otherwise they
+ * might end up appending more in response to messages during the
+ * handshake.
+ */
+ //std::cout << "server got " << c2s_traffic.size() << " bytes\n";
+ std::vector<byte> input;
+ std::swap(c2s_traffic, input);
+
+ if(corrupt_client_data)
+ {
+ //std::cout << "Corrupting client data\n";
+ mutate(input, rng);
+ }
+
+ server.received_data(input.data(), input.size());
+ }
+ catch(std::exception& e)
+ {
+ std::cout << "Server error - " << e.what() << std::endl;
+ continue;
+ }
+
+ try
+ {
+ //std::cout << "client got " << s2c_traffic.size() << " bytes\n";
+ std::vector<byte> input;
+ std::swap(s2c_traffic, input);
+
+ if(corrupt_server_data)
+ {
+ //std::cout << "Corrupting server data\n";
+ mutate(input, rng);
+ }
+ client.received_data(input.data(), input.size());
+ }
+ catch(std::exception& e)
+ {
+ std::cout << "Client error - " << e.what() << std::endl;
+ continue;
+ }
- server.close();
- client.close();
+ // If we corrupted a DTLS application message, resend it:
+ if(client.is_active() && corrupt_client_data && server_recv.empty())
+ client.send(client_sent);
+ if(server.is_active() && corrupt_server_data && client_recv.empty())
+ server.send(server_sent);
+
+ if(client_recv.size())
+ {
+ if(client_recv != server_sent)
+ {
+ std::cout << "Error in client recv" << std::endl;
+ return 1;
+ }
+ }
+
+ if(server_recv.size())
+ {
+ if(server_recv != client_sent)
+ {
+ std::cout << "Error in server recv" << std::endl;
+ return 1;
+ }
+ }
+
+ if(client.is_closed() && server.is_closed())
+ break;
+
+ if(server_recv.size() && client_recv.size())
+ {
+ SymmetricKey client_key = client.key_material_export("label", "context", 32);
+ SymmetricKey server_key = server.key_material_export("label", "context", 32);
+
+ if(client_key != server_key)
+ {
+ std::cout << "TLS key material export mismatch: "
+ << client_key.as_string() << " != "
+ << server_key.as_string() << "\n";
+ return 1;
+ }
+
+ if(r % 2 == 0)
+ client.close();
+ else
+ server.close();
+ }
+ }
+ }
+ catch(std::exception& e)
+ {
+ std::cout << e.what() << "\n";
+ return 1;
}
}
return 0;
}
-class Test_Policy : public TLS::Policy
+class Test_Policy : public TLS::Text_Policy
{
public:
+ Test_Policy() : Text_Policy("") {}
bool acceptable_protocol_version(TLS::Protocol_Version) const override { return true; }
bool send_fallback_scsv(TLS::Protocol_Version) const override { return false; }
+
+ size_t dtls_initial_timeout() const override { return 1; }
+ size_t dtls_maximum_timeout() const override { return 8; }
};
}
@@ -315,17 +600,43 @@ size_t test_tls()
{
size_t errors = 0;
- Test_Policy default_policy;
auto& rng = test_rng();
std::unique_ptr<Credentials_Manager> basic_creds(create_creds());
- errors += basic_test_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, default_policy);
- errors += basic_test_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, default_policy);
- errors += basic_test_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, default_policy);
- errors += basic_test_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, default_policy);
- errors += basic_test_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, default_policy);
-
- test_report("TLS", 5, errors);
+ Test_Policy policy;
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy);
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy);
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy);
+
+ policy.set("key_exchange_methods", "RSA");
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy);
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy);
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy);
+
+ policy.set("key_exchange_methods", "DH");
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy);
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy);
+ policy.set("key_exchange_methods", "ECDH");
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy);
+
+ policy.set("ciphers", "AES-128");
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy);
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy);
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy);
+
+ policy.set("ciphers", "ChaCha20Poly1305");
+ errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy);
+ errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy);
+
+ test_report("TLS", 22, errors);
return errors;
}