aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorlloyd <[email protected]>2012-09-04 20:43:35 +0000
committerlloyd <[email protected]>2012-09-04 20:43:35 +0000
commite2d844d3dbea46a54a86a03126772956b715b9c9 (patch)
tree59c48e5346f506996733eae28d275ce02ff53de3
parent7f888f4bbcf8d11ff5916b1056ab181962da5b90 (diff)
Use a std::function so handshake_io only has access Record_Writer's
send function.
-rw-r--r--src/tls/rec_wri.cpp4
-rw-r--r--src/tls/tls_channel.cpp2
-rw-r--r--src/tls/tls_client.cpp6
-rw-r--r--src/tls/tls_handshake_io.cpp22
-rw-r--r--src/tls/tls_handshake_io.h13
-rw-r--r--src/tls/tls_record.h4
-rw-r--r--src/tls/tls_server.cpp7
7 files changed, 35 insertions, 23 deletions
diff --git a/src/tls/rec_wri.cpp b/src/tls/rec_wri.cpp
index fdecaa919..4eff52f78 100644
--- a/src/tls/rec_wri.cpp
+++ b/src/tls/rec_wri.cpp
@@ -148,7 +148,7 @@ void Record_Writer::change_cipher_spec(Connection_Side side,
/*
* Send one or more records to the other side
*/
-void Record_Writer::send(byte type, const byte input[], size_t length)
+void Record_Writer::send_array(byte type, const byte input[], size_t length)
{
if(length == 0)
return;
@@ -288,7 +288,7 @@ void Record_Writer::send_alert(const Alert& alert)
const byte alert_bits[2] = { static_cast<byte>(alert.is_fatal() ? 2 : 1),
static_cast<byte>(alert.type()) };
- send(ALERT, alert_bits, sizeof(alert_bits));
+ send_array(ALERT, alert_bits, sizeof(alert_bits));
}
}
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index 3e5bdbabd..63a496358 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -207,7 +207,7 @@ void Channel::send(const byte buf[], size_t buf_size)
if(!is_active())
throw std::runtime_error("Data cannot be sent on inactive TLS connection");
- m_writer.send(APPLICATION_DATA, buf, buf_size);
+ m_writer.send_array(APPLICATION_DATA, buf, buf_size);
}
void Channel::send(const std::string& string)
diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp
index c6c7a1765..7fa1ad8bc 100644
--- a/src/tls/tls_client.cpp
+++ b/src/tls/tls_client.cpp
@@ -64,7 +64,11 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn,
Handshake_State* Client::new_handshake_state()
{
- return new Client_Handshake_State(new Stream_Handshake_IO(m_writer));
+ using namespace std::placeholders;
+
+ return new Client_Handshake_State(
+ new Stream_Handshake_IO(std::bind(&Record_Writer::send,
+ std::ref(m_writer), _1, _2)));
}
/*
diff --git a/src/tls/tls_handshake_io.cpp b/src/tls/tls_handshake_io.cpp
index cc2e4f7d1..b1f6a0eb5 100644
--- a/src/tls/tls_handshake_io.cpp
+++ b/src/tls/tls_handshake_io.cpp
@@ -106,12 +106,12 @@ std::vector<byte> Stream_Handshake_IO::send(const Handshake_Message& msg)
if(msg.type() == HANDSHAKE_CCS)
{
- m_writer.send(CHANGE_CIPHER_SPEC, msg_bits);
+ m_writer(CHANGE_CIPHER_SPEC, msg_bits);
return std::vector<byte>(); // not included in handshake hashes
}
const std::vector<byte> buf = format(msg_bits, msg.type());
- m_writer.send(HANDSHAKE, buf);
+ m_writer(HANDSHAKE, buf);
return buf;
}
@@ -323,7 +323,7 @@ Datagram_Handshake_IO::send(const Handshake_Message& msg)
if(msg.type() == HANDSHAKE_CCS)
{
- m_writer.send(CHANGE_CIPHER_SPEC, msg_bits);
+ m_writer(CHANGE_CIPHER_SPEC, msg_bits);
return std::vector<byte>(); // not included in handshake hashes
}
@@ -333,7 +333,7 @@ Datagram_Handshake_IO::send(const Handshake_Message& msg)
m_mtu = 64;
if(no_fragment.size() + DTLS_HEADER_SIZE <= m_mtu)
- m_writer.send(HANDSHAKE, no_fragment);
+ m_writer(HANDSHAKE, no_fragment);
else
{
const size_t parts = split_for_mtu(m_mtu, msg_bits.size());
@@ -348,13 +348,13 @@ Datagram_Handshake_IO::send(const Handshake_Message& msg)
std::min<size_t>(msg_bits.size() - frag_offset,
parts_size);
- m_writer.send(HANDSHAKE,
- format_fragment(&msg_bits[frag_offset],
- frag_len,
- frag_offset,
- msg_bits.size(),
- msg.type(),
- m_out_message_seq));
+ m_writer(HANDSHAKE,
+ format_fragment(&msg_bits[frag_offset],
+ frag_len,
+ frag_offset,
+ msg_bits.size(),
+ msg.type(),
+ m_out_message_seq));
frag_offset += frag_len;
}
diff --git a/src/tls/tls_handshake_io.h b/src/tls/tls_handshake_io.h
index 6463e5638..9ff580cf8 100644
--- a/src/tls/tls_handshake_io.h
+++ b/src/tls/tls_handshake_io.h
@@ -11,6 +11,7 @@
#include <botan/tls_magic.h>
#include <botan/tls_version.h>
#include <botan/loadstor.h>
+#include <functional>
#include <vector>
#include <deque>
#include <map>
@@ -21,9 +22,10 @@ namespace Botan {
namespace TLS {
-class Record_Writer;
class Handshake_Message;
+typedef std::function<void (byte, const std::vector<byte>&)> handshake_write_fn;
+
/**
* Handshake IO Interface
*/
@@ -64,7 +66,8 @@ class Handshake_IO
class Stream_Handshake_IO : public Handshake_IO
{
public:
- Stream_Handshake_IO(Record_Writer& writer) : m_writer(writer) {}
+ Stream_Handshake_IO(handshake_write_fn writer) :
+ m_writer(writer) {}
Protocol_Version initial_record_version() const override;
@@ -83,7 +86,7 @@ class Stream_Handshake_IO : public Handshake_IO
get_next_record(bool expecting_ccs) override;
private:
std::deque<byte> m_queue;
- Record_Writer& m_writer;
+ handshake_write_fn m_writer;
};
/**
@@ -92,7 +95,7 @@ class Stream_Handshake_IO : public Handshake_IO
class Datagram_Handshake_IO : public Handshake_IO
{
public:
- Datagram_Handshake_IO(Record_Writer& writer, u16bit mtu) :
+ Datagram_Handshake_IO(handshake_write_fn writer, u16bit mtu) :
m_flights(1), m_mtu(mtu), m_writer(writer) {}
Protocol_Version initial_record_version() const override;
@@ -155,7 +158,7 @@ class Datagram_Handshake_IO : public Handshake_IO
u16bit m_mtu = 0;
u16bit m_in_message_seq = 0;
u16bit m_out_message_seq = 0;
- Record_Writer& m_writer;
+ handshake_write_fn m_writer;
};
}
diff --git a/src/tls/tls_record.h b/src/tls/tls_record.h
index 5de17033a..3d3e4d7a7 100644
--- a/src/tls/tls_record.h
+++ b/src/tls/tls_record.h
@@ -30,10 +30,10 @@ class Session_Keys;
class BOTAN_DLL Record_Writer
{
public:
- void send(byte type, const byte input[], size_t length);
+ void send_array(byte type, const byte input[], size_t length);
void send(byte type, const std::vector<byte>& input)
- { send(type, &input[0], input.size()); }
+ { send_array(type, &input[0], input.size()); }
void send_alert(const Alert& alert);
diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp
index 426da353b..9aea56b1c 100644
--- a/src/tls/tls_server.cpp
+++ b/src/tls/tls_server.cpp
@@ -224,7 +224,12 @@ Server::Server(std::function<void (const byte[], size_t)> output_fn,
Handshake_State* Server::new_handshake_state()
{
- Handshake_State* state = new Server_Handshake_State(new Stream_Handshake_IO(m_writer));
+ using namespace std::placeholders;
+
+ Handshake_State* state = new Server_Handshake_State(
+ new Stream_Handshake_IO(std::bind(&Record_Writer::send,
+ std::ref(m_writer), _1, _2)));
+
state->set_expected_next(CLIENT_HELLO);
return state;
}