diff options
-rw-r--r-- | src/lib/tls/asio/asio_stream.h | 111 | ||||
-rw-r--r-- | src/lib/tls/asio/asio_stream_base.h | 47 | ||||
-rw-r--r-- | src/tests/unit_asio_stream.cpp | 23 |
3 files changed, 98 insertions, 83 deletions
diff --git a/src/lib/tls/asio/asio_stream.h b/src/lib/tls/asio/asio_stream.h index 72f363a88..1cd2af87b 100644 --- a/src/lib/tls/asio/asio_stream.h +++ b/src/lib/tls/asio/asio_stream.h @@ -45,31 +45,7 @@ class Stream final : public StreamBase<Channel> using executor_type = typename next_layer_type::executor_type; using native_handle_type = typename std::add_pointer<Channel>::type; - enum handshake_type - { - client, - server - }; - - private: - void validate_handshake_type(handshake_type type) - { - if(type != handshake_type::client) - { - throw Not_Implemented("server-side TLS stream is not implemented"); - } - } - - bool validate_handshake_type(handshake_type type, boost::system::error_code& ec) - { - if(type != handshake_type::client) - { - ec = make_error_code(Botan::TLS::error::not_implemented); - return false; - } - - return true; - } + using StreamBase<Channel>::validate_handshake_type; public: template <typename... Args> @@ -93,7 +69,6 @@ class Stream final : public StreamBase<Channel> throw Not_Implemented("cannot handle an asio::ssl::context"); } - Stream(Stream&& other) = default; Stream& operator=(Stream&& other) = default; @@ -167,22 +142,15 @@ class Stream final : public StreamBase<Channel> // -- -- handshake methods // - void handshake(handshake_type type) + void handshake() { - validate_handshake_type(type); - boost::system::error_code ec; - handshake(type, ec); + handshake(ec); boost::asio::detail::throw_error(ec, "handshake"); } - void handshake(handshake_type type, boost::system::error_code& ec) + void handshake(boost::system::error_code& ec) { - if(!validate_handshake_type(type, ec)) - { - return; - } - while(!native_handle()->is_active()) { writePendingTlsData(ec); @@ -214,11 +182,47 @@ class Stream final : public StreamBase<Channel> } } + template <typename HandshakeHandler> + BOOST_ASIO_INITFN_RESULT_TYPE(HandshakeHandler, + void(boost::system::error_code)) + async_handshake(HandshakeHandler&& handler) + { + BOOST_ASIO_HANDSHAKE_HANDLER_CHECK(HandshakeHandler, handler) type_check; + + boost::asio::async_completion<HandshakeHandler, + void(boost::system::error_code)> + init(handler); + + auto op = create_async_handshake_op(std::move(init.completion_handler)); + op(boost::system::error_code{}, 0, 1); + + return init.result.get(); + } + + // + // -- -- asio::ssl::stream compatibility methods + // + + void handshake(handshake_type type) + { + validate_handshake_type(type); + handshake(); + } + + void handshake(handshake_type type, boost::system::error_code& ec) + { + if(validate_handshake_type(type, ec)) + { + handshake(ec); + } + } + template<typename ConstBufferSequence> void handshake(handshake_type type, const ConstBufferSequence& buffers) { - BOTAN_UNUSED(type, buffers); - throw Not_Implemented("server-side TLS stream is not implemented"); + BOTAN_UNUSED(buffers); + validate_handshake_type(type); + throw Not_Implemented("buffered handshake is not implemented"); } template<typename ConstBufferSequence> @@ -226,8 +230,11 @@ class Stream final : public StreamBase<Channel> const ConstBufferSequence& buffers, boost::system::error_code& ec) { - BOTAN_UNUSED(type, buffers); - ec = make_error_code(Botan::TLS::error::not_implemented); + BOTAN_UNUSED(buffers); + if(validate_handshake_type(type, ec)) + { + ec = make_error_code(Botan::TLS::error::not_implemented); + } } template <typename HandshakeHandler> @@ -235,33 +242,19 @@ class Stream final : public StreamBase<Channel> void(boost::system::error_code)) async_handshake(handshake_type type, HandshakeHandler&& handler) { - // If you get an error on the following line it means that your handler does - // not meet the documented type requirements for a HandshakeHandler. - BOOST_ASIO_HANDSHAKE_HANDLER_CHECK(HandshakeHandler, handler) type_check; - validate_handshake_type(type); - - boost::asio::async_completion<HandshakeHandler, - void(boost::system::error_code)> - init(handler); - - auto op = create_async_handshake_op(std::move(init.completion_handler)); - op(boost::system::error_code{}, 0, 1); - - return init.result.get(); + return async_handshake(handler); } template <typename ConstBufferSequence, typename BufferedHandshakeHandler> BOOST_ASIO_INITFN_RESULT_TYPE(BufferedHandshakeHandler, void(boost::system::error_code, std::size_t)) - async_handshake(handshake_type type, - const ConstBufferSequence& buffers, + async_handshake(handshake_type type, const ConstBufferSequence& buffers, BufferedHandshakeHandler&& handler) { - // If you get an error on the following line it means that your handler does - // not meet the documented type requirements for a BufferedHandshakeHandler. + BOTAN_UNUSED(buffers, handler); BOOST_ASIO_HANDSHAKE_HANDLER_CHECK(BufferedHandshakeHandler, handler) type_check; - BOTAN_UNUSED(type, buffers, handler); + validate_handshake_type(type); throw Not_Implemented("buffered async handshake is not implemented"); } @@ -293,8 +286,6 @@ class Stream final : public StreamBase<Channel> template <typename ShutdownHandler> void async_shutdown(ShutdownHandler&& handler) { - // If you get an error on the following line it means that your handler does - // not meet the documented type requirements for a ShutdownHandler. BOOST_ASIO_HANDSHAKE_HANDLER_CHECK(ShutdownHandler, handler) type_check; BOTAN_UNUSED(handler); throw Not_Implemented("async shutdown is not implemented"); diff --git a/src/lib/tls/asio/asio_stream_base.h b/src/lib/tls/asio/asio_stream_base.h index 0fb5353ed..161392ad8 100644 --- a/src/lib/tls/asio/asio_stream_base.h +++ b/src/lib/tls/asio/asio_stream_base.h @@ -12,11 +12,23 @@ #include <botan/auto_rng.h> #include <botan/tls_client.h> #include <botan/tls_server.h> +#include <botan/asio_error.h> namespace Botan { namespace TLS { +enum handshake_type + { + client, + server + }; + + +/* Base class for all Botan::TLS::Stream implementations. + * + * + */ template <class Channel> class StreamBase { @@ -43,30 +55,31 @@ class StreamBase<Botan::TLS::Client> StreamBase(const StreamBase&) = delete; StreamBase& operator=(const StreamBase&) = delete; - protected: - Botan::TLS::StreamCore m_core; - Botan::AutoSeeded_RNG m_rng; - Botan::TLS::Client m_channel; - }; + using handshake_type = Botan::TLS::handshake_type; -template <> -class StreamBase<Botan::TLS::Server> - { - public: - StreamBase(Botan::TLS::Session_Manager& sessionManager, - Botan::Credentials_Manager& credentialsManager, - const Botan::TLS::Policy& policy = Botan::TLS::Strict_Policy{}) - : m_channel(m_core, sessionManager, credentialsManager, policy, m_rng) + protected: + void validate_handshake_type(handshake_type type) { + if(type != handshake_type::client) + { + throw Invalid_Argument("wrong handshake_type"); + } } - StreamBase(const StreamBase&) = delete; - StreamBase& operator=(const StreamBase&) = delete; + bool validate_handshake_type(handshake_type type, boost::system::error_code& ec) + { + if(type != handshake_type::client) + { + ec = make_error_code(Botan::TLS::error::invalid_argument); + return false; + } + + return true; + } - protected: Botan::TLS::StreamCore m_core; Botan::AutoSeeded_RNG m_rng; - Botan::TLS::Server m_channel; + Botan::TLS::Client m_channel; }; } // namespace TLS diff --git a/src/tests/unit_asio_stream.cpp b/src/tests/unit_asio_stream.cpp index 7da7467f4..7551c44d9 100644 --- a/src/tests/unit_asio_stream.cpp +++ b/src/tests/unit_asio_stream.cpp @@ -142,10 +142,21 @@ class StreamBase<Botan_Tests::MockChannel> StreamBase(const StreamBase&) = delete; StreamBase& operator=(const StreamBase&) = delete; + using handshake_type = Botan::TLS::handshake_type; + protected: StreamCore m_core; Botan::AutoSeeded_RNG m_rng; Botan_Tests::MockChannel m_channel; + + void validate_handshake_type(handshake_type) + { + } + + bool validate_handshake_type(handshake_type, boost::system::error_code&) + { + return true; + } }; } // namespace TLS @@ -171,7 +182,7 @@ class ASIO_Stream_Tests final : public Test MockSocket socket; AsioStream ssl{socket}; - ssl.handshake(AsioStream::client); + ssl.handshake(AsioStream::handshake_type::client); Test::Result result("sync TLS handshake"); result.test_eq("feeds data into channel until active", ssl.native_handle()->is_active(), true); @@ -187,7 +198,7 @@ class ASIO_Stream_Tests final : public Test socket.error = expected_ec; error_code ec; - ssl.handshake(AsioStream::client, ec); + ssl.handshake(AsioStream::handshake_type::client, ec); Test::Result result("sync TLS handshake error"); result.test_eq("does not activate channel", ssl.native_handle()->is_active(), false); @@ -207,7 +218,7 @@ class ASIO_Stream_Tests final : public Test result.test_eq("feeds data into channel until active", ssl.native_handle()->is_active(), true); }; - ssl.async_handshake(AsioStream::client, handler); + ssl.async_handshake(AsioStream::handshake_type::client, handler); results.push_back(result); } @@ -227,7 +238,7 @@ class ASIO_Stream_Tests final : public Test result.confirm("propagates error code", ec == expected_ec); }; - ssl.async_handshake(AsioStream::client, handler); + ssl.async_handshake(AsioStream::handshake_type::client, handler); results.push_back(result); } @@ -537,7 +548,7 @@ class Async_Asio_Stream_Tests final : public Test result.test_eq("feeds data into channel until active", ssl.native_handle()->is_active(), true); }; - ssl.async_handshake(AsioStream::client, handler); + ssl.async_handshake(AsioStream::handshake_type::client, handler); socket.close_remote(); ioc.run(); @@ -563,7 +574,7 @@ class Async_Asio_Stream_Tests final : public Test result.confirm("propagates error code", (bool)ec); }; - ssl.async_handshake(AsioStream::client, handler); + ssl.async_handshake(AsioStream::handshake_type::client, handler); ioc.run(); results.push_back(result); |