diff options
-rw-r--r-- | src/tests/unit_tls.cpp | 235 |
1 files changed, 235 insertions, 0 deletions
diff --git a/src/tests/unit_tls.cpp b/src/tests/unit_tls.cpp index c3355b118..c6114b010 100644 --- a/src/tests/unit_tls.cpp +++ b/src/tests/unit_tls.cpp @@ -1032,6 +1032,241 @@ class TLS_Unit_Tests final : public Test BOTAN_REGISTER_TEST("tls", TLS_Unit_Tests); +class DTLS_Reconnection_Test : public Test + { + public: + std::vector<Test::Result> run() override + { + class Test_Callbacks : public Botan::TLS::Callbacks + { + public: + Test_Callbacks(Test::Result& results, + std::vector<uint8_t>& outbound, + std::vector<uint8_t>& recv_buf) : + m_results(results), + m_outbound(outbound), + m_recv(recv_buf) + {} + + void tls_emit_data(const uint8_t bits[], size_t len) override + { + m_outbound.insert(m_outbound.end(), bits, bits + len); + } + + void tls_record_received(uint64_t /*seq*/, const uint8_t bits[], size_t len) override + { + m_recv.insert(m_recv.end(), bits, bits + len); + } + + void tls_alert(Botan::TLS::Alert /*alert*/) override + { + // ignore + } + + bool tls_session_established(const Botan::TLS::Session& /*session*/) override + { + m_results.test_success("Established a session"); + return true; + } + + private: + Test::Result& m_results; + std::vector<uint8_t>& m_outbound; + std::vector<uint8_t>& m_recv; + }; + + class Credentials_PSK : public Botan::Credentials_Manager + { + public: + Botan::SymmetricKey psk(const std::string& type, + const std::string& context, + const std::string&) override + { + if(type == "tls-server" && context == "session-ticket") + { + return Botan::SymmetricKey("AABBCCDDEEFF012345678012345678"); + } + + if(type == "tls-server" && context == "dtls-cookie-secret") + { + return Botan::SymmetricKey("4AEA5EAD279CADEB537A594DA0E9DE3A"); + } + + if(context == "localhost" && type == "tls-client") + { + return Botan::SymmetricKey("20B602D1475F2DF888FCB60D2AE03AFD"); + } + + if(context == "localhost" && type == "tls-server") + { + return Botan::SymmetricKey("20B602D1475F2DF888FCB60D2AE03AFD"); + } + + throw Test_Error("No PSK set for " + type + "/" + context); + } + }; + + class Datagram_PSK_Policy : public Botan::TLS::Policy + { + public: + std::vector<std::string> allowed_macs() const override + { return std::vector<std::string>({"AEAD"}); } + + std::vector<std::string> allowed_key_exchange_methods() const override + { return {"PSK"}; } + + bool allow_tls10() const override { return false; } + bool allow_tls11() const override { return false; } + bool allow_tls12() const override { return false; } + bool allow_dtls10() const override { return false; } + bool allow_dtls12() const override { return true; } + + bool allow_dtls_epoch0_restart() const override { return true; } + }; + + Test::Result result("DTLS reconnection"); + + Datagram_PSK_Policy server_policy; + Datagram_PSK_Policy client_policy; + Credentials_PSK creds; + Botan::TLS::Session_Manager_In_Memory server_sessions(rng()); + //Botan::TLS::Session_Manager_In_Memory client_sessions(rng()); + Botan::TLS::Session_Manager_Noop client_sessions; + + std::vector<uint8_t> s2c, server_recv; + Test_Callbacks server_callbacks(result, s2c, server_recv); + Botan::TLS::Server server(server_callbacks, server_sessions, creds, server_policy, rng(), true); + + std::vector<uint8_t> c1_c2s, client1_recv; + Test_Callbacks client1_callbacks(result, c1_c2s, client1_recv); + Botan::TLS::Client client1(client1_callbacks, client_sessions, creds, client_policy, rng(), + Botan::TLS::Server_Information("localhost"), + Botan::TLS::Protocol_Version::latest_dtls_version()); + + bool c1_to_server_sent = false; + bool server_to_c1_sent = false; + + const std::vector<uint8_t> c1_to_server_magic(16, 0xC1); + const std::vector<uint8_t> server_to_c1_magic(16, 0x42); + + size_t c1_rounds = 0; + for(;;) + { + c1_rounds++; + + if(c1_rounds > 64) + { + result.test_failure("Still spinning in client1 loop after 64 rounds"); + return {result}; + } + + if(c1_c2s.size() > 0) + { + std::vector<uint8_t> input; + std::swap(c1_c2s, input); + server.received_data(input.data(), input.size()); + continue; + } + + if(s2c.size() > 0) + { + std::vector<uint8_t> input; + std::swap(s2c, input); + client1.received_data(input.data(), input.size()); + continue; + } + + if(!c1_to_server_sent && client1.is_active()) + { + client1.send(c1_to_server_magic); + c1_to_server_sent = true; + } + + if(!server_to_c1_sent && server.is_active()) + { + server.send(server_to_c1_magic); + } + + if(server_recv.size() > 0 && client1_recv.size() > 0) + { + result.test_eq("Expected message from client1", server_recv, c1_to_server_magic); + result.test_eq("Expected message to client1", client1_recv, server_to_c1_magic); + break; + } + } + + // Now client1 "goes away" (goes silent) and new client + // connects to same server context (ie due to reuse of client source port) + // See RFC 6347 section 4.2.8 + + server_recv.clear(); + s2c.clear(); + + std::vector<uint8_t> c2_c2s, client2_recv; + Test_Callbacks client2_callbacks(result, c2_c2s, client2_recv); + Botan::TLS::Client client2(client2_callbacks, client_sessions, creds, client_policy, rng(), + Botan::TLS::Server_Information("localhost"), + Botan::TLS::Protocol_Version::latest_dtls_version()); + + bool c2_to_server_sent = false; + bool server_to_c2_sent = false; + + const std::vector<uint8_t> c2_to_server_magic(16, 0xC2); + const std::vector<uint8_t> server_to_c2_magic(16, 0x66); + + size_t c2_rounds = 0; + + for(;;) + { + c2_rounds++; + + if(c2_rounds > 64) + { + result.test_failure("Still spinning in client2 loop after 64 rounds"); + return {result}; + } + + if(c2_c2s.size() > 0) + { + std::vector<uint8_t> input; + std::swap(c2_c2s, input); + server.received_data(input.data(), input.size()); + continue; + } + + if(s2c.size() > 0) + { + std::vector<uint8_t> input; + std::swap(s2c, input); + client2.received_data(input.data(), input.size()); + continue; + } + + if(!c2_to_server_sent && client2.is_active()) + { + client2.send(c2_to_server_magic); + c2_to_server_sent = true; + } + + if(!server_to_c2_sent && server.is_active()) + { + server.send(server_to_c2_magic); + } + + if(server_recv.size() > 0 && client2_recv.size() > 0) + { + result.test_eq("Expected message from client2", server_recv, c2_to_server_magic); + result.test_eq("Expected message to client2", client2_recv, server_to_c2_magic); + break; + } + } + + return {result}; + } + }; + +BOTAN_REGISTER_TEST("tls_dtls_reconnect", DTLS_Reconnection_Test); + #endif } |