From e2d844d3dbea46a54a86a03126772956b715b9c9 Mon Sep 17 00:00:00 2001 From: lloyd Date: Tue, 4 Sep 2012 20:43:35 +0000 Subject: Use a std::function so handshake_io only has access Record_Writer's send function. --- src/tls/rec_wri.cpp | 4 ++-- src/tls/tls_channel.cpp | 2 +- src/tls/tls_client.cpp | 6 +++++- src/tls/tls_handshake_io.cpp | 22 +++++++++++----------- src/tls/tls_handshake_io.h | 13 ++++++++----- src/tls/tls_record.h | 4 ++-- src/tls/tls_server.cpp | 7 ++++++- 7 files changed, 35 insertions(+), 23 deletions(-) (limited to 'src') diff --git a/src/tls/rec_wri.cpp b/src/tls/rec_wri.cpp index fdecaa919..4eff52f78 100644 --- a/src/tls/rec_wri.cpp +++ b/src/tls/rec_wri.cpp @@ -148,7 +148,7 @@ void Record_Writer::change_cipher_spec(Connection_Side side, /* * Send one or more records to the other side */ -void Record_Writer::send(byte type, const byte input[], size_t length) +void Record_Writer::send_array(byte type, const byte input[], size_t length) { if(length == 0) return; @@ -288,7 +288,7 @@ void Record_Writer::send_alert(const Alert& alert) const byte alert_bits[2] = { static_cast(alert.is_fatal() ? 2 : 1), static_cast(alert.type()) }; - send(ALERT, alert_bits, sizeof(alert_bits)); + send_array(ALERT, alert_bits, sizeof(alert_bits)); } } diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index 3e5bdbabd..63a496358 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -207,7 +207,7 @@ void Channel::send(const byte buf[], size_t buf_size) if(!is_active()) throw std::runtime_error("Data cannot be sent on inactive TLS connection"); - m_writer.send(APPLICATION_DATA, buf, buf_size); + m_writer.send_array(APPLICATION_DATA, buf, buf_size); } void Channel::send(const std::string& string) diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index c6c7a1765..7fa1ad8bc 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -64,7 +64,11 @@ Client::Client(std::function output_fn, Handshake_State* Client::new_handshake_state() { - return new Client_Handshake_State(new Stream_Handshake_IO(m_writer)); + using namespace std::placeholders; + + return new Client_Handshake_State( + new Stream_Handshake_IO(std::bind(&Record_Writer::send, + std::ref(m_writer), _1, _2))); } /* diff --git a/src/tls/tls_handshake_io.cpp b/src/tls/tls_handshake_io.cpp index cc2e4f7d1..b1f6a0eb5 100644 --- a/src/tls/tls_handshake_io.cpp +++ b/src/tls/tls_handshake_io.cpp @@ -106,12 +106,12 @@ std::vector Stream_Handshake_IO::send(const Handshake_Message& msg) if(msg.type() == HANDSHAKE_CCS) { - m_writer.send(CHANGE_CIPHER_SPEC, msg_bits); + m_writer(CHANGE_CIPHER_SPEC, msg_bits); return std::vector(); // not included in handshake hashes } const std::vector buf = format(msg_bits, msg.type()); - m_writer.send(HANDSHAKE, buf); + m_writer(HANDSHAKE, buf); return buf; } @@ -323,7 +323,7 @@ Datagram_Handshake_IO::send(const Handshake_Message& msg) if(msg.type() == HANDSHAKE_CCS) { - m_writer.send(CHANGE_CIPHER_SPEC, msg_bits); + m_writer(CHANGE_CIPHER_SPEC, msg_bits); return std::vector(); // not included in handshake hashes } @@ -333,7 +333,7 @@ Datagram_Handshake_IO::send(const Handshake_Message& msg) m_mtu = 64; if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu) - m_writer.send(HANDSHAKE, no_fragment); + m_writer(HANDSHAKE, no_fragment); else { const size_t parts = split_for_mtu(m_mtu, msg_bits.size()); @@ -348,13 +348,13 @@ Datagram_Handshake_IO::send(const Handshake_Message& msg) std::min(msg_bits.size() - frag_offset, parts_size); - m_writer.send(HANDSHAKE, - format_fragment(&msg_bits[frag_offset], - frag_len, - frag_offset, - msg_bits.size(), - msg.type(), - m_out_message_seq)); + m_writer(HANDSHAKE, + format_fragment(&msg_bits[frag_offset], + frag_len, + frag_offset, + msg_bits.size(), + msg.type(), + m_out_message_seq)); frag_offset += frag_len; } diff --git a/src/tls/tls_handshake_io.h b/src/tls/tls_handshake_io.h index 6463e5638..9ff580cf8 100644 --- a/src/tls/tls_handshake_io.h +++ b/src/tls/tls_handshake_io.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -21,9 +22,10 @@ namespace Botan { namespace TLS { -class Record_Writer; class Handshake_Message; +typedef std::function&)> handshake_write_fn; + /** * Handshake IO Interface */ @@ -64,7 +66,8 @@ class Handshake_IO class Stream_Handshake_IO : public Handshake_IO { public: - Stream_Handshake_IO(Record_Writer& writer) : m_writer(writer) {} + Stream_Handshake_IO(handshake_write_fn writer) : + m_writer(writer) {} Protocol_Version initial_record_version() const override; @@ -83,7 +86,7 @@ class Stream_Handshake_IO : public Handshake_IO get_next_record(bool expecting_ccs) override; private: std::deque m_queue; - Record_Writer& m_writer; + handshake_write_fn m_writer; }; /** @@ -92,7 +95,7 @@ class Stream_Handshake_IO : public Handshake_IO class Datagram_Handshake_IO : public Handshake_IO { public: - Datagram_Handshake_IO(Record_Writer& writer, u16bit mtu) : + Datagram_Handshake_IO(handshake_write_fn writer, u16bit mtu) : m_flights(1), m_mtu(mtu), m_writer(writer) {} Protocol_Version initial_record_version() const override; @@ -155,7 +158,7 @@ class Datagram_Handshake_IO : public Handshake_IO u16bit m_mtu = 0; u16bit m_in_message_seq = 0; u16bit m_out_message_seq = 0; - Record_Writer& m_writer; + handshake_write_fn m_writer; }; } diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h index 5de17033a..3d3e4d7a7 100644 --- a/src/tls/tls_record.h +++ b/src/tls/tls_record.h @@ -30,10 +30,10 @@ class Session_Keys; class BOTAN_DLL Record_Writer { public: - void send(byte type, const byte input[], size_t length); + void send_array(byte type, const byte input[], size_t length); void send(byte type, const std::vector& input) - { send(type, &input[0], input.size()); } + { send_array(type, &input[0], input.size()); } void send_alert(const Alert& alert); diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp index 426da353b..9aea56b1c 100644 --- a/src/tls/tls_server.cpp +++ b/src/tls/tls_server.cpp @@ -224,7 +224,12 @@ Server::Server(std::function output_fn, Handshake_State* Server::new_handshake_state() { - Handshake_State* state = new Server_Handshake_State(new Stream_Handshake_IO(m_writer)); + using namespace std::placeholders; + + Handshake_State* state = new Server_Handshake_State( + new Stream_Handshake_IO(std::bind(&Record_Writer::send, + std::ref(m_writer), _1, _2))); + state->set_expected_next(CLIENT_HELLO); return state; } -- cgit v1.2.3