diff options
Diffstat (limited to 'src/tests/unit_tls.cpp')
-rw-r--r-- | src/tests/unit_tls.cpp | 551 |
1 files changed, 431 insertions, 120 deletions
diff --git a/src/tests/unit_tls.cpp b/src/tests/unit_tls.cpp index 116eb2cdf..d4abef119 100644 --- a/src/tests/unit_tls.cpp +++ b/src/tests/unit_tls.cpp @@ -11,6 +11,7 @@ #include <botan/tls_server.h> #include <botan/tls_client.h> +#include <botan/tls_handshake_msg.h> #include <botan/pkcs10.h> #include <botan/x509self.h> #include <botan/rsa.h> @@ -21,6 +22,7 @@ #include <iostream> #include <vector> #include <memory> +#include <thread> using namespace Botan; @@ -155,158 +157,441 @@ Credentials_Manager* create_creds() return new Credentials_Manager_Test(server_cert, ca_cert, server_key); } -size_t basic_test_handshake(RandomNumberGenerator& rng, - TLS::Protocol_Version offer_version, - Credentials_Manager& creds, - TLS::Policy& policy) +std::function<void (const byte[], size_t)> queue_inserter(std::vector<byte>& q) { - TLS::Session_Manager_In_Memory server_sessions(rng); - TLS::Session_Manager_In_Memory client_sessions(rng); - - std::vector<byte> c2s_q, s2c_q, c2s_data, s2c_data; + return [&](const byte buf[], size_t sz) { q.insert(q.end(), buf, buf + sz); }; + } - auto handshake_complete = [&](const TLS::Session& session) -> bool +void print_alert(TLS::Alert alert, const byte[], size_t) { - if(session.version() != offer_version) - std::cout << "Offered " << offer_version.to_string() - << " got " << session.version().to_string() << std::endl; - return true; + //std::cout << "Alert " << alert.type_string() << std::endl; }; - auto print_alert = [&](TLS::Alert alert, const byte[], size_t) +void mutate(std::vector<byte>& v, RandomNumberGenerator& rng) { - if(alert.is_valid()) - std::cout << "Server recvd alert " << alert.type_string() << std::endl; - }; + if(v.empty()) + return; - auto save_server_data = [&](const byte buf[], size_t sz) - { - c2s_data.insert(c2s_data.end(), buf, buf+sz); - }; + size_t voff = rng.get_random<size_t>() % v.size(); + v[voff] ^= rng.next_nonzero_byte(); + } - auto save_client_data = [&](const byte buf[], size_t sz) +size_t test_tls_handshake(RandomNumberGenerator& rng, + TLS::Protocol_Version offer_version, + Credentials_Manager& creds, + TLS::Policy& policy) { - s2c_data.insert(s2c_data.end(), buf, buf+sz); - }; + TLS::Session_Manager_In_Memory server_sessions(rng); + TLS::Session_Manager_In_Memory client_sessions(rng); - auto next_protocol_chooser = [&](std::vector<std::string> protos) { - if(protos.size() != 2) - std::cout << "Bad protocol size" << std::endl; - if(protos[0] != "test/1" || protos[1] != "test/2") - std::cout << "Bad protocol values" << std::endl; - return "test/3"; - }; - const std::vector<std::string> protocols_offered = { "test/1", "test/2" }; - - TLS::Server server([&](const byte buf[], size_t sz) - { s2c_q.insert(s2c_q.end(), buf, buf+sz); }, - save_server_data, - print_alert, - handshake_complete, - server_sessions, - creds, - policy, - rng, - next_protocol_chooser, - offer_version.is_datagram_protocol()); - - TLS::Client client([&](const byte buf[], size_t sz) - { c2s_q.insert(c2s_q.end(), buf, buf+sz); }, - save_client_data, - print_alert, - handshake_complete, - client_sessions, - creds, - policy, - rng, - TLS::Server_Information("server.example.com"), - offer_version, - protocols_offered); - - while(true) + for(size_t r = 1; r <= 4; ++r) { - if(client.is_closed() && server.is_closed()) - break; + //std::cout << offer_version.to_string() << " r " << r << "\n"; - if(client.is_active()) - client.send("1"); - if(server.is_active()) - { - if(server.next_protocol() != "test/3") - std::cout << "Wrong protocol " << server.next_protocol() << std::endl; - server.send("2"); - } + bool handshake_done = false; - /* - * Use this as a temp value to hold the queues as otherwise they - * might end up appending more in response to messages during the - * handshake. - */ - std::vector<byte> input; - std::swap(c2s_q, input); + auto handshake_complete = [&](const TLS::Session& session) -> bool { + handshake_done = true; - try - { - server.received_data(input.data(), input.size()); - } - catch(std::exception& e) - { - std::cout << "Server error - " << e.what() << std::endl; - return 1; - } + /* + std::cout << "Session established " << session.version().to_string() << " " + << session.ciphersuite().to_string() << " " << hex_encode(session.session_id()) << "\n"; + */ + + if(session.version() != offer_version) + std::cout << "Offered " << offer_version.to_string() + << " got " << session.version().to_string() << std::endl; + return true; + }; - input.clear(); - std::swap(s2c_q, input); + auto next_protocol_chooser = [&](std::vector<std::string> protos) { + if(protos.size() != 2) + std::cout << "Bad protocol size" << std::endl; + if(protos[0] != "test/1" || protos[1] != "test/2") + std::cout << "Bad protocol values" << std::endl; + return "test/3"; + }; + + const std::vector<std::string> protocols_offered = { "test/1", "test/2" }; try { - client.received_data(input.data(), input.size()); + std::vector<byte> c2s_traffic, s2c_traffic, client_recv, server_recv, client_sent, server_sent; + + TLS::Server server(queue_inserter(s2c_traffic), + queue_inserter(server_recv), + print_alert, + handshake_complete, + server_sessions, + creds, + policy, + rng, + next_protocol_chooser, + false); + + TLS::Client client(queue_inserter(c2s_traffic), + queue_inserter(client_recv), + print_alert, + handshake_complete, + client_sessions, + creds, + policy, + rng, + TLS::Server_Information("server.example.com"), + offer_version, + protocols_offered); + + size_t rounds = 0; + + while(true) + { + ++rounds; + + if(rounds > 25) + { + std::cout << "Still here, something went wrong\n"; + return 1; + } + + if(handshake_done && (client.is_closed() || server.is_closed())) + break; + + if(client.is_active() && client_sent.empty()) + { + // Choose a len between 1 and 511 + const size_t c_len = 1 + rng.next_byte() + rng.next_byte(); + client_sent = unlock(rng.random_vec(c_len)); + + // TODO send in several records + client.send(client_sent); + } + + if(server.is_active() && server_sent.empty()) + { + if(server.next_protocol() != "test/3") + std::cout << "Wrong protocol " << server.next_protocol() << std::endl; + + const size_t s_len = 1 + rng.next_byte() + rng.next_byte(); + server_sent = unlock(rng.random_vec(s_len)); + server.send(server_sent); + } + + const bool corrupt_client_data = (r == 3 && c2s_traffic.size() && rng.next_byte() % 3 == 0 && rounds > 1); + const bool corrupt_server_data = (r == 4 && s2c_traffic.size() && rng.next_byte() % 3 == 0 && rounds > 1); + + try + { + /* + * Use this as a temp value to hold the queues as otherwise they + * might end up appending more in response to messages during the + * handshake. + */ + //std::cout << "server recv " << c2s_traffic.size() << " bytes\n"; + std::vector<byte> input; + std::swap(c2s_traffic, input); + + if(corrupt_server_data) + { + //std::cout << "Corrupting server data\n"; + mutate(input, rng); + } + server.received_data(input.data(), input.size()); + } + catch(std::exception& e) + { + std::cout << "Server error - " << e.what() << std::endl; + continue; + } + + try + { + //std::cout << "client recv " << s2c_traffic.size() << " bytes\n"; + std::vector<byte> input; + std::swap(s2c_traffic, input); + if(corrupt_client_data) + { + //std::cout << "Corrupting client data\n"; + mutate(input, rng); + } + + client.received_data(input.data(), input.size()); + } + catch(std::exception& e) + { + std::cout << "Client error - " << e.what() << std::endl; + continue; + } + + if(client_recv.size()) + { + if(client_recv != server_sent) + { + std::cout << "Error in client recv" << std::endl; + return 1; + } + } + + if(server_recv.size()) + { + if(server_recv != client_sent) + { + std::cout << "Error in server recv" << std::endl; + return 1; + } + } + + if(client.is_closed() && server.is_closed()) + break; + + if(server_recv.size() && client_recv.size()) + { + SymmetricKey client_key = client.key_material_export("label", "context", 32); + SymmetricKey server_key = server.key_material_export("label", "context", 32); + + if(client_key != server_key) + { + std::cout << "TLS key material export mismatch: " + << client_key.as_string() << " != " + << server_key.as_string() << "\n"; + return 1; + } + + if(r % 2 == 0) + client.close(); + else + server.close(); + } + } } catch(std::exception& e) { - std::cout << "Client error - " << e.what() << std::endl; + std::cout << e.what() << "\n"; return 1; } + } - if(c2s_data.size()) - { - if(c2s_data[0] != '1') - { - std::cout << "Error" << std::endl; - return 1; - } - } + return 0; + } + +size_t test_dtls_handshake(RandomNumberGenerator& rng, + TLS::Protocol_Version offer_version, + Credentials_Manager& creds, + TLS::Policy& policy) + { + BOTAN_ASSERT(offer_version.is_datagram_protocol(), "Test is for datagram version"); + + TLS::Session_Manager_In_Memory server_sessions(rng); + TLS::Session_Manager_In_Memory client_sessions(rng); + + for(size_t r = 1; r <= 2; ++r) + { + //std::cout << offer_version.to_string() << " round " << r << "\n"; + + bool handshake_done = false; + + auto handshake_complete = [&](const TLS::Session& session) -> bool { + handshake_done = true; - if(s2c_data.size()) + /* + std::cout << "Session established " << session.version().to_string() << " " + << session.ciphersuite().to_string() << " " << hex_encode(session.session_id()) << "\n"; + */ + + if(session.version() != offer_version) + std::cout << "Offered " << offer_version.to_string() + << " got " << session.version().to_string() << std::endl; + return true; + }; + + auto next_protocol_chooser = [&](std::vector<std::string> protos) { + if(protos.size() != 2) + std::cout << "Bad protocol size" << std::endl; + if(protos[0] != "test/1" || protos[1] != "test/2") + std::cout << "Bad protocol values" << std::endl; + return "test/3"; + }; + + const std::vector<std::string> protocols_offered = { "test/1", "test/2" }; + + try { - if(s2c_data[0] != '2') + std::vector<byte> c2s_traffic, s2c_traffic, client_recv, server_recv, client_sent, server_sent; + + TLS::Server server(queue_inserter(s2c_traffic), + queue_inserter(server_recv), + print_alert, + handshake_complete, + server_sessions, + creds, + policy, + rng, + next_protocol_chooser, + true); + + TLS::Client client(queue_inserter(c2s_traffic), + queue_inserter(client_recv), + print_alert, + handshake_complete, + client_sessions, + creds, + policy, + rng, + TLS::Server_Information("server.example.com"), + offer_version, + protocols_offered); + + size_t rounds = 0; + + while(true) { - std::cout << "Error" << std::endl; - return 1; - } - } + // TODO: client and server should be in different threads + std::this_thread::sleep_for(std::chrono::milliseconds(rng.next_byte() % 2)); + ++rounds; - if(s2c_data.size() && c2s_data.size()) - { - SymmetricKey client_key = client.key_material_export("label", "context", 32); - SymmetricKey server_key = server.key_material_export("label", "context", 32); + if(rounds > 100) + { + std::cout << "Still here, something went wrong\n"; + return 1; + } + + if(handshake_done && (client.is_closed() || server.is_closed())) + break; + + if(client.is_active() && client_sent.empty()) + { + // Choose a len between 1 and 511 and send random chunks: + const size_t c_len = 1 + rng.next_byte() + rng.next_byte(); + client_sent = unlock(rng.random_vec(c_len)); + + // TODO send multiple parts + //std::cout << "Sending " << client_sent.size() << " bytes to server\n"; + client.send(client_sent); + } - if(client_key != server_key) - return 1; + if(server.is_active() && server_sent.empty()) + { + if(server.next_protocol() != "test/3") + std::cout << "Wrong protocol " << server.next_protocol() << std::endl; + + const size_t s_len = 1 + rng.next_byte() + rng.next_byte(); + server_sent = unlock(rng.random_vec(s_len)); + //std::cout << "Sending " << server_sent.size() << " bytes to client\n"; + server.send(server_sent); + } + + const bool corrupt_client_data = (r == 3 && c2s_traffic.size() && rng.next_byte() % 3 == 0 && rounds < 10); + const bool corrupt_server_data = (r == 4 && s2c_traffic.size() && rng.next_byte() % 3 == 0 && rounds < 10); + + try + { + /* + * Use this as a temp value to hold the queues as otherwise they + * might end up appending more in response to messages during the + * handshake. + */ + //std::cout << "server got " << c2s_traffic.size() << " bytes\n"; + std::vector<byte> input; + std::swap(c2s_traffic, input); + + if(corrupt_client_data) + { + //std::cout << "Corrupting client data\n"; + mutate(input, rng); + } + + server.received_data(input.data(), input.size()); + } + catch(std::exception& e) + { + std::cout << "Server error - " << e.what() << std::endl; + continue; + } + + try + { + //std::cout << "client got " << s2c_traffic.size() << " bytes\n"; + std::vector<byte> input; + std::swap(s2c_traffic, input); + + if(corrupt_server_data) + { + //std::cout << "Corrupting server data\n"; + mutate(input, rng); + } + client.received_data(input.data(), input.size()); + } + catch(std::exception& e) + { + std::cout << "Client error - " << e.what() << std::endl; + continue; + } - server.close(); - client.close(); + // If we corrupted a DTLS application message, resend it: + if(client.is_active() && corrupt_client_data && server_recv.empty()) + client.send(client_sent); + if(server.is_active() && corrupt_server_data && client_recv.empty()) + server.send(server_sent); + + if(client_recv.size()) + { + if(client_recv != server_sent) + { + std::cout << "Error in client recv" << std::endl; + return 1; + } + } + + if(server_recv.size()) + { + if(server_recv != client_sent) + { + std::cout << "Error in server recv" << std::endl; + return 1; + } + } + + if(client.is_closed() && server.is_closed()) + break; + + if(server_recv.size() && client_recv.size()) + { + SymmetricKey client_key = client.key_material_export("label", "context", 32); + SymmetricKey server_key = server.key_material_export("label", "context", 32); + + if(client_key != server_key) + { + std::cout << "TLS key material export mismatch: " + << client_key.as_string() << " != " + << server_key.as_string() << "\n"; + return 1; + } + + if(r % 2 == 0) + client.close(); + else + server.close(); + } + } + } + catch(std::exception& e) + { + std::cout << e.what() << "\n"; + return 1; } } return 0; } -class Test_Policy : public TLS::Policy +class Test_Policy : public TLS::Text_Policy { public: + Test_Policy() : Text_Policy("") {} bool acceptable_protocol_version(TLS::Protocol_Version) const override { return true; } bool send_fallback_scsv(TLS::Protocol_Version) const override { return false; } + + size_t dtls_initial_timeout() const override { return 1; } + size_t dtls_maximum_timeout() const override { return 8; } }; } @@ -315,17 +600,43 @@ size_t test_tls() { size_t errors = 0; - Test_Policy default_policy; auto& rng = test_rng(); std::unique_ptr<Credentials_Manager> basic_creds(create_creds()); - errors += basic_test_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, default_policy); - errors += basic_test_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, default_policy); - errors += basic_test_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, default_policy); - errors += basic_test_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, default_policy); - errors += basic_test_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, default_policy); - - test_report("TLS", 5, errors); + Test_Policy policy; + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy); + + policy.set("key_exchange_methods", "RSA"); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy); + + policy.set("key_exchange_methods", "DH"); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy); + policy.set("key_exchange_methods", "ECDH"); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy); + + policy.set("ciphers", "AES-128"); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V10, *basic_creds, policy); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V11, *basic_creds, policy); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V10, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy); + + policy.set("ciphers", "ChaCha20Poly1305"); + errors += test_tls_handshake(rng, TLS::Protocol_Version::TLS_V12, *basic_creds, policy); + errors += test_dtls_handshake(rng, TLS::Protocol_Version::DTLS_V12, *basic_creds, policy); + + test_report("TLS", 22, errors); return errors; } |