aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorlloyd <[email protected]>2012-09-09 20:42:39 +0000
committerlloyd <[email protected]>2012-09-09 20:42:39 +0000
commit04559cbc1f8969c623ba9f601ba7933f77cc9a97 (patch)
tree45b448baa943e879f35f1e5c02e7e3fd279345ad
parent9bc3561ef578dad00d8af8541e2003962ca1ae45 (diff)
Create the IO in Channel and then pass it down to new_handshake_state
as the logic is the same for both cases.
-rw-r--r--src/tls/tls_channel.cpp31
-rw-r--r--src/tls/tls_channel.h11
-rw-r--r--src/tls/tls_client.cpp13
-rw-r--r--src/tls/tls_client.h2
-rw-r--r--src/tls/tls_record.cpp3
-rw-r--r--src/tls/tls_server.cpp14
-rw-r--r--src/tls/tls_server.h2
7 files changed, 39 insertions, 37 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index eea9edb74..d95e8bbf7 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -46,12 +46,25 @@ std::vector<X509_Certificate> Channel::peer_cert_chain() const
return get_peer_cert_chain(*m_active_state);
}
-Handshake_State& Channel::create_handshake_state()
+Handshake_State& Channel::create_handshake_state(Protocol_Version version)
{
if(m_pending_state)
throw Internal_Error("create_handshake_state called during handshake");
- m_pending_state.reset(new_handshake_state());
+ const size_t dtls_mtu = 1400;
+
+ std::unique_ptr<Handshake_IO> handshake_io;
+
+ auto send_rec = std::bind(&Channel::send_record, this,
+ std::placeholders::_1,
+ std::placeholders::_2);
+
+ if(version.is_datagram_protocol())
+ handshake_io.reset(new Datagram_Handshake_IO(send_rec, dtls_mtu));
+ else
+ handshake_io.reset(new Stream_Handshake_IO(send_rec));
+
+ m_pending_state.reset(new_handshake_state(handshake_io.release()));
return *m_pending_state.get();
}
@@ -61,9 +74,11 @@ void Channel::renegotiate(bool force_full_renegotiation)
if(m_pending_state) // currently in handshake?
return;
- m_pending_state.reset(new_handshake_state());
+ if(!m_active_state)
+ throw std::runtime_error("Cannot renegotiate on inactive connection");
- initiate_handshake(*m_pending_state.get(), force_full_renegotiation);
+ initiate_handshake(create_handshake_state(m_active_state->version()),
+ force_full_renegotiation);
}
void Channel::set_protocol_version(Protocol_Version version)
@@ -196,7 +211,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size)
if(rec_type == HANDSHAKE || rec_type == CHANGE_CIPHER_SPEC)
{
if(!m_pending_state)
- m_pending_state.reset(new_handshake_state());
+ create_handshake_state(m_current_version); // fixme
m_pending_state->handshake_io().add_input(
rec_type, &record[0], record.size(), record_number);
@@ -324,7 +339,7 @@ void Channel::heartbeat(const byte payload[], size_t payload_size)
}
}
-void Channel::send_record(byte type, const byte input[], size_t length)
+void Channel::send_record_array(byte type, const byte input[], size_t length)
{
if(length == 0)
return;
@@ -361,7 +376,7 @@ void Channel::send_record(byte type, const byte input[], size_t length)
void Channel::send_record(byte record_type, const std::vector<byte>& record)
{
- send_record(record_type, &record[0], record.size());
+ send_record_array(record_type, &record[0], record.size());
}
void Channel::write_record(byte record_type, const byte input[], size_t length)
@@ -396,7 +411,7 @@ void Channel::send(const byte buf[], size_t buf_size)
if(!is_active())
throw std::runtime_error("Data cannot be sent on inactive TLS connection");
- send_record(APPLICATION_DATA, buf, buf_size);
+ send_record_array(APPLICATION_DATA, buf, buf_size);
}
void Channel::send(const std::string& string)
diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h
index d5ae4b1cd..87062523c 100644
--- a/src/tls/tls_channel.h
+++ b/src/tls/tls_channel.h
@@ -21,6 +21,7 @@ namespace Botan {
namespace TLS {
+class Handshake_IO;
class Handshake_State;
/**
@@ -120,9 +121,9 @@ class BOTAN_DLL Channel
virtual std::vector<X509_Certificate>
get_peer_cert_chain(const Handshake_State& state) const = 0;
- virtual Handshake_State* new_handshake_state() = 0;
+ virtual Handshake_State* new_handshake_state(Handshake_IO* io) = 0;
- Handshake_State& create_handshake_state();
+ Handshake_State& create_handshake_state(Protocol_Version version);
/**
* Send a TLS alert message. If the alert is fatal, the internal
@@ -144,8 +145,6 @@ class BOTAN_DLL Channel
void change_cipher_spec_writer(Connection_Side side);
- void send_record(byte record_type, const std::vector<byte>& record);
-
/* secure renegotiation handling */
void secure_renegotiation_check(const class Client_Hello* client_hello);
@@ -163,7 +162,9 @@ class BOTAN_DLL Channel
bool save_session(const Session& session) const { return m_handshake_fn(session); }
private:
- void send_record(byte type, const byte input[], size_t length);
+ void send_record(byte record_type, const std::vector<byte>& record);
+
+ void send_record_array(byte type, const byte input[], size_t length);
void write_record(byte type, const byte input[], size_t length);
diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp
index b40c86f5c..d63d05cab 100644
--- a/src/tls/tls_client.cpp
+++ b/src/tls/tls_client.cpp
@@ -66,21 +66,14 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn,
{
const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_hostname);
- Handshake_State& state = create_handshake_state();
const Protocol_Version version = m_policy.pref_version();
+ Handshake_State& state = create_handshake_state(version);
initiate_handshake(state, false, version, srp_identifier, next_protocol);
}
-Handshake_State* Client::new_handshake_state()
+Handshake_State* Client::new_handshake_state(Handshake_IO* io)
{
- using namespace std::placeholders;
-
- return new Client_Handshake_State(
- new Stream_Handshake_IO(
- [this](byte type, const std::vector<byte>& rec)
- { this->send_record(type, rec); }
- )
- );
+ return new Client_Handshake_State(io);
}
std::vector<X509_Certificate>
diff --git a/src/tls/tls_client.h b/src/tls/tls_client.h
index 3edcfa495..a2b32cadd 100644
--- a/src/tls/tls_client.h
+++ b/src/tls/tls_client.h
@@ -84,7 +84,7 @@ class BOTAN_DLL Client : public Channel
Handshake_Type type,
const std::vector<byte>& contents) override;
- Handshake_State* new_handshake_state() override;
+ Handshake_State* new_handshake_state(Handshake_IO* io) override;
const Policy& m_policy;
Credentials_Manager& m_creds;
diff --git a/src/tls/tls_record.cpp b/src/tls/tls_record.cpp
index 9a439b86d..33903f1df 100644
--- a/src/tls/tls_record.cpp
+++ b/src/tls/tls_record.cpp
@@ -274,6 +274,7 @@ size_t read_record(std::vector<byte>& readbuf,
byte& msg_type,
std::vector<byte>& msg,
u64bit msg_sequence,
+ Protocol_Version& record_version,
Connection_Cipher_State* cipherstate)
{
consumed = 0;
@@ -335,7 +336,7 @@ size_t read_record(std::vector<byte>& readbuf,
" from counterparty");
}
- Protocol_Version record_version(readbuf[1], readbuf[2]);
+ record_version = Protocol_Version(readbuf[1], readbuf[2]);
if(record_version.is_datagram_protocol() && readbuf_pos < DTLS_HEADER_SIZE)
{
diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp
index 313b23a0a..81cb4940c 100644
--- a/src/tls/tls_server.cpp
+++ b/src/tls/tls_server.cpp
@@ -221,19 +221,11 @@ Server::Server(std::function<void (const byte[], size_t)> output_fn,
{
}
-Handshake_State* Server::new_handshake_state()
+Handshake_State* Server::new_handshake_state(Handshake_IO* io)
{
- using namespace std::placeholders;
-
- Handshake_State* state = new Server_Handshake_State(
- new Stream_Handshake_IO(
- [this](byte type, const std::vector<byte>& rec)
- { this->send_record(type, rec); }
- )
- );
-
+ std::unique_ptr<Handshake_State> state(new Server_Handshake_State(io));
state->set_expected_next(CLIENT_HELLO);
- return state;
+ return state.release();
}
std::vector<X509_Certificate>
diff --git a/src/tls/tls_server.h b/src/tls/tls_server.h
index 94127e0d0..761ff6028 100644
--- a/src/tls/tls_server.h
+++ b/src/tls/tls_server.h
@@ -59,7 +59,7 @@ class BOTAN_DLL Server : public Channel
Handshake_Type type,
const std::vector<byte>& contents) override;
- Handshake_State* new_handshake_state() override;
+ Handshake_State* new_handshake_state(Handshake_IO* io) override;
const Policy& m_policy;
Credentials_Manager& m_creds;