diff options
-rw-r--r-- | src/tls/tls_channel.cpp | 29 | ||||
-rw-r--r-- | src/tls/tls_channel.h | 19 | ||||
-rw-r--r-- | src/tls/tls_client.cpp | 2 | ||||
-rw-r--r-- | src/tls/tls_server.cpp | 6 |
4 files changed, 27 insertions, 29 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index 17cf6098e..1be336fc5 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -128,12 +128,20 @@ void Channel::renegotiate(bool force_full_renegotiation) throw std::runtime_error("Cannot renegotiate on inactive connection"); } -void Channel::set_maximum_fragment_size(size_t max_fragment) +size_t Channel::maximum_fragment_size() const { - if(max_fragment == 0) - m_max_fragment = MAX_PLAINTEXT_SIZE; - else - m_max_fragment = clamp(max_fragment, 128, MAX_PLAINTEXT_SIZE); + // should we be caching this value? + + if(auto pending = pending_state()) + if(auto server_hello = pending->server_hello()) + if(size_t frag = server_hello->fragment_size()) + return frag; + + if(auto active = active_state()) + if(size_t frag = active->server_hello()->fragment_size()) + return frag; + + return MAX_PLAINTEXT_SIZE; } void Channel::change_cipher_spec_reader(Connection_Side side) @@ -249,6 +257,8 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) const auto get_cipherstate = [this](u16bit epoch) { return this->read_cipher_state_epoch(epoch).get(); }; + const size_t max_fragment_size = maximum_fragment_size(); + try { while(!is_closed() && buf_size) @@ -287,7 +297,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) if(rec_type == NO_RECORD) continue; - if(record.size() > m_max_fragment) + if(record.size() > max_fragment_size) throw TLS_Exception(Alert::RECORD_OVERFLOW, "Plaintext record is too large"); @@ -446,9 +456,11 @@ void Channel::send_record_array(byte type, const byte input[], size_t length) length -= 1; } + const size_t max_fragment_size = maximum_fragment_size(); + while(length) { - const size_t sending = std::min(length, m_max_fragment); + const size_t sending = std::min(length, max_fragment_size); write_record(cipher_state.get(), type, &input[0], sending); input += sending; @@ -464,9 +476,6 @@ void Channel::send_record(byte record_type, const std::vector<byte>& record) void Channel::write_record(Connection_Cipher_State* cipher_state, byte record_type, const byte input[], size_t length) { - if(length > m_max_fragment) - throw Internal_Error("Record is larger than allowed fragment size"); - BOTAN_ASSERT(m_pending_state || m_active_state, "Some connection state exists"); diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h index a320a5f3c..10ecd296f 100644 --- a/src/tls/tls_channel.h +++ b/src/tls/tls_channel.h @@ -152,8 +152,6 @@ class BOTAN_DLL Channel void activate_session(); - void set_maximum_fragment_size(size_t maximum); - void change_cipher_spec_reader(Connection_Side side); void change_cipher_spec_writer(Connection_Side side); @@ -173,6 +171,8 @@ class BOTAN_DLL Channel bool save_session(const Session& session) const { return m_handshake_fn(session); } private: + size_t maximum_fragment_size() const; + void send_record(byte record_type, const std::vector<byte>& record); void send_record_array(byte type, const byte input[], size_t length); @@ -202,9 +202,9 @@ class BOTAN_DLL Channel /* sequence number state */ std::unique_ptr<Connection_Sequence_Numbers> m_sequence_numbers; - /* I/O buffers */ - std::vector<byte> m_writebuf; - std::vector<byte> m_readbuf; + /* pending and active connection states */ + std::unique_ptr<Handshake_State> m_active_state; + std::unique_ptr<Handshake_State> m_pending_state; /* cipher states for each epoch - epoch 0 is plaintext, thus null cipher state */ std::map<u16bit, std::shared_ptr<Connection_Cipher_State>> m_write_cipher_states = @@ -212,12 +212,9 @@ class BOTAN_DLL Channel std::map<u16bit, std::shared_ptr<Connection_Cipher_State>> m_read_cipher_states = { { 0, nullptr } }; - /* pending and active connection states */ - std::unique_ptr<Handshake_State> m_active_state; - std::unique_ptr<Handshake_State> m_pending_state; - - /* misc, should be removed? */ - size_t m_max_fragment = MAX_PLAINTEXT_SIZE; + /* I/O buffers */ + std::vector<byte> m_writebuf; + std::vector<byte> m_readbuf; }; } diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp index 0e1d84bed..aae3a65c5 100644 --- a/src/tls/tls_client.cpp +++ b/src/tls/tls_client.cpp @@ -145,8 +145,6 @@ void Client::send_client_hello(Handshake_State& state_base, } secure_renegotiation_check(state.client_hello()); - - set_maximum_fragment_size(state.client_hello()->fragment_size()); } /* diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp index 1189019bc..b91dfc9aa 100644 --- a/src/tls/tls_server.cpp +++ b/src/tls/tls_server.cpp @@ -394,9 +394,6 @@ void Server::process_handshake_msg(const Handshake_State* active_state, secure_renegotiation_check(state.server_hello()); - if(session_info.fragment_size()) - set_maximum_fragment_size(session_info.fragment_size()); - state.compute_session_keys(session_info.master_secret()); if(!save_session(session_info)) @@ -493,9 +490,6 @@ void Server::process_handshake_msg(const Handshake_State* active_state, secure_renegotiation_check(state.server_hello()); - if(state.client_hello()->fragment_size()) - set_maximum_fragment_size(state.client_hello()->fragment_size()); - const std::string sig_algo = state.ciphersuite().sig_algo(); const std::string kex_algo = state.ciphersuite().kex_algo(); |