diff options
author | lloyd <[email protected]> | 2012-09-09 20:42:39 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2012-09-09 20:42:39 +0000 |
commit | 04559cbc1f8969c623ba9f601ba7933f77cc9a97 (patch) | |
tree | 45b448baa943e879f35f1e5c02e7e3fd279345ad | |
parent | 9bc3561ef578dad00d8af8541e2003962ca1ae45 (diff) |
Create the IO in Channel and then pass it down to new_handshake_state
as the logic is the same for both cases.
-rw-r--r-- | src/tls/tls_channel.cpp | 31 | ||||
-rw-r--r-- | src/tls/tls_channel.h | 11 | ||||
-rw-r--r-- | src/tls/tls_client.cpp | 13 | ||||
-rw-r--r-- | src/tls/tls_client.h | 2 | ||||
-rw-r--r-- | src/tls/tls_record.cpp | 3 | ||||
-rw-r--r-- | src/tls/tls_server.cpp | 14 | ||||
-rw-r--r-- | src/tls/tls_server.h | 2 |
7 files changed, 39 insertions, 37 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index eea9edb74..d95e8bbf7 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -46,12 +46,25 @@ std::vector<X509_Certificate> Channel::peer_cert_chain() const return get_peer_cert_chain(*m_active_state); } -Handshake_State& Channel::create_handshake_state() +Handshake_State& Channel::create_handshake_state(Protocol_Version version) { if(m_pending_state) throw Internal_Error("create_handshake_state called during handshake"); - m_pending_state.reset(new_handshake_state()); + const size_t dtls_mtu = 1400; + + std::unique_ptr<Handshake_IO> handshake_io; + + auto send_rec = std::bind(&Channel::send_record, this, + std::placeholders::_1, + std::placeholders::_2); + + if(version.is_datagram_protocol()) + handshake_io.reset(new Datagram_Handshake_IO(send_rec, dtls_mtu)); + else + handshake_io.reset(new Stream_Handshake_IO(send_rec)); + + m_pending_state.reset(new_handshake_state(handshake_io.release())); return *m_pending_state.get(); } @@ -61,9 +74,11 @@ void Channel::renegotiate(bool force_full_renegotiation) if(m_pending_state) // currently in handshake? return; - m_pending_state.reset(new_handshake_state()); + if(!m_active_state) + throw std::runtime_error("Cannot renegotiate on inactive connection"); - initiate_handshake(*m_pending_state.get(), force_full_renegotiation); + initiate_handshake(create_handshake_state(m_active_state->version()), + force_full_renegotiation); } void Channel::set_protocol_version(Protocol_Version version) @@ -196,7 +211,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) if(rec_type == HANDSHAKE || rec_type == CHANGE_CIPHER_SPEC) { if(!m_pending_state) - m_pending_state.reset(new_handshake_state()); + create_handshake_state(m_current_version); // fixme m_pending_state->handshake_io().add_input( rec_type, &record[0], record.size(), record_number); @@ -324,7 +339,7 @@ void Channel::heartbeat(const byte payload[], size_t payload_size) } } -void Channel::send_record(byte type, const byte input[], size_t length) +void Channel::send_record_array(byte type, const byte input[], size_t length) { if(length == 0) return; @@ -361,7 +376,7 @@ void Channel::send_record(byte type, const byte input[], size_t length) void Channel::send_record(byte record_type, const std::vector<byte>& record) { - send_record(record_type, &record[0], record.size()); + send_record_array(record_type, &record[0], record.size()); } void Channel::write_record(byte record_type, const byte input[], size_t length) @@ -396,7 +411,7 @@ void Channel::send(const byte buf[], size_t buf_size) if(!is_active()) throw std::runtime_error("Data cannot be sent on inactive TLS connection"); - send_record(APPLICATION_DATA, buf, buf_size); + send_record_array(APPLICATION_DATA, buf, buf_size); } void Channel::send(const std::string& string) diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h index d5ae4b1cd..87062523c 100644 --- a/src/tls/tls_channel.h +++ b/src/tls/tls_channel.h @@ -21,6 +21,7 @@ namespace Botan { namespace TLS { +class Handshake_IO; class Handshake_State; /** @@ -120,9 +121,9 @@ class BOTAN_DLL Channel virtual std::vector<X509_Certificate> get_peer_cert_chain(const Handshake_State& state) const = 0; - virtual Handshake_State* new_handshake_state() = 0; + virtual Handshake_State* new_handshake_state(Handshake_IO* io) = 0; - Handshake_State& create_handshake_state(); + Handshake_State& create_handshake_state(Protocol_Version version); /** * Send a TLS alert message. If the alert is fatal, the internal @@ -144,8 +145,6 @@ class BOTAN_DLL Channel void change_cipher_spec_writer(Connection_Side side); - void send_record(byte record_type, const std::vector<byte>& record); - /* secure renegotiation handling */ void secure_renegotiation_check(const class Client_Hello* client_hello); @@ -163,7 +162,9 @@ class BOTAN_DLL Channel bool save_session(const Session& session) const { return m_handshake_fn(session); } private: - void send_record(byte type, const byte input[], size_t length); + void send_record(byte record_type, const std::vector<byte>& record); + + void send_record_array(byte type, const byte input[], size_t length); void write_record(byte type, const byte input[], size_t length); diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index b40c86f5c..d63d05cab 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -66,21 +66,14 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn, { const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_hostname); - Handshake_State& state = create_handshake_state(); const Protocol_Version version = m_policy.pref_version(); + Handshake_State& state = create_handshake_state(version); initiate_handshake(state, false, version, srp_identifier, next_protocol); } -Handshake_State* Client::new_handshake_state() +Handshake_State* Client::new_handshake_state(Handshake_IO* io) { - using namespace std::placeholders; - - return new Client_Handshake_State( - new Stream_Handshake_IO( - [this](byte type, const std::vector<byte>& rec) - { this->send_record(type, rec); } - ) - ); + return new Client_Handshake_State(io); } std::vector<X509_Certificate> diff --git a/src/tls/tls_client.h b/src/tls/tls_client.h index 3edcfa495..a2b32cadd 100644 --- a/src/tls/tls_client.h +++ b/src/tls/tls_client.h @@ -84,7 +84,7 @@ class BOTAN_DLL Client : public Channel Handshake_Type type, const std::vector<byte>& contents) override; - Handshake_State* new_handshake_state() override; + Handshake_State* new_handshake_state(Handshake_IO* io) override; const Policy& m_policy; Credentials_Manager& m_creds; diff --git a/src/tls/tls_record.cpp b/src/tls/tls_record.cpp index 9a439b86d..33903f1df 100644 --- a/src/tls/tls_record.cpp +++ b/src/tls/tls_record.cpp @@ -274,6 +274,7 @@ size_t read_record(std::vector<byte>& readbuf, byte& msg_type, std::vector<byte>& msg, u64bit msg_sequence, + Protocol_Version& record_version, Connection_Cipher_State* cipherstate) { consumed = 0; @@ -335,7 +336,7 @@ size_t read_record(std::vector<byte>& readbuf, " from counterparty"); } - Protocol_Version record_version(readbuf[1], readbuf[2]); + record_version = Protocol_Version(readbuf[1], readbuf[2]); if(record_version.is_datagram_protocol() && readbuf_pos < DTLS_HEADER_SIZE) { diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp index 313b23a0a..81cb4940c 100644 --- a/src/tls/tls_server.cpp +++ b/src/tls/tls_server.cpp @@ -221,19 +221,11 @@ Server::Server(std::function<void (const byte[], size_t)> output_fn, { } -Handshake_State* Server::new_handshake_state() +Handshake_State* Server::new_handshake_state(Handshake_IO* io) { - using namespace std::placeholders; - - Handshake_State* state = new Server_Handshake_State( - new Stream_Handshake_IO( - [this](byte type, const std::vector<byte>& rec) - { this->send_record(type, rec); } - ) - ); - + std::unique_ptr<Handshake_State> state(new Server_Handshake_State(io)); state->set_expected_next(CLIENT_HELLO); - return state; + return state.release(); } std::vector<X509_Certificate> diff --git a/src/tls/tls_server.h b/src/tls/tls_server.h index 94127e0d0..761ff6028 100644 --- a/src/tls/tls_server.h +++ b/src/tls/tls_server.h @@ -59,7 +59,7 @@ class BOTAN_DLL Server : public Channel Handshake_Type type, const std::vector<byte>& contents) override; - Handshake_State* new_handshake_state() override; + Handshake_State* new_handshake_state(Handshake_IO* io) override; const Policy& m_policy; Credentials_Manager& m_creds; |