diff options
author | Jack Lloyd <[email protected]> | 2016-08-31 10:31:58 -0400 |
---|---|---|
committer | Jack Lloyd <[email protected]> | 2016-08-31 10:31:58 -0400 |
commit | 148262088c117ba849efc42432f2d2510ce25349 (patch) | |
tree | 663bf183b9ce54f08530f319fd4b491473514f61 /src | |
parent | 5e946f93e8e751d2104f58583d4f209ca631aff1 (diff) | |
parent | ee60a29088fc6dd712c1651af1e7f56a26f40d63 (diff) |
Merge GH #567/GH #457 TLS refactoring and Callbacks interface
Diffstat (limited to 'src')
32 files changed, 1738 insertions, 1063 deletions
diff --git a/src/cli/tls_client.cpp b/src/cli/tls_client.cpp index 6af2f56f8..082daf4ac 100644 --- a/src/cli/tls_client.cpp +++ b/src/cli/tls_client.cpp @@ -1,5 +1,6 @@ /* * (C) 2014,2015 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -35,7 +36,7 @@ namespace Botan_CLI { -class TLS_Client final : public Command +class TLS_Client final : public Command, public Botan::TLS::Callbacks { public: TLS_Client() : Command("tls_client host --port=443 --print-certs --policy= " @@ -98,15 +99,10 @@ class TLS_Client final : public Command const std::vector<std::string> protocols_to_offer = Botan::split_on("next-protocols", ','); - int sockfd = connect_to_host(host, port, use_tcp); + m_sockfd = connect_to_host(host, port, use_tcp); using namespace std::placeholders; - auto socket_write = - use_tcp ? - std::bind(stream_socket_write, sockfd, _1, _2) : - std::bind(dgram_socket_write, sockfd, _1, _2); - auto version = policy->latest_supported_version(!use_tcp); if(flag_set("tls1.0")) @@ -118,10 +114,7 @@ class TLS_Client final : public Command version = Botan::TLS::Protocol_Version::TLS_V11; } - Botan::TLS::Client client(socket_write, - std::bind(&TLS_Client::process_data, this, _1, _2), - std::bind(&TLS_Client::alert_received, this, _1, _2, _3), - std::bind(&TLS_Client::handshake_complete, this, _1), + Botan::TLS::Client client(*this, *session_mgr, creds, *policy, @@ -136,7 +129,7 @@ class TLS_Client final : public Command { fd_set readfds; FD_ZERO(&readfds); - FD_SET(sockfd, &readfds); + FD_SET(m_sockfd, &readfds); if(client.is_active()) { @@ -152,13 +145,13 @@ class TLS_Client final : public Command struct timeval timeout = { 1, 0 }; - ::select(sockfd + 1, &readfds, nullptr, nullptr, &timeout); + ::select(m_sockfd + 1, &readfds, nullptr, nullptr, &timeout); - if(FD_ISSET(sockfd, &readfds)) + if(FD_ISSET(m_sockfd, &readfds)) { uint8_t buf[4*1024] = { 0 }; - ssize_t got = ::read(sockfd, buf, sizeof(buf)); + ssize_t got = ::read(m_sockfd, buf, sizeof(buf)); if(got == 0) { @@ -216,7 +209,7 @@ class TLS_Client final : public Command } } - ::close(sockfd); + ::close(m_sockfd); } private: @@ -256,7 +249,7 @@ class TLS_Client final : public Command return fd; } - bool handshake_complete(const Botan::TLS::Session& session) + bool tls_session_established(const Botan::TLS::Session& session) override { output() << "Handshake complete, " << session.version().to_string() << " using " << session.ciphersuite().to_string() << "\n"; @@ -290,13 +283,13 @@ class TLS_Client final : public Command throw CLI_Error("Socket write failed errno=" + std::to_string(errno)); } - static void stream_socket_write(int sockfd, const uint8_t buf[], size_t length) + void tls_emit_data(const uint8_t buf[], size_t length) override { size_t offset = 0; while(length) { - ssize_t sent = ::send(sockfd, (const char*)buf + offset, + ssize_t sent = ::send(m_sockfd, (const char*)buf + offset, length, MSG_NOSIGNAL); if(sent == -1) @@ -312,16 +305,19 @@ class TLS_Client final : public Command } } - void alert_received(Botan::TLS::Alert alert, const uint8_t [], size_t ) + void tls_alert(Botan::TLS::Alert alert) { output() << "Alert: " << alert.type_string() << "\n"; } - void process_data(const uint8_t buf[], size_t buf_size) + void tls_record_received(uint64_t seq_no, const uint8_t buf[], size_t buf_size) { for(size_t i = 0; i != buf_size; ++i) output() << buf[i]; } + + private: + int m_sockfd; }; BOTAN_REGISTER_COMMAND("tls_client", TLS_Client); diff --git a/src/cli/tls_proxy.cpp b/src/cli/tls_proxy.cpp index 2929e473d..5140654de 100644 --- a/src/cli/tls_proxy.cpp +++ b/src/cli/tls_proxy.cpp @@ -1,6 +1,7 @@ /* * TLS Server Proxy * (C) 2014,2015 Jack Lloyd +* (C) 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -59,7 +60,7 @@ void log_text_message(const char* where, const uint8_t buf[], size_t buf_len) //std::cout << where << ' ' << std::string(c, c + buf_len) << std::endl; } -class tls_proxy_session : public boost::enable_shared_from_this<tls_proxy_session> +class tls_proxy_session : public boost::enable_shared_from_this<tls_proxy_session>, public Botan::TLS::Callbacks { public: enum { readbuf_size = 4 * 1024 }; @@ -111,10 +112,7 @@ class tls_proxy_session : public boost::enable_shared_from_this<tls_proxy_sessio m_server_endpoints(endpoints), m_client_socket(io), m_server_socket(io), - m_tls(boost::bind(&tls_proxy_session::tls_proxy_write_to_client, this, _1, _2), - boost::bind(&tls_proxy_session::tls_client_write_to_proxy, this, _1, _2), - boost::bind(&tls_proxy_session::tls_alert_cb, this, _1, _2, _3), - boost::bind(&tls_proxy_session::tls_handshake_complete, this, _1), + m_tls(*this, session_manager, credentials, policy, @@ -167,7 +165,7 @@ class tls_proxy_session : public boost::enable_shared_from_this<tls_proxy_sessio { m_client_socket.close(); } - tls_proxy_write_to_client(nullptr, 0); // initiate another write if needed + tls_emit_data(nullptr, 0); // initiate another write if needed } void handle_server_write_completion(const boost::system::error_code& error) @@ -183,13 +181,13 @@ class tls_proxy_session : public boost::enable_shared_from_this<tls_proxy_sessio proxy_write_to_server(nullptr, 0); // initiate another write if needed } - void tls_client_write_to_proxy(const uint8_t buf[], size_t buf_len) + void tls_record_received(uint64_t /*rec_no*/, const uint8_t buf[], size_t buf_len) override { // Immediately bounce message to server proxy_write_to_server(buf, buf_len); } - void tls_proxy_write_to_client(const uint8_t buf[], size_t buf_len) + void tls_emit_data(const uint8_t buf[], size_t buf_len) override { if(buf_len > 0) m_p2c_pending.insert(m_p2c_pending.end(), buf, buf + buf_len); @@ -268,7 +266,7 @@ class tls_proxy_session : public boost::enable_shared_from_this<tls_proxy_sessio boost::asio::placeholders::bytes_transferred))); } - bool tls_handshake_complete(const Botan::TLS::Session& session) + bool tls_session_established(const Botan::TLS::Session& session) override { //std::cout << "Handshake from client complete" << std::endl; @@ -292,7 +290,7 @@ class tls_proxy_session : public boost::enable_shared_from_this<tls_proxy_sessio return true; } - void tls_alert_cb(Botan::TLS::Alert alert, const uint8_t[], size_t) + void tls_alert(Botan::TLS::Alert alert) override { if(alert.type() == Botan::TLS::Alert::CLOSE_NOTIFY) { diff --git a/src/lib/cert/x509/x509_ca.h b/src/lib/cert/x509/x509_ca.h index 6ea51cd06..ba3724f5e 100644 --- a/src/lib/cert/x509/x509_ca.h +++ b/src/lib/cert/x509/x509_ca.h @@ -22,7 +22,6 @@ namespace Botan { class BOTAN_DLL X509_CA { public: - /** * Sign a PKCS#10 Request. * @param req the request to sign diff --git a/src/lib/math/bigint/big_ops2.cpp b/src/lib/math/bigint/big_ops2.cpp index 9a3408247..6e234f036 100644 --- a/src/lib/math/bigint/big_ops2.cpp +++ b/src/lib/math/bigint/big_ops2.cpp @@ -1,6 +1,7 @@ /* * BigInt Assignment Operators * (C) 1999-2007 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -118,10 +119,7 @@ BigInt& BigInt::operator*=(const BigInt& y) secure_vector<word> z(data(), data() + x_sw); secure_vector<word> workspace(size()); - - bigint_mul(mutable_data(), size(), workspace.data(), - z.data(), z.size(), x_sw, - y.data(), y.size(), y_sw); + bigint_mul(*this, BigInt(*this), y, workspace.data()); } return (*this); diff --git a/src/lib/math/bigint/big_ops3.cpp b/src/lib/math/bigint/big_ops3.cpp index 6cf837020..24927b4fc 100644 --- a/src/lib/math/bigint/big_ops3.cpp +++ b/src/lib/math/bigint/big_ops3.cpp @@ -1,6 +1,7 @@ /* * BigInt Binary Operators * (C) 1999-2007 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -93,9 +94,7 @@ BigInt operator*(const BigInt& x, const BigInt& y) else if(x_sw && y_sw) { secure_vector<word> workspace(z.size()); - bigint_mul(z.mutable_data(), z.size(), workspace.data(), - x.data(), x.size(), x_sw, - y.data(), y.size(), y_sw); + bigint_mul(z, x, y, workspace.data()); } if(x_sw && y_sw && x.sign() != y.sign()) diff --git a/src/lib/math/ec_gfp/curve_gfp.cpp b/src/lib/math/ec_gfp/curve_gfp.cpp index 9bf2191c6..96593e601 100644 --- a/src/lib/math/ec_gfp/curve_gfp.cpp +++ b/src/lib/math/ec_gfp/curve_gfp.cpp @@ -1,6 +1,7 @@ /* * Elliptic curves over GF(p) Montgomery Representation * (C) 2014,2015 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -80,20 +81,14 @@ void CurveGFp_Montgomery::curve_mul(BigInt& z, const BigInt& x, const BigInt& y, return; } - const size_t x_sw = x.sig_words(); - const size_t y_sw = y.sig_words(); - const size_t output_size = 2*m_p_words + 1; ws.resize(2*(m_p_words+2)); z.grow_to(output_size); z.clear(); - bigint_monty_mul(z.mutable_data(), output_size, - x.data(), x.size(), x_sw, - y.data(), y.size(), y_sw, - m_p.data(), m_p_words, m_p_dash, - ws.data()); + bigint_monty_mul(z, x, y, m_p.data(), m_p_words, m_p_dash, ws.data()); + } void CurveGFp_Montgomery::curve_sqr(BigInt& z, const BigInt& x, @@ -115,9 +110,7 @@ void CurveGFp_Montgomery::curve_sqr(BigInt& z, const BigInt& x, z.grow_to(output_size); z.clear(); - bigint_monty_sqr(z.mutable_data(), output_size, - x.data(), x.size(), x_sw, - m_p.data(), m_p_words, m_p_dash, + bigint_monty_sqr(z, x, m_p.data(), m_p_words, m_p_dash, ws.data()); } @@ -174,9 +167,7 @@ void CurveGFp_NIST::curve_mul(BigInt& z, const BigInt& x, const BigInt& y, z.grow_to(output_size); z.clear(); - bigint_mul(z.mutable_data(), output_size, ws.data(), - x.data(), x.size(), x.sig_words(), - y.data(), y.size(), y.sig_words()); + bigint_mul(z, x, y, ws.data()); this->redc(z, ws); } diff --git a/src/lib/math/mp/mp_core.h b/src/lib/math/mp/mp_core.h index 73f13742c..c4ce005ba 100644 --- a/src/lib/math/mp/mp_core.h +++ b/src/lib/math/mp/mp_core.h @@ -2,6 +2,7 @@ * MPI Algorithms * (C) 1999-2010 Jack Lloyd * 2006 Luca Piccarreta +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -9,6 +10,7 @@ #ifndef BOTAN_MP_CORE_OPS_H__ #define BOTAN_MP_CORE_OPS_H__ +#include <botan/bigint.h> #include <botan/mp_types.h> namespace Botan { @@ -134,17 +136,14 @@ void bigint_monty_redc(word z[], /* * Montgomery Multiplication */ -void bigint_monty_mul(word z[], size_t z_size, - const word x[], size_t x_size, size_t x_sw, - const word y[], size_t y_size, size_t y_sw, +void bigint_monty_mul(BigInt& z, const BigInt& x, const BigInt& y, const word p[], size_t p_size, word p_dash, word workspace[]); /* * Montgomery Squaring */ -void bigint_monty_sqr(word z[], size_t z_size, - const word x[], size_t x_size, size_t x_sw, +void bigint_monty_sqr(BigInt& z, const BigInt& x, const word p[], size_t p_size, word p_dash, word workspace[]); @@ -182,9 +181,7 @@ void bigint_comba_sqr16(word out[32], const word in[16]); /* * High Level Multiplication/Squaring Interfaces */ -void bigint_mul(word z[], size_t z_size, word workspace[], - const word x[], size_t x_size, size_t x_sw, - const word y[], size_t y_size, size_t y_sw); +void bigint_mul(BigInt& z, const BigInt& x, const BigInt& y, word workspace[]); void bigint_sqr(word z[], size_t z_size, word workspace[], const word x[], size_t x_size, size_t x_sw); diff --git a/src/lib/math/mp/mp_karat.cpp b/src/lib/math/mp/mp_karat.cpp index 9135fdd6a..7a763e2a9 100644 --- a/src/lib/math/mp/mp_karat.cpp +++ b/src/lib/math/mp/mp_karat.cpp @@ -1,6 +1,7 @@ /* * Multiplication and Squaring * (C) 1999-2010 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -252,60 +253,55 @@ size_t karatsuba_size(size_t z_size, size_t x_size, size_t x_sw) /* * Multiplication Algorithm Dispatcher */ -void bigint_mul(word z[], size_t z_size, word workspace[], - const word x[], size_t x_size, size_t x_sw, - const word y[], size_t y_size, size_t y_sw) +void bigint_mul(BigInt& z, const BigInt& x, const BigInt& y, word workspace[]) { - // checking that z_size >= x_sw + y_sw without overflow - BOTAN_ASSERT(z_size > x_sw && z_size > y_sw && z_size-x_sw >= y_sw, "Output size is sufficient"); - - if(x_sw == 1) + if(x.sig_words() == 1) { - bigint_linmul3(z, y, y_sw, x[0]); + bigint_linmul3(z.mutable_data(), y.data(), y.sig_words(), x.data()[0]); } - else if(y_sw == 1) + else if(y.sig_words() == 1) { - bigint_linmul3(z, x, x_sw, y[0]); + bigint_linmul3(z.mutable_data(), x.data(), x.sig_words(), y.data()[0]); } - else if(x_sw <= 4 && x_size >= 4 && - y_sw <= 4 && y_size >= 4 && z_size >= 8) + else if(x.sig_words() <= 4 && x.size() >= 4 && + y.sig_words() <= 4 && y.size() >= 4 && z.size() >= 8) { - bigint_comba_mul4(z, x, y); + bigint_comba_mul4(z.mutable_data(), x.data(), y.data()); } - else if(x_sw <= 6 && x_size >= 6 && - y_sw <= 6 && y_size >= 6 && z_size >= 12) + else if(x.sig_words() <= 6 && x.size() >= 6 && + y.sig_words() <= 6 && y.size() >= 6 && z.size() >= 12) { - bigint_comba_mul6(z, x, y); + bigint_comba_mul6(z.mutable_data(), x.data(), y.data()); } - else if(x_sw <= 8 && x_size >= 8 && - y_sw <= 8 && y_size >= 8 && z_size >= 16) + else if(x.sig_words() <= 8 && x.size() >= 8 && + y.sig_words() <= 8 && y.size() >= 8 && z.size() >= 16) { - bigint_comba_mul8(z, x, y); + bigint_comba_mul8(z.mutable_data(), x.data(), y.data()); } - else if(x_sw <= 9 && x_size >= 9 && - y_sw <= 9 && y_size >= 9 && z_size >= 18) + else if(x.sig_words() <= 9 && x.size() >= 9 && + y.sig_words() <= 9 && y.size() >= 9 && z.size() >= 18) { - bigint_comba_mul9(z, x, y); + bigint_comba_mul9(z.mutable_data(), x.data(), y.data()); } - else if(x_sw <= 16 && x_size >= 16 && - y_sw <= 16 && y_size >= 16 && z_size >= 32) + else if(x.sig_words() <= 16 && x.size() >= 16 && + y.sig_words() <= 16 && y.size() >= 16 && z.size() >= 32) { - bigint_comba_mul16(z, x, y); + bigint_comba_mul16(z.mutable_data(), x.data(), y.data()); } - else if(x_sw < KARATSUBA_MULTIPLY_THRESHOLD || - y_sw < KARATSUBA_MULTIPLY_THRESHOLD || + else if(x.sig_words() < KARATSUBA_MULTIPLY_THRESHOLD || + y.sig_words() < KARATSUBA_MULTIPLY_THRESHOLD || !workspace) { - basecase_mul(z, x, x_sw, y, y_sw); + basecase_mul(z.mutable_data(), x.data(), x.sig_words(), y.data(), y.sig_words()); } else { - const size_t N = karatsuba_size(z_size, x_size, x_sw, y_size, y_sw); + const size_t N = karatsuba_size(z.size(), x.size(), x.sig_words(), y.size(), y.sig_words()); if(N) - karatsuba_mul(z, x, y, N, workspace); + karatsuba_mul(z.mutable_data(), x.data(), y.data(), N, workspace); else - basecase_mul(z, x, x_sw, y, y_sw); + basecase_mul(z.mutable_data(), x.data(), x.sig_words(), y.data(), y.sig_words()); } } diff --git a/src/lib/math/mp/mp_monty.cpp b/src/lib/math/mp/mp_monty.cpp index 7e427b540..88b5de715 100644 --- a/src/lib/math/mp/mp_monty.cpp +++ b/src/lib/math/mp/mp_monty.cpp @@ -2,10 +2,12 @@ * Montgomery Reduction * (C) 1999-2011 Jack Lloyd * 2006 Luca Piccarreta +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ +#include <botan/bigint.h> #include <botan/internal/mp_core.h> #include <botan/internal/mp_madd.h> #include <botan/internal/mp_asmi.h> @@ -92,30 +94,25 @@ void bigint_monty_redc(word z[], BOTAN_ASSERT(borrow == 0 || borrow == 1, "Expected borrow"); } -void bigint_monty_mul(word z[], size_t z_size, - const word x[], size_t x_size, size_t x_sw, - const word y[], size_t y_size, size_t y_sw, +void bigint_monty_mul(BigInt& z, const BigInt& x, const BigInt& y, const word p[], size_t p_size, word p_dash, word ws[]) { - bigint_mul(&z[0], z_size, &ws[0], - &x[0], x_size, x_sw, - &y[0], y_size, y_sw); + bigint_mul(z, x, y, &ws[0]); - bigint_monty_redc(&z[0], + bigint_monty_redc(z.mutable_data(), &p[0], p_size, p_dash, &ws[0]); + } -void bigint_monty_sqr(word z[], size_t z_size, - const word x[], size_t x_size, size_t x_sw, - const word p[], size_t p_size, word p_dash, - word ws[]) +void bigint_monty_sqr(BigInt& z, const BigInt& x, const word p[], + size_t p_size, word p_dash, word ws[]) { - bigint_sqr(&z[0], z_size, &ws[0], - &x[0], x_size, x_sw); + bigint_sqr(z.mutable_data(), z.size(), &ws[0], + x.data(), x.size(), x.sig_words()); - bigint_monty_redc(&z[0], + bigint_monty_redc(z.mutable_data(), &p[0], p_size, p_dash, &ws[0]); } diff --git a/src/lib/math/numbertheory/mp_numth.cpp b/src/lib/math/numbertheory/mp_numth.cpp index 3373b9ee7..d78d21128 100644 --- a/src/lib/math/numbertheory/mp_numth.cpp +++ b/src/lib/math/numbertheory/mp_numth.cpp @@ -1,6 +1,7 @@ /* * Fused and Important MP Algorithms * (C) 1999-2007 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -40,20 +41,13 @@ BigInt mul_add(const BigInt& a, const BigInt& b, const BigInt& c) if(a.sign() != b.sign()) sign = BigInt::Negative; - const size_t a_sw = a.sig_words(); - const size_t b_sw = b.sig_words(); - const size_t c_sw = c.sig_words(); - - BigInt r(sign, std::max(a.size() + b.size(), c_sw) + 1); + BigInt r(sign, std::max(a.size() + b.size(), c.sig_words()) + 1); secure_vector<word> workspace(r.size()); - bigint_mul(r.mutable_data(), r.size(), - workspace.data(), - a.data(), a.size(), a_sw, - b.data(), b.size(), b_sw); + bigint_mul(r, a, b, workspace.data()); - const size_t r_size = std::max(r.sig_words(), c_sw); - bigint_add2(r.mutable_data(), r_size, c.data(), c_sw); + const size_t r_size = std::max(r.sig_words(), c.sig_words()); + bigint_add2(r.mutable_data(), r_size, c.data(), c.sig_words()); return r; } diff --git a/src/lib/math/numbertheory/powm_mnt.cpp b/src/lib/math/numbertheory/powm_mnt.cpp index 5c441db3a..572f0de98 100644 --- a/src/lib/math/numbertheory/powm_mnt.cpp +++ b/src/lib/math/numbertheory/powm_mnt.cpp @@ -1,6 +1,7 @@ /* * Montgomery Exponentiation * (C) 1999-2010,2012 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -8,6 +9,7 @@ #include <botan/internal/def_powm.h> #include <botan/numthry.h> #include <botan/internal/mp_core.h> +#include <iostream> namespace Botan { @@ -34,36 +36,26 @@ void Montgomery_Exponentiator::set_base(const BigInt& base) m_g[0] = 1; - bigint_monty_mul(z.mutable_data(), z.size(), - m_g[0].data(), m_g[0].size(), m_g[0].sig_words(), - m_R2_mod.data(), m_R2_mod.size(), m_R2_mod.sig_words(), + bigint_monty_mul(z, m_g[0], m_R2_mod, m_modulus.data(), m_mod_words, m_mod_prime, workspace.data()); - m_g[0] = z; m_g[1] = (base >= m_modulus) ? (base % m_modulus) : base; - bigint_monty_mul(z.mutable_data(), z.size(), - m_g[1].data(), m_g[1].size(), m_g[1].sig_words(), - m_R2_mod.data(), m_R2_mod.size(), m_R2_mod.sig_words(), + bigint_monty_mul(z, m_g[1], m_R2_mod, m_modulus.data(), m_mod_words, m_mod_prime, workspace.data()); m_g[1] = z; const BigInt& x = m_g[1]; - const size_t x_sig = x.sig_words(); for(size_t i = 2; i != m_g.size(); ++i) { const BigInt& y = m_g[i-1]; - const size_t y_sig = y.sig_words(); - bigint_monty_mul(z.mutable_data(), z.size(), - x.data(), x.size(), x_sig, - y.data(), y.size(), y_sig, - m_modulus.data(), m_mod_words, m_mod_prime, + bigint_monty_mul(z, x, y, m_modulus.data(), m_mod_words, m_mod_prime, workspace.data()); m_g[i] = z; @@ -82,15 +74,13 @@ BigInt Montgomery_Exponentiator::execute() const const size_t z_size = 2*(m_mod_words + 1); BigInt z(BigInt::Positive, z_size); - secure_vector<word> workspace(z_size); + secure_vector<word> workspace(z.size()); for(size_t i = exp_nibbles; i > 0; --i) { for(size_t k = 0; k != m_window_bits; ++k) { - bigint_monty_sqr(z.mutable_data(), z_size, - x.data(), x.size(), x.sig_words(), - m_modulus.data(), m_mod_words, m_mod_prime, + bigint_monty_sqr(z, x, m_modulus.data(), m_mod_words, m_mod_prime, workspace.data()); x = z; @@ -100,9 +90,7 @@ BigInt Montgomery_Exponentiator::execute() const const BigInt& y = m_g[nibble]; - bigint_monty_mul(z.mutable_data(), z_size, - x.data(), x.size(), x.sig_words(), - y.data(), y.size(), y.sig_words(), + bigint_monty_mul(z, x, y, m_modulus.data(), m_mod_words, m_mod_prime, workspace.data()); diff --git a/src/lib/pubkey/curve25519/donna.cpp b/src/lib/pubkey/curve25519/donna.cpp index 9b28e412c..a0e4d249f 100644 --- a/src/lib/pubkey/curve25519/donna.cpp +++ b/src/lib/pubkey/curve25519/donna.cpp @@ -39,6 +39,26 @@ typedef byte u8; typedef u64bit limb; typedef limb felem[5]; +typedef struct + { + limb* x; + limb* z; + } fmonty_pair_t; + +typedef struct + { + fmonty_pair_t q; + fmonty_pair_t q_dash; + const limb* q_minus_q_dash; + } fmonty_in_t; + +typedef struct + { + fmonty_pair_t two_q; + fmonty_pair_t q_plus_q_dash; + } fmonty_out_t; + + #if !defined(BOTAN_TARGET_HAS_NATIVE_UINT128) typedef donna128 uint128_t; #endif @@ -273,44 +293,41 @@ fcontract(u8 *output, const felem input) { /* Input: Q, Q', Q-Q' * Output: 2Q, Q+Q' * - * x2 z3: long form - * x3 z3: long form - * x z: short form, destroyed - * xprime zprime: short form, destroyed - * qmqp: short form, preserved + * result.two_q (2*Q): long form + * result.q_plus_q_dash (Q + Q): long form + * in.q: short form, destroyed + * in.q_dash: short form, destroyed + * in.q_minus_q_dash: short form, preserved */ static void -fmonty(limb *x2, limb *z2, /* output 2Q */ - limb *x3, limb *z3, /* output Q + Q' */ - limb *x, limb *z, /* input Q */ - limb *xprime, limb *zprime, /* input Q' */ - const limb *qmqp /* input Q - Q' */) { +fmonty(fmonty_out_t& result, fmonty_in_t& in) +{ limb origx[5], origxprime[5], zzz[5], xx[5], zz[5], xxprime[5], - zzprime[5], zzzprime[5]; + zzprime[5], zzzprime[5]; - copy_mem(origx, x, 5); - fsum(x, z); - fdifference_backwards(z, origx); // does x - z + copy_mem(origx, in.q.x, 5); + fsum(in.q.x, in.q.z); + fdifference_backwards(in.q.z, origx); // does x - z - copy_mem(origxprime, xprime, 5); - fsum(xprime, zprime); - fdifference_backwards(zprime, origxprime); - fmul(xxprime, xprime, z); - fmul(zzprime, x, zprime); + copy_mem(origxprime, in.q_dash.x, 5); + fsum(in.q_dash.x, in.q_dash.z); + fdifference_backwards(in.q_dash.z, origxprime); + fmul(xxprime, in.q_dash.x, in.q.z); + fmul(zzprime, in.q.x, in.q_dash.z); copy_mem(origxprime, xxprime, 5); fsum(xxprime, zzprime); fdifference_backwards(zzprime, origxprime); - fsquare_times(x3, xxprime, 1); + fsquare_times(result.q_plus_q_dash.x, xxprime, 1); fsquare_times(zzzprime, zzprime, 1); - fmul(z3, zzzprime, qmqp); + fmul(result.q_plus_q_dash.z, zzzprime, in.q_minus_q_dash); - fsquare_times(xx, x, 1); - fsquare_times(zz, z, 1); - fmul(x2, xx, zz); + fsquare_times(xx, in.q.x, 1); + fsquare_times(zz, in.q.z, 1); + fmul(result.two_q.x, xx, zz); fdifference_backwards(zz, xx); // does zz = xx - zz fscalar_product(zzz, zz, 121665); fsum(zzz, xx); - fmul(z2, zz, zzz); + fmul(result.two_q.z, zz, zzz); } // ----------------------------------------------------------------------------- @@ -356,11 +373,10 @@ cmult(limb *resultx, limb *resultz, const u8 *n, const limb *q) { swap_conditional(nqx, nqpqx, bit); swap_conditional(nqz, nqpqz, bit); - fmonty(nqx2, nqz2, - nqpqx2, nqpqz2, - nqx, nqz, - nqpqx, nqpqz, - q); + + fmonty_out_t result { nqx2, nqz2, nqpqx2, nqpqz2 }; + fmonty_in_t in { nqx, nqz, nqpqx, nqpqz, q }; + fmonty(result, in); swap_conditional(nqx2, nqpqx2, bit); swap_conditional(nqz2, nqpqz2, bit); diff --git a/src/lib/tls/info.txt b/src/lib/tls/info.txt index b26179226..667726318 100644 --- a/src/lib/tls/info.txt +++ b/src/lib/tls/info.txt @@ -6,6 +6,7 @@ load_on auto credentials_manager.h tls_alert.h tls_blocking.h +tls_callbacks.h tls_channel.h tls_ciphersuite.h tls_client.h diff --git a/src/lib/tls/msg_client_hello.cpp b/src/lib/tls/msg_client_hello.cpp index 23807215f..41a6f5c02 100644 --- a/src/lib/tls/msg_client_hello.cpp +++ b/src/lib/tls/msg_client_hello.cpp @@ -1,6 +1,7 @@ /* * TLS Hello Request and Client Hello Messages * (C) 2004-2011,2015,2016 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -68,22 +69,20 @@ std::vector<byte> Hello_Request::serialize() const */ Client_Hello::Client_Hello(Handshake_IO& io, Handshake_Hash& hash, - Protocol_Version version, const Policy& policy, RandomNumberGenerator& rng, const std::vector<byte>& reneg_info, - const std::vector<std::string>& next_protocols, - const std::string& hostname, - const std::string& srp_identifier) : - m_version(version), + const Client_Hello::Settings& client_settings, + const std::vector<std::string>& next_protocols) : + m_version(client_settings.protocol_version()), m_random(make_hello_random(rng, policy)), - m_suites(policy.ciphersuite_list(m_version, (srp_identifier != ""))), + m_suites(policy.ciphersuite_list(m_version, + client_settings.srp_identifier() != "")), m_comp_methods(policy.compression()) { m_extensions.add(new Extended_Master_Secret); m_extensions.add(new Renegotiation_Extension(reneg_info)); - - m_extensions.add(new Server_Name_Indicator(hostname)); + m_extensions.add(new Server_Name_Indicator(client_settings.hostname())); m_extensions.add(new Session_Ticket()); m_extensions.add(new Supported_Elliptic_Curves(policy.allowed_ecc_curves())); @@ -98,7 +97,7 @@ Client_Hello::Client_Hello(Handshake_IO& io, m_extensions.add(new Application_Layer_Protocol_Notification(next_protocols)); #if defined(BOTAN_HAS_SRP6) - m_extensions.add(new SRP_Identifier(srp_identifier)); + m_extensions.add(new SRP_Identifier(client_settings.srp_identifier())); #else if(!srp_identifier.empty()) { @@ -106,10 +105,10 @@ Client_Hello::Client_Hello(Handshake_IO& io, } #endif - BOTAN_ASSERT(policy.acceptable_protocol_version(version), + BOTAN_ASSERT(policy.acceptable_protocol_version(client_settings.protocol_version()), "Our policy accepts the version we are offering"); - if(policy.send_fallback_scsv(version)) + if(policy.send_fallback_scsv(client_settings.protocol_version())) m_suites.push_back(TLS_FALLBACK_SCSV); hash.update(io.send(*this)); diff --git a/src/lib/tls/msg_server_hello.cpp b/src/lib/tls/msg_server_hello.cpp index f8d0c63c7..f32625508 100644 --- a/src/lib/tls/msg_server_hello.cpp +++ b/src/lib/tls/msg_server_hello.cpp @@ -1,6 +1,7 @@ /* * TLS Server Hello and Server Hello Done * (C) 2004-2011,2015,2016 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -23,17 +24,13 @@ Server_Hello::Server_Hello(Handshake_IO& io, RandomNumberGenerator& rng, const std::vector<byte>& reneg_info, const Client_Hello& client_hello, - const std::vector<byte>& new_session_id, - Protocol_Version new_session_version, - u16bit ciphersuite, - byte compression, - bool offer_session_ticket, - const std::string& next_protocol) : - m_version(new_session_version), - m_session_id(new_session_id), + const Server_Hello::Settings& server_settings, + const std::string next_protocol) : + m_version(server_settings.protocol_version()), + m_session_id(server_settings.session_id()), m_random(make_hello_random(rng, policy)), - m_ciphersuite(ciphersuite), - m_comp_method(compression) + m_ciphersuite(server_settings.ciphersuite()), + m_comp_method(server_settings.compression()) { if(client_hello.supports_extended_master_secret()) m_extensions.add(new Extended_Master_Secret); @@ -41,7 +38,7 @@ Server_Hello::Server_Hello(Handshake_IO& io, if(client_hello.secure_renegotiation()) m_extensions.add(new Renegotiation_Extension(reneg_info)); - if(client_hello.supports_session_ticket() && offer_session_ticket) + if(client_hello.supports_session_ticket() && server_settings.offer_session_ticket()) m_extensions.add(new Session_Ticket()); if(!next_protocol.empty() && client_hello.supports_alpn()) diff --git a/src/lib/tls/tls_blocking.cpp b/src/lib/tls/tls_blocking.cpp index a1867b6b5..34ebc6c91 100644 --- a/src/lib/tls/tls_blocking.cpp +++ b/src/lib/tls/tls_blocking.cpp @@ -1,6 +1,7 @@ /* * TLS Blocking API * (C) 2013 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -23,10 +24,13 @@ Blocking_Client::Blocking_Client(read_fn reader, const Protocol_Version& offer_version, const std::vector<std::string>& next) : m_read(reader), - m_channel(writer, - std::bind(&Blocking_Client::data_cb, this, _1, _2), - std::bind(&Blocking_Client::alert_cb, this, _1, _2, _3), - std::bind(&Blocking_Client::handshake_cb, this, _1), + m_callbacks(new TLS::Compat_Callbacks( + writer, + std::bind(&Blocking_Client::data_cb, this, _1, _2), + std::function<void (Alert)>(std::bind(&Blocking_Client::alert_cb, this, _1)), + std::bind(&Blocking_Client::handshake_cb, this, _1) + )), + m_channel(*m_callbacks.get(), session_manager, creds, policy, @@ -35,6 +39,7 @@ Blocking_Client::Blocking_Client(read_fn reader, offer_version, next) { + printf("hi\n"); } bool Blocking_Client::handshake_cb(const Session& session) @@ -42,7 +47,7 @@ bool Blocking_Client::handshake_cb(const Session& session) return this->handshake_complete(session); } -void Blocking_Client::alert_cb(const Alert& alert, const byte[], size_t) +void Blocking_Client::alert_cb(const Alert& alert) { this->alert_notification(alert); } diff --git a/src/lib/tls/tls_blocking.h b/src/lib/tls/tls_blocking.h index 00e65cbaf..0f2986710 100644 --- a/src/lib/tls/tls_blocking.h +++ b/src/lib/tls/tls_blocking.h @@ -1,6 +1,7 @@ /* * TLS Blocking API * (C) 2013 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -32,6 +33,7 @@ class BOTAN_DLL Blocking_Client typedef std::function<size_t (byte[], size_t)> read_fn; typedef std::function<void (const byte[], size_t)> write_fn; + BOTAN_DEPRECATED("Use the regular TLS::Client interface") Blocking_Client(read_fn reader, write_fn writer, Session_Manager& session_manager, @@ -89,9 +91,10 @@ class BOTAN_DLL Blocking_Client void data_cb(const byte data[], size_t data_len); - void alert_cb(const Alert& alert, const byte data[], size_t data_len); + void alert_cb(const Alert& alert); read_fn m_read; + std::unique_ptr<Compat_Callbacks> m_callbacks; TLS::Client m_channel; secure_vector<byte> m_plaintext; }; diff --git a/src/lib/tls/tls_callbacks.h b/src/lib/tls/tls_callbacks.h new file mode 100644 index 000000000..d7a68da31 --- /dev/null +++ b/src/lib/tls/tls_callbacks.h @@ -0,0 +1,205 @@ +/* +* TLS Callbacks +* (C) 2016 Matthias Gierlings +* 2016 Jack Lloyd +* +* Botan is released under the Simplified BSD License (see license.txt) +*/ + +#ifndef BOTAN_TLS_CALLBACKS_H__ +#define BOTAN_TLS_CALLBACKS_H__ + +#include <botan/tls_session.h> +#include <botan/tls_alert.h> +namespace Botan { + +namespace TLS { + +class Handshake_Message; + +/** +* Encapsulates the callbacks that a TLS channel will make which are due to +* channel specific operations. +*/ +class BOTAN_DLL Callbacks + { + public: + virtual ~Callbacks() {} + + /** + * Mandatory callback: output function + * The channel will call this with data which needs to be sent to the peer + * (eg, over a socket or some other form of IPC). The array will be overwritten + * when the function returns so a copy must be made if the data cannot be + * sent immediately. + * + * @param data the vector of data to send + * + * @param size the number of bytes to send + */ + virtual void tls_emit_data(const uint8_t data[], size_t size) = 0; + + /** + * Mandatory callback: process application data + * Called when application data record is received from the peer. + * Again the array is overwritten immediately after the function returns. + * + * @param seq_no the underlying TLS/DTLS record sequence number + * + * @param data the vector containing the received record + * + * @param size the length of the received record, in bytes + */ + virtual void tls_record_received(u64bit seq_no, const uint8_t data[], size_t size) = 0; + + /** + * Mandary callback: alert received + * Called when an alert is received from the peer + * If fatal, the connection is closing. If not fatal, the connection may + * still be closing (depending on the error and the peer). + * + * @param alert the source of the alert + */ + virtual void tls_alert(Alert alert) = 0; + + /** + * Mandatory callback: session established + * Called when a session is established. Throw an exception to abort + * the connection. + * + * @param session the session descriptor + * + * @return return false to prevent the session from being cached, + * return true to cache the session in the configured session manager + */ + virtual bool tls_session_established(const Session& session) = 0; + + /** + * Optional callback: inspect handshake message + * Throw an exception to abort the handshake. + * + * @param message the handshake message + */ + virtual void tls_inspect_handshake_msg(const Handshake_Message& message) {} + + /** + * Optional callback for server: choose ALPN protocol + * ALPN (RFC 7301) works by the client sending a list of application + * protocols it is willing to negotiate. The server then selects which + * protocol to use, which is not necessarily even on the list that + * the client sent. + * + * @param client_protos the vector of protocols the client is willing to negotiate + * + * @return the protocol selected by the server, which need not be on the + * list that the client sent; if this is the empty string, the server ignores the + * client ALPN extension + */ + virtual std::string tls_server_choose_app_protocol(const std::vector<std::string>& client_protos) + { + return ""; + } + + /** + * Optional callback: debug logging. (not currently used) + */ + virtual bool tls_log_debug(const char*) { return false; } + }; + +/** +* TLS::Callbacks using std::function for compatability with the old API signatures. +* This type is only provided for backward compatibility. +* New implementations should derive from TLS::Callbacks instead. +*/ +class BOTAN_DLL Compat_Callbacks final : public Callbacks + { + public: + typedef std::function<void (const byte[], size_t)> output_fn; + typedef std::function<void (const byte[], size_t)> data_cb; + typedef std::function<void (Alert, const byte[], size_t)> alert_cb; + typedef std::function<bool (const Session&)> handshake_cb; + typedef std::function<void (const Handshake_Message&)> handshake_msg_cb; + typedef std::function<std::string (std::vector<std::string>)> next_protocol_fn; + + /** + * @param output_fn is called with data for the outbound socket + * + * @param app_data_cb is called when new application data is received + * + * @param alert_cb is called when a TLS alert is received + * + * @param handshake_cb is called when a handshake is completed + */ + BOTAN_DEPRECATED("Use TLS::Callbacks (virtual interface).") + Compat_Callbacks(output_fn out, data_cb app_data_cb, alert_cb alert_cb, + handshake_cb hs_cb, handshake_msg_cb hs_msg_cb = nullptr, + next_protocol_fn next_proto = nullptr) + : m_output_function(out), m_app_data_cb(app_data_cb), + m_alert_cb(std::bind(alert_cb, std::placeholders::_1, nullptr, 0)), + m_hs_cb(hs_cb), m_hs_msg_cb(hs_msg_cb), m_next_proto(next_proto) {} + + BOTAN_DEPRECATED("Use TLS::Callbacks (virtual interface).") + Compat_Callbacks(output_fn out, data_cb app_data_cb, + std::function<void (Alert)> alert_cb, + handshake_cb hs_cb, + handshake_msg_cb hs_msg_cb = nullptr, + next_protocol_fn next_proto = nullptr) + : m_output_function(out), m_app_data_cb(app_data_cb), + m_alert_cb(alert_cb), + m_hs_cb(hs_cb), m_hs_msg_cb(hs_msg_cb), m_next_proto(next_proto) {} + + void tls_emit_data(const byte data[], size_t size) override + { + BOTAN_ASSERT(m_output_function != nullptr, + "Invalid TLS output function callback."); + m_output_function(data, size); + } + + void tls_record_received(u64bit /*seq_no*/, const byte data[], size_t size) override + { + BOTAN_ASSERT(m_app_data_cb != nullptr, + "Invalid TLS app data callback."); + m_app_data_cb(data, size); + } + + void tls_alert(Alert alert) override + { + BOTAN_ASSERT(m_alert_cb != nullptr, + "Invalid TLS alert callback."); + m_alert_cb(alert); + } + + bool tls_session_established(const Session& session) override + { + BOTAN_ASSERT(m_hs_cb != nullptr, + "Invalid TLS handshake callback."); + return m_hs_cb(session); + } + + std::string tls_server_choose_app_protocol(const std::vector<std::string>& client_protos) override + { + if(m_next_proto != nullptr) { return m_next_proto(client_protos); } + return ""; + } + + void tls_inspect_handshake_msg(const Handshake_Message& hmsg) override + { + // The handshake message callback is optional so we can + // not assume it has been set. + if(m_hs_msg_cb != nullptr) { m_hs_msg_cb(hmsg); } + } + + private: + const output_fn m_output_function; + const data_cb m_app_data_cb; + const std::function<void (Alert)> m_alert_cb; + const handshake_cb m_hs_cb; + const handshake_msg_cb m_hs_msg_cb; + const next_protocol_fn m_next_proto; + }; + +} + +} + +#endif diff --git a/src/lib/tls/tls_channel.cpp b/src/lib/tls/tls_channel.cpp index f445eef99..5f882af64 100644 --- a/src/lib/tls/tls_channel.cpp +++ b/src/lib/tls/tls_channel.cpp @@ -1,6 +1,7 @@ /* * TLS Channels -* (C) 2011,2012,2014,2015 Jack Lloyd +* (C) 2011,2012,2014,2015,2016 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -18,32 +19,51 @@ namespace Botan { namespace TLS { -Channel::Channel(output_fn output_fn, - data_cb data_cb, - alert_cb alert_cb, - handshake_cb handshake_cb, - handshake_msg_cb handshake_msg_cb, +size_t TLS::Channel::IO_BUF_DEFAULT_SIZE = 10*1024; + +Channel::Channel(Callbacks& callbacks, Session_Manager& session_manager, RandomNumberGenerator& rng, const Policy& policy, bool is_datagram, size_t reserved_io_buffer_size) : m_is_datagram(is_datagram), - m_data_cb(data_cb), - m_alert_cb(alert_cb), - m_output_fn(output_fn), - m_handshake_cb(handshake_cb), - m_handshake_msg_cb(handshake_msg_cb), + m_callbacks(callbacks), m_session_manager(session_manager), m_policy(policy), m_rng(rng) { + init(reserved_io_buffer_size); + } + +Channel::Channel(output_fn out, + data_cb app_data_cb, + alert_cb alert_cb, + handshake_cb hs_cb, + handshake_msg_cb hs_msg_cb, + Session_Manager& session_manager, + RandomNumberGenerator& rng, + const Policy& policy, + bool is_datagram, + size_t io_buf_sz) : + m_is_datagram(is_datagram), + m_compat_callbacks(new Compat_Callbacks(out, app_data_cb, alert_cb, hs_cb, hs_msg_cb)), + m_callbacks(*m_compat_callbacks.get()), + m_session_manager(session_manager), + m_policy(policy), + m_rng(rng) + { + init(io_buf_sz); + } + +void Channel::init(size_t io_buf_sz) + { /* epoch 0 is plaintext, thus null cipher state */ m_write_cipher_states[0] = nullptr; m_read_cipher_states[0] = nullptr; - m_writebuf.reserve(reserved_io_buffer_size); - m_readbuf.reserve(reserved_io_buffer_size); + m_writebuf.reserve(io_buf_sz); + m_readbuf.reserve(io_buf_sz); } void Channel::reset_state() @@ -263,23 +283,19 @@ size_t Channel::received_data(const byte input[], size_t input_size) { while(!is_closed() && input_size) { - secure_vector<byte> record; + secure_vector<byte> record_data; u64bit record_sequence = 0; Record_Type record_type = NO_RECORD; Protocol_Version record_version; size_t consumed = 0; + Record_Raw_Input raw_input(input, input_size, consumed, m_is_datagram); + Record record(record_data, &record_sequence, &record_version, &record_type); const size_t needed = read_record(m_readbuf, - input, - input_size, - m_is_datagram, - consumed, + raw_input, record, - &record_sequence, - &record_version, - &record_type, m_sequence_numbers.get(), std::bind(&TLS::Channel::read_cipher_state_epoch, this, std::placeholders::_1)); @@ -298,105 +314,21 @@ size_t Channel::received_data(const byte input[], size_t input_size) if(input_size == 0 && needed != 0) return needed; // need more data to complete record - if(record.size() > MAX_PLAINTEXT_SIZE) + if(record_data.size() > MAX_PLAINTEXT_SIZE) throw TLS_Exception(Alert::RECORD_OVERFLOW, "TLS plaintext record is larger than allowed maximum"); if(record_type == HANDSHAKE || record_type == CHANGE_CIPHER_SPEC) { - if(!m_pending_state) - { - // No pending handshake, possibly new: - if(record_version.is_datagram_protocol()) - { - if(m_sequence_numbers) - { - /* - * Might be a peer retransmit under epoch - 1 in which - * case we must retransmit last flight - */ - sequence_numbers().read_accept(record_sequence); - - const u16bit epoch = record_sequence >> 48; - - if(epoch == sequence_numbers().current_read_epoch()) - { - create_handshake_state(record_version); - } - else if(epoch == sequence_numbers().current_read_epoch() - 1) - { - BOTAN_ASSERT(m_active_state, "Have active state here"); - m_active_state->handshake_io().add_record(unlock(record), - record_type, - record_sequence); - } - } - else if(record_sequence == 0) - { - create_handshake_state(record_version); - } - } - else - { - create_handshake_state(record_version); - } - } - - // May have been created in above conditional - if(m_pending_state) - { - m_pending_state->handshake_io().add_record(unlock(record), - record_type, - record_sequence); - - while(auto pending = m_pending_state.get()) - { - auto msg = pending->get_next_handshake_msg(); - - if(msg.first == HANDSHAKE_NONE) // no full handshake yet - break; - - process_handshake_msg(active_state(), *pending, - msg.first, msg.second); - } - } + process_handshake_ccs(record_data, record_sequence, record_type, record_version); } else if(record_type == APPLICATION_DATA) { - if(!active_state()) - throw Unexpected_Message("Application data before handshake done"); - - /* - * OpenSSL among others sends empty records in versions - * before TLS v1.1 in order to randomize the IV of the - * following record. Avoid spurious callbacks. - */ - if(record.size() > 0) - m_data_cb(record.data(), record.size()); + process_application_data(record_sequence, record_data); } else if(record_type == ALERT) { - Alert alert_msg(record); - - if(alert_msg.type() == Alert::NO_RENEGOTIATION) - m_pending_state.reset(); - - m_alert_cb(alert_msg, nullptr, 0); - - if(alert_msg.is_fatal()) - { - if(auto active = active_state()) - m_session_manager.remove_entry(active->server_hello()->session_id()); - } - - if(alert_msg.type() == Alert::CLOSE_NOTIFY) - send_warning_alert(Alert::CLOSE_NOTIFY); // reply in kind - - if(alert_msg.type() == Alert::CLOSE_NOTIFY || alert_msg.is_fatal()) - { - reset_state(); - return 0; - } + process_alert(record_data); } else if(record_type != NO_RECORD) throw Unexpected_Message("Unexpected record type " + @@ -428,6 +360,108 @@ size_t Channel::received_data(const byte input[], size_t input_size) } } +void Channel::process_handshake_ccs(const secure_vector<byte>& record, + u64bit record_sequence, + Record_Type record_type, + Protocol_Version record_version) + { + if(!m_pending_state) + { + // No pending handshake, possibly new: + if(record_version.is_datagram_protocol()) + { + if(m_sequence_numbers) + { + /* + * Might be a peer retransmit under epoch - 1 in which + * case we must retransmit last flight + */ + sequence_numbers().read_accept(record_sequence); + + const u16bit epoch = record_sequence >> 48; + + if(epoch == sequence_numbers().current_read_epoch()) + { + create_handshake_state(record_version); + } + else if(epoch == sequence_numbers().current_read_epoch() - 1) + { + BOTAN_ASSERT(m_active_state, "Have active state here"); + m_active_state->handshake_io().add_record(unlock(record), + record_type, + record_sequence); + } + } + else if(record_sequence == 0) + { + create_handshake_state(record_version); + } + } + else + { + create_handshake_state(record_version); + } + } + + // May have been created in above conditional + if(m_pending_state) + { + m_pending_state->handshake_io().add_record(unlock(record), + record_type, + record_sequence); + + while(auto pending = m_pending_state.get()) + { + auto msg = pending->get_next_handshake_msg(); + + if(msg.first == HANDSHAKE_NONE) // no full handshake yet + break; + + process_handshake_msg(active_state(), *pending, + msg.first, msg.second); + } + } + } + +void Channel::process_application_data(u64bit seq_no, const secure_vector<byte>& record) + { + if(!active_state()) + throw Unexpected_Message("Application data before handshake done"); + + /* + * OpenSSL among others sends empty records in versions + * before TLS v1.1 in order to randomize the IV of the + * following record. Avoid spurious callbacks. + */ + if(record.size() > 0) + callbacks().tls_record_received(seq_no, record.data(), record.size()); + } + +void Channel::process_alert(const secure_vector<byte>& record) + { + Alert alert_msg(record); + + if(alert_msg.type() == Alert::NO_RENEGOTIATION) + m_pending_state.reset(); + + callbacks().tls_alert(alert_msg); + + if(alert_msg.is_fatal()) + { + if(auto active = active_state()) + m_session_manager.remove_entry(active->server_hello()->session_id()); + } + + if(alert_msg.type() == Alert::CLOSE_NOTIFY) + send_warning_alert(Alert::CLOSE_NOTIFY); // reply in kind + + if(alert_msg.type() == Alert::CLOSE_NOTIFY || alert_msg.is_fatal()) + { + reset_state(); + } + } + + void Channel::write_record(Connection_Cipher_State* cipher_state, u16bit epoch, byte record_type, const byte input[], size_t length) { @@ -436,16 +470,16 @@ void Channel::write_record(Connection_Cipher_State* cipher_state, u16bit epoch, Protocol_Version record_version = (m_pending_state) ? (m_pending_state->version()) : (m_active_state->version()); + Record_Message record_message(record_type, 0, input, length); + TLS::write_record(m_writebuf, - record_type, - input, - length, + record_message, record_version, sequence_numbers().next_write_sequence(epoch), cipher_state, m_rng); - m_output_fn(m_writebuf.data(), m_writebuf.size()); + callbacks().tls_emit_data(m_writebuf.data(), m_writebuf.size()); } void Channel::send_record_array(u16bit epoch, byte type, const byte input[], size_t length) diff --git a/src/lib/tls/tls_channel.h b/src/lib/tls/tls_channel.h index e0219c242..073af760f 100644 --- a/src/lib/tls/tls_channel.h +++ b/src/lib/tls/tls_channel.h @@ -1,6 +1,7 @@ /* * TLS Channel * (C) 2011,2012,2014,2015 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -12,6 +13,7 @@ #include <botan/tls_session.h> #include <botan/tls_alert.h> #include <botan/tls_session_manager.h> +#include <botan/tls_callbacks.h> #include <botan/x509cert.h> #include <vector> #include <string> @@ -37,7 +39,20 @@ class BOTAN_DLL Channel typedef std::function<void (Alert, const byte[], size_t)> alert_cb; typedef std::function<bool (const Session&)> handshake_cb; typedef std::function<void (const Handshake_Message&)> handshake_msg_cb; + static size_t IO_BUF_DEFAULT_SIZE; + Channel(Callbacks& callbacks, + Session_Manager& session_manager, + RandomNumberGenerator& rng, + const Policy& policy, + bool is_datagram, + size_t io_buf_sz = IO_BUF_DEFAULT_SIZE); + + /** + * DEPRECATED. This constructor is only provided for backward + * compatibility and should not be used in new implementations. + */ + BOTAN_DEPRECATED("Use TLS::Channel(TLS::Callbacks ...)") Channel(output_fn out, data_cb app_data_cb, alert_cb alert_cb, @@ -47,7 +62,7 @@ class BOTAN_DLL Channel RandomNumberGenerator& rng, const Policy& policy, bool is_datagram, - size_t io_buf_sz = 16*1024); + size_t io_buf_sz = IO_BUF_DEFAULT_SIZE); Channel(const Channel&) = delete; @@ -200,10 +215,12 @@ class BOTAN_DLL Channel const Policy& policy() const { return m_policy; } - bool save_session(const Session& session) const { return m_handshake_cb(session); } + bool save_session(const Session& session) const { return callbacks().tls_session_established(session); } - handshake_msg_cb get_handshake_msg_cb() const { return m_handshake_msg_cb; } + Callbacks& callbacks() const { return m_callbacks; } private: + void init(size_t io_buf_sze); + void send_record(byte record_type, const std::vector<byte>& record); void send_record_under_epoch(u16bit epoch, byte record_type, @@ -227,14 +244,21 @@ class BOTAN_DLL Channel const Handshake_State* pending_state() const { return m_pending_state.get(); } + /* methods to handle incoming traffic through Channel::receive_data. */ + void process_handshake_ccs(const secure_vector<byte>& record, + u64bit record_sequence, + Record_Type record_type, + Protocol_Version record_version); + + void process_application_data(u64bit req_no, const secure_vector<byte>& record); + + void process_alert(const secure_vector<byte>& record); + bool m_is_datagram; /* callbacks */ - data_cb m_data_cb; - alert_cb m_alert_cb; - output_fn m_output_fn; - handshake_cb m_handshake_cb; - handshake_msg_cb m_handshake_msg_cb; + std::unique_ptr<Compat_Callbacks> m_compat_callbacks; + Callbacks& m_callbacks; /* external state */ Session_Manager& m_session_manager; diff --git a/src/lib/tls/tls_client.cpp b/src/lib/tls/tls_client.cpp index 301c77c6b..f77521c03 100644 --- a/src/lib/tls/tls_client.cpp +++ b/src/lib/tls/tls_client.cpp @@ -1,6 +1,7 @@ /* * TLS Client * (C) 2004-2011,2012,2015,2016 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -23,7 +24,7 @@ class Client_Handshake_State : public Handshake_State public: // using Handshake_State::Handshake_State; - Client_Handshake_State(Handshake_IO* io, handshake_msg_cb cb) : Handshake_State(io, cb) {} + Client_Handshake_State(Handshake_IO* io, Callbacks& cb) : Handshake_State(io, cb) {} const Public_Key& get_server_public_Key() const { @@ -42,6 +43,23 @@ class Client_Handshake_State : public Handshake_State /* * TLS Client Constructor */ +Client::Client(Callbacks& callbacks, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + const Server_Information& info, + const Protocol_Version& offer_version, + const std::vector<std::string>& next_protos, + size_t io_buf_sz) : + Channel(callbacks, session_manager, rng, policy, offer_version.is_datagram_protocol(), + io_buf_sz), + m_creds(creds), + m_info(info) + { + init(offer_version, next_protos); + } + Client::Client(output_fn output_fn, data_cb proc_cb, alert_cb alert_cb, @@ -59,10 +77,7 @@ Client::Client(output_fn output_fn, m_creds(creds), m_info(info) { - const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_info.hostname()); - - Handshake_State& state = create_handshake_state(offer_version); - send_client_hello(state, false, offer_version, srp_identifier, next_protos); + init(offer_version, next_protos); } Client::Client(output_fn output_fn, @@ -82,15 +97,22 @@ Client::Client(output_fn output_fn, m_creds(creds), m_info(info) { + init(offer_version, next_protos); + } + +void Client::init(const Protocol_Version& protocol_version, + const std::vector<std::string>& next_protocols) + { const std::string srp_identifier = m_creds.srp_identifier("tls-client", m_info.hostname()); - Handshake_State& state = create_handshake_state(offer_version); - send_client_hello(state, false, offer_version, srp_identifier, next_protos); + Handshake_State& state = create_handshake_state(protocol_version); + send_client_hello(state, false, protocol_version, + srp_identifier, next_protocols); } Handshake_State* Client::new_handshake_state(Handshake_IO* io) { - return new Client_Handshake_State(io, get_handshake_msg_cb()); + return new Client_Handshake_State(io, callbacks()); } std::vector<X509_Certificate> @@ -145,16 +167,15 @@ void Client::send_client_hello(Handshake_State& state_base, if(!state.client_hello()) // not resuming { + Client_Hello::Settings client_settings(version, m_info.hostname(), srp_identifier); state.client_hello(new Client_Hello( state.handshake_io(), state.hash(), - version, policy(), rng(), secure_renegotiation_data_for_client_hello(), - next_protocols, - m_info.hostname(), - srp_identifier)); + client_settings, + next_protocols)); } secure_renegotiation_check(state.client_hello()); @@ -419,11 +440,9 @@ void Client::process_handshake_msg(const Handshake_State* active_state, "tls-client", m_info.hostname()); - state.client_certs( - new Certificate(state.handshake_io(), - state.hash(), - client_certs) - ); + state.client_certs(new Certificate(state.handshake_io(), + state.hash(), + client_certs)); } state.client_kex( diff --git a/src/lib/tls/tls_client.h b/src/lib/tls/tls_client.h index 45a741878..09af053af 100644 --- a/src/lib/tls/tls_client.h +++ b/src/lib/tls/tls_client.h @@ -1,6 +1,7 @@ /* * TLS Client * (C) 2004-2011 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -22,9 +23,49 @@ namespace TLS { class BOTAN_DLL Client final : public Channel { public: + /** * Set up a new TLS client session * + * @param callbacks contains a set of callback function references + * required by the TLS client. + * + * @param session_manager manages session state + * + * @param creds manages application/user credentials + * + * @param policy specifies other connection policy information + * + * @param rng a random number generator + * + * @param server_info is identifying information about the TLS server + * + * @param offer_version specifies which version we will offer + * to the TLS server. + * + * @param next_protocols specifies protocols to advertise with ALPN + * + * @param reserved_io_buffer_size This many bytes of memory will + * be preallocated for the read and write buffers. Smaller + * values just mean reallocations and copies are more likely. + */ + Client(Callbacks& callbacks, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + const Server_Information& server_info = Server_Information(), + const Protocol_Version& offer_version = Protocol_Version::latest_tls_version(), + const std::vector<std::string>& next_protocols = {}, + size_t reserved_io_buffer_size = TLS::Client::IO_BUF_DEFAULT_SIZE + ); + + /** + * DEPRECATED. This constructor is only provided for backward + * compatibility and should not be used in new code. + * + * Set up a new TLS client session + * * @param output_fn is called with data for the outbound socket * * @param app_data_cb is called when new application data is received @@ -52,7 +93,7 @@ class BOTAN_DLL Client final : public Channel * be preallocated for the read and write buffers. Smaller * values just mean reallocations and copies are more likely. */ - + BOTAN_DEPRECATED("Use TLS::Client(TLS::Callbacks ...)") Client(output_fn out, data_cb app_data_cb, alert_cb alert_cb, @@ -64,9 +105,14 @@ class BOTAN_DLL Client final : public Channel const Server_Information& server_info = Server_Information(), const Protocol_Version& offer_version = Protocol_Version::latest_tls_version(), const std::vector<std::string>& next_protocols = {}, - size_t reserved_io_buffer_size = 16*1024 + size_t reserved_io_buffer_size = TLS::Client::IO_BUF_DEFAULT_SIZE ); + /** + * DEPRECATED. This constructor is only provided for backward + * compatibility and should not be used in new implementations. + */ + BOTAN_DEPRECATED("Use TLS::Client(TLS::Callbacks ...)") Client(output_fn out, data_cb app_data_cb, alert_cb alert_cb, @@ -83,6 +129,9 @@ class BOTAN_DLL Client final : public Channel const std::string& application_protocol() const { return m_application_protocol; } private: + void init(const Protocol_Version& protocol_version, + const std::vector<std::string>& next_protocols); + std::vector<X509_Certificate> get_peer_cert_chain(const Handshake_State& state) const override; diff --git a/src/lib/tls/tls_extensions.h b/src/lib/tls/tls_extensions.h index a5aac0020..cfde0067c 100644 --- a/src/lib/tls/tls_extensions.h +++ b/src/lib/tls/tls_extensions.h @@ -1,6 +1,7 @@ /* * TLS Extensions * (C) 2011,2012,2016 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -34,7 +35,6 @@ enum Handshake_Extension_Type { TLSEXT_SRP_IDENTIFIER = 12, TLSEXT_SIGNATURE_ALGORITHMS = 13, TLSEXT_USE_SRTP = 14, - TLSEXT_HEARTBEAT_SUPPORT = 15, TLSEXT_ALPN = 16, TLSEXT_EXTENDED_MASTER_SECRET = 23, diff --git a/src/lib/tls/tls_handshake_msg.h b/src/lib/tls/tls_handshake_msg.h index 7e527abf4..618ae8d76 100644 --- a/src/lib/tls/tls_handshake_msg.h +++ b/src/lib/tls/tls_handshake_msg.h @@ -1,6 +1,7 @@ /* * TLS Handshake Message * (C) 2012 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -16,6 +17,9 @@ namespace Botan { namespace TLS { +class Handshake_IO; +class Handshake_Hash; + /** * TLS Handshake Message Base Class */ diff --git a/src/lib/tls/tls_handshake_state.cpp b/src/lib/tls/tls_handshake_state.cpp index afc32ba87..71cacdabd 100644 --- a/src/lib/tls/tls_handshake_state.cpp +++ b/src/lib/tls/tls_handshake_state.cpp @@ -8,6 +8,7 @@ #include <botan/internal/tls_handshake_state.h> #include <botan/internal/tls_messages.h> #include <botan/internal/tls_record.h> +#include <botan/tls_callbacks.h> namespace Botan { @@ -174,8 +175,8 @@ std::string handshake_mask_to_string(u32bit mask) /* * Initialize the SSL/TLS Handshake State */ -Handshake_State::Handshake_State(Handshake_IO* io, handshake_msg_cb cb) : - m_msg_callback(cb), +Handshake_State::Handshake_State(Handshake_IO* io, Callbacks& cb) : + m_callbacks(cb), m_handshake_io(io), m_version(m_handshake_io->initial_record_version()) { @@ -183,6 +184,11 @@ Handshake_State::Handshake_State(Handshake_IO* io, handshake_msg_cb cb) : Handshake_State::~Handshake_State() {} +void Handshake_State::note_message(const Handshake_Message& msg) + { + m_callbacks.tls_inspect_handshake_msg(msg); + } + void Handshake_State::hello_verify_request(const Hello_Verify_Request& hello_verify) { note_message(hello_verify); diff --git a/src/lib/tls/tls_handshake_state.h b/src/lib/tls/tls_handshake_state.h index 2943a8637..bdec10d14 100644 --- a/src/lib/tls/tls_handshake_state.h +++ b/src/lib/tls/tls_handshake_state.h @@ -24,6 +24,7 @@ class KDF; namespace TLS { +class Callbacks; class Policy; class Hello_Verify_Request; @@ -45,9 +46,7 @@ class Finished; class Handshake_State { public: - typedef std::function<void (const Handshake_Message&)> handshake_msg_cb; - - Handshake_State(Handshake_IO* io, handshake_msg_cb cb); + Handshake_State(Handshake_IO* io, Callbacks& callbacks); virtual ~Handshake_State(); @@ -164,15 +163,10 @@ class Handshake_State const Handshake_Hash& hash() const { return m_handshake_hash; } - void note_message(const Handshake_Message& msg) - { - if(m_msg_callback) - m_msg_callback(msg); - } - + void note_message(const Handshake_Message& msg); private: - handshake_msg_cb m_msg_callback; + Callbacks& m_callbacks; std::unique_ptr<Handshake_IO> m_handshake_io; diff --git a/src/lib/tls/tls_messages.h b/src/lib/tls/tls_messages.h index 3bee89e13..47ff7d3d8 100644 --- a/src/lib/tls/tls_messages.h +++ b/src/lib/tls/tls_messages.h @@ -1,6 +1,7 @@ /* * TLS Messages * (C) 2004-2011,2015 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -61,6 +62,26 @@ class Hello_Verify_Request final : public Handshake_Message class Client_Hello final : public Handshake_Message { public: + class Settings + { + public: + Settings(const Protocol_Version version, + const std::string& hostname = "", + const std::string& srp_identifier = "") + : m_new_session_version(version), + m_hostname(hostname), + m_srp_identifier(srp_identifier) {}; + + const Protocol_Version protocol_version() const { return m_new_session_version; }; + const std::string& hostname() const { return m_hostname; }; + const std::string& srp_identifier() const { return m_srp_identifier; } + + private: + const Protocol_Version m_new_session_version; + const std::string m_hostname; + const std::string m_srp_identifier; + }; + Handshake_Type type() const override { return CLIENT_HELLO; } Protocol_Version version() const { return m_version; } @@ -162,13 +183,11 @@ class Client_Hello final : public Handshake_Message Client_Hello(Handshake_IO& io, Handshake_Hash& hash, - Protocol_Version version, const Policy& policy, RandomNumberGenerator& rng, const std::vector<byte>& reneg_info, - const std::vector<std::string>& next_protocols, - const std::string& hostname = "", - const std::string& srp_identifier = ""); + const Client_Hello::Settings& client_settings, + const std::vector<std::string>& next_protocols); Client_Hello(Handshake_IO& io, Handshake_Hash& hash, @@ -199,6 +218,35 @@ class Client_Hello final : public Handshake_Message class Server_Hello final : public Handshake_Message { public: + class Settings + { + public: + Settings(const std::vector<byte> new_session_id, + Protocol_Version new_session_version, + u16bit ciphersuite, + byte compression, + bool offer_session_ticket) + : m_new_session_id(new_session_id), + m_new_session_version(new_session_version), + m_ciphersuite(ciphersuite), + m_compression(compression), + m_offer_session_ticket(offer_session_ticket) {}; + + const std::vector<byte>& session_id() const { return m_new_session_id; }; + Protocol_Version protocol_version() const { return m_new_session_version; }; + u16bit ciphersuite() const { return m_ciphersuite; }; + byte compression() const { return m_compression; } + bool offer_session_ticket() const { return m_offer_session_ticket; } + + private: + const std::vector<byte> m_new_session_id; + Protocol_Version m_new_session_version; + u16bit m_ciphersuite; + byte m_compression; + bool m_offer_session_ticket; + }; + + Handshake_Type type() const override { return SERVER_HELLO; } Protocol_Version version() const { return m_version; } @@ -262,12 +310,8 @@ class Server_Hello final : public Handshake_Message RandomNumberGenerator& rng, const std::vector<byte>& secure_reneg_info, const Client_Hello& client_hello, - const std::vector<byte>& new_session_id, - Protocol_Version new_session_version, - u16bit ciphersuite, - byte compression, - bool offer_session_ticket, - const std::string& next_protocol); + const Server_Hello::Settings& settings, + const std::string next_protocol); Server_Hello(Handshake_IO& io, Handshake_Hash& hash, diff --git a/src/lib/tls/tls_record.cpp b/src/lib/tls/tls_record.cpp index eacf313a8..e028c43a0 100644 --- a/src/lib/tls/tls_record.cpp +++ b/src/lib/tls/tls_record.cpp @@ -1,6 +1,7 @@ /* * TLS Record Handling * (C) 2012,2013,2014,2015 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -152,7 +153,7 @@ Connection_Cipher_State::format_ad(u64bit msg_sequence, } void write_record(secure_vector<byte>& output, - byte msg_type, const byte msg[], size_t msg_length, + Record_Message msg, Protocol_Version version, u64bit seq, Connection_Cipher_State* cs, @@ -160,7 +161,7 @@ void write_record(secure_vector<byte>& output, { output.clear(); - output.push_back(msg_type); + output.push_back(msg.get_type()); output.push_back(version.major_version()); output.push_back(version.minor_version()); @@ -172,17 +173,17 @@ void write_record(secure_vector<byte>& output, if(!cs) // initial unencrypted handshake records { - output.push_back(get_byte(0, static_cast<u16bit>(msg_length))); - output.push_back(get_byte(1, static_cast<u16bit>(msg_length))); + output.push_back(get_byte<u16bit>(0, static_cast<u16bit>(msg.get_size()))); + output.push_back(get_byte<u16bit>(1, static_cast<u16bit>(msg.get_size()))); - output.insert(output.end(), msg, msg + msg_length); + output.insert(output.end(), msg.get_data(), msg.get_data() + msg.get_size()); return; } if(AEAD_Mode* aead = cs->aead()) { - const size_t ctext_size = aead->output_length(msg_length); + const size_t ctext_size = aead->output_length(msg.get_size()); const std::vector<byte> nonce = cs->aead_nonce(seq); @@ -193,17 +194,16 @@ void write_record(secure_vector<byte>& output, output.push_back(get_byte(0, static_cast<u16bit>(rec_size))); output.push_back(get_byte(1, static_cast<u16bit>(rec_size))); - aead->set_ad(cs->format_ad(seq, msg_type, version, static_cast<u16bit>(msg_length))); + aead->set_ad(cs->format_ad(seq, msg.get_type(), version, static_cast<u16bit>(msg.get_size()))); if(cs->nonce_bytes_from_record() > 0) { output += std::make_pair(&nonce[cs->nonce_bytes_from_handshake()], cs->nonce_bytes_from_record()); } - BOTAN_ASSERT(aead->start(nonce).empty(), "AEAD doesn't return anything from start"); const size_t offset = output.size(); - output += std::make_pair(msg, msg_length); + output += std::make_pair(msg.get_data(), msg.get_size()); aead->finish(output, offset); BOTAN_ASSERT(output.size() == offset + ctext_size, "Expected size"); @@ -213,16 +213,16 @@ void write_record(secure_vector<byte>& output, return; } - cs->mac()->update(cs->format_ad(seq, msg_type, version, static_cast<u16bit>(msg_length))); + cs->mac()->update(cs->format_ad(seq, msg.get_type(), version, static_cast<u16bit>(msg.get_size()))); - cs->mac()->update(msg, msg_length); + cs->mac()->update(msg.get_data(), msg.get_size()); const size_t block_size = cs->block_size(); const size_t iv_size = cs->iv_size(); const size_t mac_size = cs->mac_size(); const size_t buf_size = round_up( - iv_size + msg_length + mac_size + (block_size ? 1 : 0), + iv_size + msg.get_size() + mac_size + (block_size ? 1 : 0), block_size); if(buf_size > MAX_CIPHERTEXT_SIZE) @@ -239,7 +239,7 @@ void write_record(secure_vector<byte>& output, rng.randomize(&output[output.size() - iv_size], iv_size); } - output.insert(output.end(), msg, msg + msg_length); + output.insert(output.end(), msg.get_data(), msg.get_data() + msg.get_size()); output.resize(output.size() + mac_size); cs->mac()->final(&output[output.size() - mac_size]); @@ -247,7 +247,7 @@ void write_record(secure_vector<byte>& output, if(block_size) { const size_t pad_val = - buf_size - (iv_size + msg_length + mac_size + 1); + buf_size - (iv_size + msg.get_size() + mac_size + 1); for(size_t i = 0; i != pad_val + 1; ++i) output.push_back(static_cast<byte>(pad_val)); @@ -461,65 +461,58 @@ void decrypt_record(secure_vector<byte>& output, } size_t read_tls_record(secure_vector<byte>& readbuf, - const byte input[], - size_t input_sz, - size_t& consumed, - secure_vector<byte>& record, - u64bit* record_sequence, - Protocol_Version* record_version, - Record_Type* record_type, + Record_Raw_Input& raw_input, + Record& rec, Connection_Sequence_Numbers* sequence_numbers, get_cipherstate_fn get_cipherstate) { - consumed = 0; - if(readbuf.size() < TLS_HEADER_SIZE) // header incomplete? { if(size_t needed = fill_buffer_to(readbuf, - input, input_sz, consumed, + raw_input.get_data(), raw_input.get_size(), raw_input.get_consumed(), TLS_HEADER_SIZE)) return needed; BOTAN_ASSERT_EQUAL(readbuf.size(), TLS_HEADER_SIZE, "Have an entire header"); } - *record_version = Protocol_Version(readbuf[1], readbuf[2]); + *rec.get_protocol_version() = Protocol_Version(readbuf[1], readbuf[2]); - BOTAN_ASSERT(!record_version->is_datagram_protocol(), "Expected TLS"); + BOTAN_ASSERT(!rec.get_protocol_version()->is_datagram_protocol(), "Expected TLS"); - const size_t record_len = make_u16bit(readbuf[TLS_HEADER_SIZE-2], + const size_t record_size = make_u16bit(readbuf[TLS_HEADER_SIZE-2], readbuf[TLS_HEADER_SIZE-1]); - if(record_len > MAX_CIPHERTEXT_SIZE) + if(record_size > MAX_CIPHERTEXT_SIZE) throw TLS_Exception(Alert::RECORD_OVERFLOW, "Received a record that exceeds maximum size"); - if(record_len == 0) + if(record_size == 0) throw TLS_Exception(Alert::DECODE_ERROR, "Received a completely empty record"); if(size_t needed = fill_buffer_to(readbuf, - input, input_sz, consumed, - TLS_HEADER_SIZE + record_len)) + raw_input.get_data(), raw_input.get_size(), raw_input.get_consumed(), + TLS_HEADER_SIZE + record_size)) return needed; - BOTAN_ASSERT_EQUAL(static_cast<size_t>(TLS_HEADER_SIZE) + record_len, + BOTAN_ASSERT_EQUAL(static_cast<size_t>(TLS_HEADER_SIZE) + record_size, readbuf.size(), "Have the full record"); - *record_type = static_cast<Record_Type>(readbuf[0]); + *rec.get_type() = static_cast<Record_Type>(readbuf[0]); u16bit epoch = 0; if(sequence_numbers) { - *record_sequence = sequence_numbers->next_read_sequence(); + *rec.get_sequence() = sequence_numbers->next_read_sequence(); epoch = sequence_numbers->current_read_epoch(); } else { // server initial handshake case - *record_sequence = 0; + *rec.get_sequence() = 0; epoch = 0; } @@ -527,7 +520,7 @@ size_t read_tls_record(secure_vector<byte>& readbuf, if(epoch == 0) // Unencrypted initial handshake { - record.assign(readbuf.begin() + TLS_HEADER_SIZE, readbuf.begin() + TLS_HEADER_SIZE + record_len); + rec.get_data().assign(readbuf.begin() + TLS_HEADER_SIZE, readbuf.begin() + TLS_HEADER_SIZE + record_size); readbuf.clear(); return 0; // got a full record } @@ -537,37 +530,30 @@ size_t read_tls_record(secure_vector<byte>& readbuf, BOTAN_ASSERT(cs, "Have cipherstate for this epoch"); - decrypt_record(record, + decrypt_record(rec.get_data(), record_contents, - record_len, - *record_sequence, - *record_version, - *record_type, + record_size, + *rec.get_sequence(), + *rec.get_protocol_version(), + *rec.get_type(), *cs); if(sequence_numbers) - sequence_numbers->read_accept(*record_sequence); + sequence_numbers->read_accept(*rec.get_sequence()); readbuf.clear(); return 0; } size_t read_dtls_record(secure_vector<byte>& readbuf, - const byte input[], - size_t input_sz, - size_t& consumed, - secure_vector<byte>& record, - u64bit* record_sequence, - Protocol_Version* record_version, - Record_Type* record_type, + Record_Raw_Input& raw_input, + Record& rec, Connection_Sequence_Numbers* sequence_numbers, get_cipherstate_fn get_cipherstate) { - consumed = 0; - if(readbuf.size() < DTLS_HEADER_SIZE) // header incomplete? { - if(fill_buffer_to(readbuf, input, input_sz, consumed, DTLS_HEADER_SIZE)) + if(fill_buffer_to(readbuf, raw_input.get_data(), raw_input.get_size(), raw_input.get_consumed(), DTLS_HEADER_SIZE)) { readbuf.clear(); return 0; @@ -576,38 +562,35 @@ size_t read_dtls_record(secure_vector<byte>& readbuf, BOTAN_ASSERT_EQUAL(readbuf.size(), DTLS_HEADER_SIZE, "Have an entire header"); } - *record_version = Protocol_Version(readbuf[1], readbuf[2]); + *rec.get_protocol_version() = Protocol_Version(readbuf[1], readbuf[2]); - BOTAN_ASSERT(record_version->is_datagram_protocol(), "Expected DTLS"); + BOTAN_ASSERT(rec.get_protocol_version()->is_datagram_protocol(), "Expected DTLS"); - const size_t record_len = make_u16bit(readbuf[DTLS_HEADER_SIZE-2], - readbuf[DTLS_HEADER_SIZE-1]); + const size_t record_size = make_u16bit(readbuf[DTLS_HEADER_SIZE-2], + readbuf[DTLS_HEADER_SIZE-1]); - // Invalid packet: - if(record_len == 0 || record_len > MAX_CIPHERTEXT_SIZE) - { - readbuf.clear(); - return 0; - } + if(record_size > MAX_CIPHERTEXT_SIZE) + throw TLS_Exception(Alert::RECORD_OVERFLOW, + "Got message that exceeds maximum size"); - if(fill_buffer_to(readbuf, input, input_sz, consumed, DTLS_HEADER_SIZE + record_len)) + if(fill_buffer_to(readbuf, raw_input.get_data(), raw_input.get_size(), raw_input.get_consumed(), DTLS_HEADER_SIZE + record_size)) { // Truncated packet? readbuf.clear(); return 0; } - BOTAN_ASSERT_EQUAL(static_cast<size_t>(DTLS_HEADER_SIZE) + record_len, readbuf.size(), + BOTAN_ASSERT_EQUAL(static_cast<size_t>(DTLS_HEADER_SIZE) + record_size, readbuf.size(), "Have the full record"); - *record_type = static_cast<Record_Type>(readbuf[0]); + *rec.get_type() = static_cast<Record_Type>(readbuf[0]); u16bit epoch = 0; - *record_sequence = load_be<u64bit>(&readbuf[3], 0); - epoch = (*record_sequence >> 48); + *rec.get_sequence() = load_be<u64bit>(&readbuf[3], 0); + epoch = (*rec.get_sequence() >> 48); - if(sequence_numbers && sequence_numbers->already_seen(*record_sequence)) + if(sequence_numbers && sequence_numbers->already_seen(*rec.get_sequence())) { readbuf.clear(); return 0; @@ -617,7 +600,7 @@ size_t read_dtls_record(secure_vector<byte>& readbuf, if(epoch == 0) // Unencrypted initial handshake { - record.assign(readbuf.begin() + DTLS_HEADER_SIZE, readbuf.begin() + DTLS_HEADER_SIZE + record_len); + rec.get_data().assign(readbuf.begin() + DTLS_HEADER_SIZE, readbuf.begin() + DTLS_HEADER_SIZE + record_size); readbuf.clear(); return 0; // got a full record } @@ -629,23 +612,23 @@ size_t read_dtls_record(secure_vector<byte>& readbuf, BOTAN_ASSERT(cs, "Have cipherstate for this epoch"); - decrypt_record(record, + decrypt_record(rec.get_data(), record_contents, - record_len, - *record_sequence, - *record_version, - *record_type, + record_size, + *rec.get_sequence(), + *rec.get_protocol_version(), + *rec.get_type(), *cs); } catch(std::exception) { readbuf.clear(); - *record_type = NO_RECORD; + *rec.get_type() = NO_RECORD; return 0; } if(sequence_numbers) - sequence_numbers->read_accept(*record_sequence); + sequence_numbers->read_accept(*rec.get_sequence()); readbuf.clear(); return 0; @@ -654,24 +637,16 @@ size_t read_dtls_record(secure_vector<byte>& readbuf, } size_t read_record(secure_vector<byte>& readbuf, - const byte input[], - size_t input_sz, - bool is_datagram, - size_t& consumed, - secure_vector<byte>& record, - u64bit* record_sequence, - Protocol_Version* record_version, - Record_Type* record_type, + Record_Raw_Input& raw_input, + Record& rec, Connection_Sequence_Numbers* sequence_numbers, get_cipherstate_fn get_cipherstate) { - if(is_datagram) - return read_dtls_record(readbuf, input, input_sz, consumed, - record, record_sequence, record_version, record_type, + if(raw_input.is_datagram()) + return read_dtls_record(readbuf, raw_input, rec, sequence_numbers, get_cipherstate); else - return read_tls_record(readbuf, input, input_sz, consumed, - record, record_sequence, record_version, record_type, + return read_tls_record(readbuf, raw_input, rec, sequence_numbers, get_cipherstate); } diff --git a/src/lib/tls/tls_record.h b/src/lib/tls/tls_record.h index e3b0b9b58..c16b919b5 100644 --- a/src/lib/tls/tls_record.h +++ b/src/lib/tls/tls_record.h @@ -1,6 +1,7 @@ /* * TLS Record Handling * (C) 2004-2012 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -90,6 +91,80 @@ class Connection_Cipher_State size_t m_iv_size = 0; }; +class Record + { + public: + Record(secure_vector<byte>& data, + u64bit* sequence, + Protocol_Version* protocol_version, + Record_Type* type) + : m_data(data), m_sequence(sequence), m_protocol_version(protocol_version), + m_type(type), m_size(data.size()) {}; + + secure_vector<byte>& get_data() { return m_data; } + + Protocol_Version* get_protocol_version() { return m_protocol_version; } + + u64bit* get_sequence() { return m_sequence; } + + Record_Type* get_type() { return m_type; } + + size_t& get_size() { return m_size; } + + private: + secure_vector<byte>& m_data; + u64bit* m_sequence; + Protocol_Version* m_protocol_version; + Record_Type* m_type; + size_t m_size; + }; + +class Record_Message + { + public: + Record_Message(const byte* data, size_t size) + : m_type(0), m_sequence(0), m_data(data), m_size(size) {}; + Record_Message(byte type, u64bit sequence, const byte* data, size_t size) + : m_type(type), m_sequence(sequence), m_data(data), + m_size(size) {}; + + byte& get_type() { return m_type; }; + u64bit& get_sequence() { return m_sequence; }; + const byte* get_data() { return m_data; }; + size_t& get_size() { return m_size; }; + + private: + byte m_type; + u64bit m_sequence; + const byte* m_data; + size_t m_size; +}; + +class Record_Raw_Input + { + public: + Record_Raw_Input(const byte* data, size_t size, size_t& consumed, + bool is_datagram) + : m_data(data), m_size(size), m_consumed(consumed), + m_is_datagram(is_datagram) {}; + + const byte*& get_data() { return m_data; }; + + size_t& get_size() { return m_size; }; + + size_t& get_consumed() { return m_consumed; }; + void set_consumed(size_t consumed) { m_consumed = consumed; } + + bool is_datagram() { return m_is_datagram; }; + + private: + const byte* m_data; + size_t m_size; + size_t& m_consumed; + bool m_is_datagram; + }; + + /** * Create a TLS record * @param write_buffer the output record is placed here @@ -103,7 +178,7 @@ class Connection_Cipher_State * @return number of bytes written to write_buffer */ void write_record(secure_vector<byte>& write_buffer, - byte msg_type, const byte msg[], size_t msg_length, + Record_Message rec_msg, Protocol_Version version, u64bit msg_sequence, Connection_Cipher_State* cipherstate, @@ -117,14 +192,8 @@ typedef std::function<std::shared_ptr<Connection_Cipher_State> (u16bit)> get_cip * @return zero if full message, else number of bytes still needed */ size_t read_record(secure_vector<byte>& read_buffer, - const byte input[], - size_t input_length, - bool is_datagram, - size_t& input_consumed, - secure_vector<byte>& record, - u64bit* record_sequence, - Protocol_Version* record_version, - Record_Type* record_type, + Record_Raw_Input& raw_input, + Record& rec, Connection_Sequence_Numbers* sequence_numbers, get_cipherstate_fn get_cipherstate); diff --git a/src/lib/tls/tls_server.cpp b/src/lib/tls/tls_server.cpp index 41b14ae08..2e546eab1 100644 --- a/src/lib/tls/tls_server.cpp +++ b/src/lib/tls/tls_server.cpp @@ -1,6 +1,7 @@ /* * TLS Server * (C) 2004-2011,2012,2016 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -9,30 +10,41 @@ #include <botan/internal/tls_handshake_state.h> #include <botan/internal/tls_messages.h> #include <botan/internal/stl_util.h> +#include <botan/tls_magic.h> namespace Botan { namespace TLS { -namespace { - class Server_Handshake_State : public Handshake_State { public: - // using Handshake_State::Handshake_State; + Server_Handshake_State(Handshake_IO* io, Callbacks& cb) + : Handshake_State(io, cb) {} + + Private_Key* server_rsa_kex_key() { return m_server_rsa_kex_key; } + void set_server_rsa_kex_key(Private_Key* key) + { m_server_rsa_kex_key = key; } + + bool allow_session_resumption() const + { return m_allow_session_resumption; } + void set_allow_session_resumption(bool allow_session_resumption) + { m_allow_session_resumption = allow_session_resumption; } - Server_Handshake_State(Handshake_IO* io, handshake_msg_cb cb) : Handshake_State(io, cb) {} + private: // Used by the server only, in case of RSA key exchange. Not owned - Private_Key* server_rsa_kex_key = nullptr; + Private_Key* m_server_rsa_kex_key = nullptr; /* * Used by the server to know if resumption should be allowed on * a server-initiated renegotiation */ - bool allow_session_resumption = true; + bool m_allow_session_resumption = true; }; +namespace { + bool check_for_resume(Session& session_info, Session_Manager& session_manager, Credentials_Manager& credentials, @@ -225,6 +237,19 @@ get_server_certs(const std::string& hostname, /* * TLS Server Constructor */ +Server::Server(Callbacks& callbacks, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + bool is_datagram, + size_t io_buf_sz) : + Channel(callbacks, session_manager, rng, policy, + is_datagram, io_buf_sz), + m_creds(creds) + { + } + Server::Server(output_fn output, data_cb data_cb, alert_cb alert_cb, @@ -236,13 +261,15 @@ Server::Server(output_fn output, next_protocol_fn next_proto, bool is_datagram, size_t io_buf_sz) : - Channel(output, data_cb, alert_cb, handshake_cb, Channel::handshake_msg_cb(), - session_manager, rng, policy, is_datagram, io_buf_sz), + Channel(output, data_cb, alert_cb, handshake_cb, + Channel::handshake_msg_cb(), session_manager, + rng, policy, is_datagram, io_buf_sz), m_creds(creds), m_choose_next_protocol(next_proto) { } + Server::Server(output_fn output, data_cb data_cb, alert_cb alert_cb, @@ -263,8 +290,7 @@ Server::Server(output_fn output, Handshake_State* Server::new_handshake_state(Handshake_IO* io) { - std::unique_ptr<Handshake_State> state( - new Server_Handshake_State(io, get_handshake_msg_cb())); + std::unique_ptr<Handshake_State> state(new Server_Handshake_State(io, callbacks())); state->set_expected_next(CLIENT_HELLO); return state.release(); @@ -284,441 +310,516 @@ Server::get_peer_cert_chain(const Handshake_State& state) const void Server::initiate_handshake(Handshake_State& state, bool force_full_renegotiation) { - dynamic_cast<Server_Handshake_State&>(state).allow_session_resumption = - !force_full_renegotiation; + dynamic_cast<Server_Handshake_State&>(state). + set_allow_session_resumption(!force_full_renegotiation); Hello_Request hello_req(state.handshake_io()); } /* -* Process a handshake message +* Process a CLIENT HELLO Message */ -void Server::process_handshake_msg(const Handshake_State* active_state, - Handshake_State& state_base, - Handshake_Type type, - const std::vector<byte>& contents) - { - Server_Handshake_State& state = dynamic_cast<Server_Handshake_State&>(state_base); - - state.confirm_transition_to(type); - - /* - * The change cipher spec message isn't technically a handshake - * message so it's not included in the hash. The finished and - * certificate verify messages are verified based on the current - * state of the hash *before* this message so we delay adding them - * to the hash computation until we've processed them below. - */ - if(type != HANDSHAKE_CCS && type != FINISHED && type != CERTIFICATE_VERIFY) +void Server::process_client_hello_msg(const Handshake_State* active_state, + Server_Handshake_State& pending_state, + const std::vector<byte>& contents) +{ + const bool initial_handshake = !active_state; + + if(!policy().allow_insecure_renegotiation() && + !(initial_handshake || secure_renegotiation_supported())) { - state.hash().update(state.handshake_io().format(contents, type)); + send_warning_alert(Alert::NO_RENEGOTIATION); + return; } - if(type == CLIENT_HELLO) + pending_state.client_hello(new Client_Hello(contents)); + const Protocol_Version client_version = pending_state.client_hello()->version(); + + Protocol_Version negotiated_version; + + const Protocol_Version latest_supported = + policy().latest_supported_version(client_version.is_datagram_protocol()); + + if((initial_handshake && client_version.known_version()) || + (!initial_handshake && client_version == active_state->version())) { - const bool initial_handshake = !active_state; + /* + Common cases: new client hello with some known version, or a + renegotiation using the same version as previously + negotiated. + */ - if(!policy().allow_insecure_renegotiation() && - !(initial_handshake || secure_renegotiation_supported())) + negotiated_version = client_version; + } + else if(!initial_handshake && (client_version != active_state->version())) + { + /* + * If this is a renegotiation, and the client has offered a + * later version than what it initially negotiated, negotiate + * the old version. This matches OpenSSL's behavior. If the + * client is offering a version earlier than what it initially + * negotiated, reject as a probable attack. + */ + if(active_state->version() > client_version) { - send_warning_alert(Alert::NO_RENEGOTIATION); - return; + throw TLS_Exception(Alert::PROTOCOL_VERSION, + "Client negotiated " + + active_state->version().to_string() + + " then renegotiated with " + + client_version.to_string()); } + else + negotiated_version = active_state->version(); + } + else + { + /* + New negotiation using a version we don't know. Offer them the + best we currently know and support + */ + negotiated_version = latest_supported; + } - state.client_hello(new Client_Hello(contents)); + if(!policy().acceptable_protocol_version(negotiated_version)) + { + throw TLS_Exception(Alert::PROTOCOL_VERSION, + "Client version " + negotiated_version.to_string() + + " is unacceptable by policy"); + } - const Protocol_Version client_version = state.client_hello()->version(); + if(pending_state.client_hello()->sent_fallback_scsv()) + { + if(latest_supported > client_version) + throw TLS_Exception(Alert::INAPPROPRIATE_FALLBACK, + "Client signalled fallback SCSV, possible attack"); + } - Protocol_Version negotiated_version; + secure_renegotiation_check(pending_state.client_hello()); - const Protocol_Version latest_supported = - policy().latest_supported_version(client_version.is_datagram_protocol()); + pending_state.set_version(negotiated_version); - if((initial_handshake && client_version.known_version()) || - (!initial_handshake && client_version == active_state->version())) - { - /* - Common cases: new client hello with some known version, or a - renegotiation using the same version as previously - negotiated. - */ + Session session_info; + const bool resuming = + pending_state.allow_session_resumption() && + check_for_resume(session_info, + session_manager(), + m_creds, + pending_state.client_hello(), + std::chrono::seconds(policy().session_ticket_lifetime())); - negotiated_version = client_version; - } - else if(!initial_handshake && (client_version != active_state->version())) - { - /* - * If this is a renegotiation, and the client has offered a - * later version than what it initially negotiated, negotiate - * the old version. This matches OpenSSL's behavior. If the - * client is offering a version earlier than what it initially - * negotiated, reject as a probable attack. - */ - if(active_state->version() > client_version) - { - throw TLS_Exception(Alert::PROTOCOL_VERSION, - "Client negotiated " + - active_state->version().to_string() + - " then renegotiated with " + - client_version.to_string()); - } - else - negotiated_version = active_state->version(); - } - else - { - /* - New negotiation using a version we don't know. Offer them the - best we currently know and support - */ - negotiated_version = latest_supported; - } + bool have_session_ticket_key = false; - if(!policy().acceptable_protocol_version(negotiated_version)) - { - throw TLS_Exception(Alert::PROTOCOL_VERSION, - "Client version " + negotiated_version.to_string() + - " is unacceptable by policy"); - } + try + { + have_session_ticket_key = + m_creds.psk("tls-server", "session-ticket", "").length() > 0; + } + catch(...) {} + + m_next_protocol = ""; + if(pending_state.client_hello()->supports_alpn()) + { + m_next_protocol = callbacks().tls_server_choose_app_protocol(pending_state.client_hello()->next_protocols()); - if(state.client_hello()->sent_fallback_scsv()) + // if the callback return was empty, fall back to the (deprecated) std::function + if(m_next_protocol.empty() && m_choose_next_protocol) { - if(latest_supported > client_version) - throw TLS_Exception(Alert::INAPPROPRIATE_FALLBACK, - "Client signalled fallback SCSV, possible attack"); + m_next_protocol = m_choose_next_protocol(pending_state.client_hello()->next_protocols()); } + } - secure_renegotiation_check(state.client_hello()); + if(resuming) + { + this->session_resume(pending_state, have_session_ticket_key, session_info); + } + else // new session + { + this->session_create(pending_state, have_session_ticket_key); + } +} - state.set_version(negotiated_version); +void Server::process_certificate_msg(Server_Handshake_State& pending_state, + const std::vector<byte>& contents) +{ + pending_state.client_certs(new Certificate(contents)); + pending_state.set_expected_next(CLIENT_KEX); +} - Session session_info; - const bool resuming = - state.allow_session_resumption && - check_for_resume(session_info, - session_manager(), - m_creds, - state.client_hello(), - std::chrono::seconds(policy().session_ticket_lifetime())); +void Server::process_client_key_exchange_msg(Server_Handshake_State& pending_state, + const std::vector<byte>& contents) +{ + if(pending_state.received_handshake_msg(CERTIFICATE) && !pending_state.client_certs()->empty()) + pending_state.set_expected_next(CERTIFICATE_VERIFY); + else + pending_state.set_expected_next(HANDSHAKE_CCS); - bool have_session_ticket_key = false; + pending_state.client_kex( + new Client_Key_Exchange(contents, pending_state, + pending_state.server_rsa_kex_key(), + m_creds, policy(), rng()) + ); - try - { - have_session_ticket_key = - m_creds.psk("tls-server", "session-ticket", "").length() > 0; - } - catch(...) {} + pending_state.compute_session_keys(); +} - m_next_protocol = ""; - if(m_choose_next_protocol && state.client_hello()->supports_alpn()) - m_next_protocol = m_choose_next_protocol(state.client_hello()->next_protocols()); +void Server::process_change_cipher_spec_msg(Server_Handshake_State& pending_state) +{ + pending_state.set_expected_next(FINISHED); + change_cipher_spec_reader(SERVER); +} - if(resuming) - { - // Only offer a resuming client a new ticket if they didn't send one this time, - // ie, resumed via server-side resumption. TODO: also send one if expiring soon? - - const bool offer_new_session_ticket = - (state.client_hello()->supports_session_ticket() && - state.client_hello()->session_ticket().empty() && - have_session_ticket_key); - - state.server_hello(new Server_Hello( - state.handshake_io(), - state.hash(), - policy(), - rng(), - secure_renegotiation_data_for_server_hello(), - *state.client_hello(), - session_info, - offer_new_session_ticket, - m_next_protocol - )); - - secure_renegotiation_check(state.server_hello()); - - state.compute_session_keys(session_info.master_secret()); - - if(!save_session(session_info)) - { - session_manager().remove_entry(session_info.session_id()); - - if(state.server_hello()->supports_session_ticket()) // send an empty ticket - { - state.new_session_ticket( - new New_Session_Ticket(state.handshake_io(), - state.hash()) - ); - } - } +void Server::process_certificate_verify_msg(Server_Handshake_State& pending_state, + Handshake_Type type, + const std::vector<byte>& contents) +{ + pending_state.client_verify ( new Certificate_Verify ( contents, pending_state.version() ) ); + + const std::vector<X509_Certificate>& client_certs = + pending_state.client_certs()->cert_chain(); + + const bool sig_valid = + pending_state.client_verify()->verify ( client_certs[0], pending_state, policy() ); + + pending_state.hash().update ( pending_state.handshake_io().format ( contents, type ) ); + + /* + * Using DECRYPT_ERROR looks weird here, but per RFC 4346 is for + * "A handshake cryptographic operation failed, including being + * unable to correctly verify a signature, ..." + */ + if ( !sig_valid ) + throw TLS_Exception ( Alert::DECRYPT_ERROR, "Client cert verify failed" ); + + try + { + m_creds.verify_certificate_chain ( "tls-server", "", client_certs ); + } + catch ( std::exception& e ) + { + throw TLS_Exception ( Alert::BAD_CERTIFICATE, e.what() ); + } + + pending_state.set_expected_next ( HANDSHAKE_CCS ); +} - if(state.server_hello()->supports_session_ticket() && !state.new_session_ticket()) - { - try - { - const SymmetricKey ticket_key = m_creds.psk("tls-server", "session-ticket", ""); - - state.new_session_ticket( - new New_Session_Ticket(state.handshake_io(), - state.hash(), - session_info.encrypt(ticket_key, rng()), - policy().session_ticket_lifetime()) - ); - } - catch(...) {} - - if(!state.new_session_ticket()) - { - state.new_session_ticket( - new New_Session_Ticket(state.handshake_io(), state.hash()) - ); - } - } +void Server::process_finished_msg(Server_Handshake_State& pending_state, + Handshake_Type type, + const std::vector<byte>& contents) +{ + pending_state.set_expected_next ( HANDSHAKE_NONE ); - state.handshake_io().send(Change_Cipher_Spec()); + pending_state.client_finished ( new Finished ( contents ) ); - change_cipher_spec_writer(SERVER); + if ( !pending_state.client_finished()->verify ( pending_state, CLIENT ) ) + throw TLS_Exception ( Alert::DECRYPT_ERROR, + "Finished message didn't verify" ); - state.server_finished(new Finished(state.handshake_io(), state, SERVER)); - state.set_expected_next(HANDSHAKE_CCS); - } - else // new session - { - std::map<std::string, std::vector<X509_Certificate> > cert_chains; + if ( !pending_state.server_finished() ) + { + // already sent finished if resuming, so this is a new session - const std::string sni_hostname = state.client_hello()->sni_hostname(); + pending_state.hash().update ( pending_state.handshake_io().format ( contents, type ) ); - cert_chains = get_server_certs(sni_hostname, m_creds); + Session session_info( + pending_state.server_hello()->session_id(), + pending_state.session_keys().master_secret(), + pending_state.server_hello()->version(), + pending_state.server_hello()->ciphersuite(), + pending_state.server_hello()->compression_method(), + SERVER, + pending_state.server_hello()->supports_extended_master_secret(), + get_peer_cert_chain ( pending_state ), + std::vector<byte>(), + Server_Information(pending_state.client_hello()->sni_hostname()), + pending_state.srp_identifier(), + pending_state.server_hello()->srtp_profile() + ); - if(sni_hostname != "" && cert_chains.empty()) + if ( save_session ( session_info ) ) { - cert_chains = get_server_certs("", m_creds); - - /* - * Only send the unrecognized_name alert if we couldn't - * find any certs for the requested name but did find at - * least one cert to use in general. That avoids sending an - * unrecognized_name when a server is configured for purely - * anonymous operation. - */ - if(!cert_chains.empty()) - send_alert(Alert(Alert::UNRECOGNIZED_NAME)); + if ( pending_state.server_hello()->supports_session_ticket() ) + { + try + { + const SymmetricKey ticket_key = m_creds.psk ( "tls-server", "session-ticket", "" ); + + pending_state.new_session_ticket ( + new New_Session_Ticket ( pending_state.handshake_io(), + pending_state.hash(), + session_info.encrypt ( ticket_key, rng() ), + policy().session_ticket_lifetime() ) + ); + } + catch ( ... ) {} + } + else + session_manager().save ( session_info ); } - state.server_hello(new Server_Hello( - state.handshake_io(), - state.hash(), - policy(), - rng(), - secure_renegotiation_data_for_server_hello(), - *state.client_hello(), - make_hello_random(rng(), policy()), // new session ID - state.version(), - choose_ciphersuite(policy(), state.version(), m_creds, cert_chains, state.client_hello()), - choose_compression(policy(), state.client_hello()->compression_methods()), - have_session_ticket_key, - m_next_protocol) + if ( !pending_state.new_session_ticket() && + pending_state.server_hello()->supports_session_ticket() ) + { + pending_state.new_session_ticket ( + new New_Session_Ticket ( pending_state.handshake_io(), pending_state.hash() ) ); + } - secure_renegotiation_check(state.server_hello()); + pending_state.handshake_io().send ( Change_Cipher_Spec() ); - const std::string sig_algo = state.ciphersuite().sig_algo(); - const std::string kex_algo = state.ciphersuite().kex_algo(); + change_cipher_spec_writer ( SERVER ); - if(sig_algo != "") - { - BOTAN_ASSERT(!cert_chains[sig_algo].empty(), - "Attempting to send empty certificate chain"); + pending_state.server_finished ( new Finished ( pending_state.handshake_io(), pending_state, SERVER ) ); + } - state.server_certs(new Certificate(state.handshake_io(), - state.hash(), - cert_chains[sig_algo])); - } + activate_session(); - Private_Key* private_key = nullptr; +} - if(kex_algo == "RSA" || sig_algo != "") - { - private_key = m_creds.private_key_for( - state.server_certs()->cert_chain()[0], - "tls-server", - sni_hostname); +/* +* Process a handshake message +*/ +void Server::process_handshake_msg(const Handshake_State* active_state, + Handshake_State& state_base, + Handshake_Type type, + const std::vector<byte>& contents) + { + Server_Handshake_State& state = dynamic_cast<Server_Handshake_State&>(state_base); + state.confirm_transition_to(type); - if(!private_key) - throw Internal_Error("No private key located for associated server cert"); - } + /* + * The change cipher spec message isn't technically a handshake + * message so it's not included in the hash. The finished and + * certificate verify messages are verified based on the current + * state of the hash *before* this message so we delay adding them + * to the hash computation until we've processed them below. + */ + if(type != HANDSHAKE_CCS && type != FINISHED && type != CERTIFICATE_VERIFY) + { + state.hash().update(state.handshake_io().format(contents, type)); + } - if(kex_algo == "RSA") - { - state.server_rsa_kex_key = private_key; - } - else - { - state.server_kex(new Server_Key_Exchange(state.handshake_io(), - state, policy(), - m_creds, rng(), private_key)); - } + switch(type) + { + case CLIENT_HELLO: + this->process_client_hello_msg(active_state, state, contents); + break; - auto trusted_CAs = m_creds.trusted_certificate_authorities("tls-server", sni_hostname); + case CERTIFICATE: + this->process_certificate_msg(state, contents); + break; - std::vector<X509_DN> client_auth_CAs; + case CLIENT_KEX: + this->process_client_key_exchange_msg(state, contents); + break; + + case CERTIFICATE_VERIFY: + this->process_certificate_verify_msg(state, type, contents); + break; + + case HANDSHAKE_CCS: + this->process_change_cipher_spec_msg(state); + break; + + case FINISHED: + this->process_finished_msg(state, type, contents); + break; + + default: + throw Unexpected_Message("Unknown handshake message received"); + break; + } + } - for(auto store : trusted_CAs) +void Server::session_resume(Server_Handshake_State& pending_state, + bool have_session_ticket_key, + Session& session_info) + { + // Only offer a resuming client a new ticket if they didn't send one this time, + // ie, resumed via server-side resumption. TODO: also send one if expiring soon? + + const bool offer_new_session_ticket = + (pending_state.client_hello()->supports_session_ticket() && + pending_state.client_hello()->session_ticket().empty() && + have_session_ticket_key); + + pending_state.server_hello(new Server_Hello( + pending_state.handshake_io(), + pending_state.hash(), + policy(), + rng(), + secure_renegotiation_data_for_server_hello(), + *pending_state.client_hello(), + session_info, + offer_new_session_ticket, + m_next_protocol + )); + + secure_renegotiation_check(pending_state.server_hello()); + + pending_state.compute_session_keys(session_info.master_secret()); + + if(!save_session(session_info)) + { + session_manager().remove_entry(session_info.session_id()); + + if(pending_state.server_hello()->supports_session_ticket()) // send an empty ticket { - auto subjects = store->all_subjects(); - client_auth_CAs.insert(client_auth_CAs.end(), subjects.begin(), subjects.end()); + pending_state.new_session_ticket( + new New_Session_Ticket(pending_state.handshake_io(), + pending_state.hash()) + ); } + } - if(!client_auth_CAs.empty() && state.ciphersuite().sig_algo() != "") + if(pending_state.server_hello()->supports_session_ticket() && !pending_state.new_session_ticket()) + { + try { - state.cert_req( - new Certificate_Req(state.handshake_io(), state.hash(), - policy(), client_auth_CAs, state.version())); + const SymmetricKey ticket_key = m_creds.psk("tls-server", "session-ticket", ""); - state.set_expected_next(CERTIFICATE); + pending_state.new_session_ticket( + new New_Session_Ticket(pending_state.handshake_io(), + pending_state.hash(), + session_info.encrypt(ticket_key, rng()), + policy().session_ticket_lifetime()) + ); } + catch(...) {} - /* - * If the client doesn't have a cert they want to use they are - * allowed to send either an empty cert message or proceed - * directly to the client key exchange, so allow either case. - */ - state.set_expected_next(CLIENT_KEX); - - state.server_hello_done(new Server_Hello_Done(state.handshake_io(), state.hash())); + if(!pending_state.new_session_ticket()) + { + pending_state.new_session_ticket( + new New_Session_Ticket(pending_state.handshake_io(), pending_state.hash()) + ); + } } - } - else if(type == CERTIFICATE) - { - state.client_certs(new Certificate(contents)); - state.set_expected_next(CLIENT_KEX); - } - else if(type == CLIENT_KEX) - { - if(state.received_handshake_msg(CERTIFICATE) && !state.client_certs()->empty()) - state.set_expected_next(CERTIFICATE_VERIFY); - else - state.set_expected_next(HANDSHAKE_CCS); + pending_state.handshake_io().send(Change_Cipher_Spec()); - state.client_kex( - new Client_Key_Exchange(contents, state, - state.server_rsa_kex_key, - m_creds, policy(), rng()) - ); + change_cipher_spec_writer(SERVER); - state.compute_session_keys(); - } - else if(type == CERTIFICATE_VERIFY) - { - state.client_verify(new Certificate_Verify(contents, state.version())); + pending_state.server_finished(new Finished(pending_state.handshake_io(), pending_state, SERVER)); + pending_state.set_expected_next(HANDSHAKE_CCS); + } + +void Server::session_create(Server_Handshake_State& pending_state, + bool have_session_ticket_key) + { + std::map<std::string, std::vector<X509_Certificate> > cert_chains; - const std::vector<X509_Certificate>& client_certs = - state.client_certs()->cert_chain(); + const std::string sni_hostname = pending_state.client_hello()->sni_hostname(); - const bool sig_valid = - state.client_verify()->verify(client_certs[0], state, policy()); + cert_chains = get_server_certs(sni_hostname, m_creds); - state.hash().update(state.handshake_io().format(contents, type)); + if(sni_hostname != "" && cert_chains.empty()) + { + cert_chains = get_server_certs("", m_creds); /* - * Using DECRYPT_ERROR looks weird here, but per RFC 4346 is for - * "A handshake cryptographic operation failed, including being - * unable to correctly verify a signature, ..." + * Only send the unrecognized_name alert if we couldn't + * find any certs for the requested name but did find at + * least one cert to use in general. That avoids sending an + * unrecognized_name when a server is configured for purely + * anonymous operation. */ - if(!sig_valid) - throw TLS_Exception(Alert::DECRYPT_ERROR, "Client cert verify failed"); - - try - { - m_creds.verify_certificate_chain("tls-server", "", client_certs); - } - catch(std::exception& e) - { - throw TLS_Exception(Alert::BAD_CERTIFICATE, e.what()); - } - - state.set_expected_next(HANDSHAKE_CCS); + if(!cert_chains.empty()) + send_alert(Alert(Alert::UNRECOGNIZED_NAME)); } - else if(type == HANDSHAKE_CCS) + + Server_Hello::Settings srv_settings( + make_hello_random(rng(), policy()), // new session ID + pending_state.version(), + choose_ciphersuite(policy(), + pending_state.version(), + m_creds, + cert_chains, + pending_state.client_hello()), + choose_compression(policy(), + pending_state.client_hello()->compression_methods()), + have_session_ticket_key); + + pending_state.server_hello(new Server_Hello( + pending_state.handshake_io(), + pending_state.hash(), + policy(), + rng(), + secure_renegotiation_data_for_server_hello(), + *pending_state.client_hello(), + srv_settings, + m_next_protocol) + ); + + secure_renegotiation_check(pending_state.server_hello()); + + const std::string sig_algo = pending_state.ciphersuite().sig_algo(); + const std::string kex_algo = pending_state.ciphersuite().kex_algo(); + + if(sig_algo != "") { - state.set_expected_next(FINISHED); - change_cipher_spec_reader(SERVER); + BOTAN_ASSERT(!cert_chains[sig_algo].empty(), + "Attempting to send empty certificate chain"); + + pending_state.server_certs(new Certificate(pending_state.handshake_io(), + pending_state.hash(), + cert_chains[sig_algo])); } - else if(type == FINISHED) - { - state.set_expected_next(HANDSHAKE_NONE); - state.client_finished(new Finished(contents)); + Private_Key* private_key = nullptr; - if(!state.client_finished()->verify(state, CLIENT)) - throw TLS_Exception(Alert::DECRYPT_ERROR, - "Finished message didn't verify"); + if(kex_algo == "RSA" || sig_algo != "") + { + private_key = m_creds.private_key_for( + pending_state.server_certs()->cert_chain()[0], + "tls-server", + sni_hostname); - if(!state.server_finished()) - { - // already sent finished if resuming, so this is a new session + if(!private_key) + throw Internal_Error("No private key located for associated server cert"); + } - state.hash().update(state.handshake_io().format(contents, type)); + if(kex_algo == "RSA") + { + pending_state.set_server_rsa_kex_key(private_key); + } + else + { - Session session_info( - state.server_hello()->session_id(), - state.session_keys().master_secret(), - state.server_hello()->version(), - state.server_hello()->ciphersuite(), - state.server_hello()->compression_method(), - SERVER, - state.server_hello()->supports_extended_master_secret(), - get_peer_cert_chain(state), - std::vector<byte>(), - Server_Information(state.client_hello()->sni_hostname()), - state.srp_identifier(), - state.server_hello()->srtp_profile() - ); + pending_state.server_kex(new Server_Key_Exchange(pending_state.handshake_io(), + pending_state, policy(), + m_creds, rng(), private_key)); + } - if(save_session(session_info)) - { - if(state.server_hello()->supports_session_ticket()) - { - try - { - const SymmetricKey ticket_key = m_creds.psk("tls-server", "session-ticket", ""); - - state.new_session_ticket( - new New_Session_Ticket(state.handshake_io(), - state.hash(), - session_info.encrypt(ticket_key, rng()), - policy().session_ticket_lifetime()) - ); - } - catch(...) {} - } - else - session_manager().save(session_info); - } + auto trusted_CAs = m_creds.trusted_certificate_authorities("tls-server", sni_hostname); - if(!state.new_session_ticket() && - state.server_hello()->supports_session_ticket()) - { - state.new_session_ticket( - new New_Session_Ticket(state.handshake_io(), state.hash()) - ); - } + std::vector<X509_DN> client_auth_CAs; - state.handshake_io().send(Change_Cipher_Spec()); + for(auto store : trusted_CAs) + { + auto subjects = store->all_subjects(); + client_auth_CAs.insert(client_auth_CAs.end(), subjects.begin(), subjects.end()); + } - change_cipher_spec_writer(SERVER); + if(!client_auth_CAs.empty() && pending_state.ciphersuite().sig_algo() != "") + { + pending_state.cert_req( + new Certificate_Req(pending_state.handshake_io(), + pending_state.hash(), + policy(), + client_auth_CAs, + pending_state.version())); + + pending_state.set_expected_next(CERTIFICATE); + } - state.server_finished(new Finished(state.handshake_io(), state, SERVER)); - } + /* + * If the client doesn't have a cert they want to use they are + * allowed to send either an empty cert message or proceed + * directly to the client key exchange, so allow either case. + */ + pending_state.set_expected_next(CLIENT_KEX); - activate_session(); - } - else - throw Unexpected_Message("Unknown handshake message received"); + pending_state.server_hello_done(new Server_Hello_Done(pending_state.handshake_io(), pending_state.hash())); } - } } diff --git a/src/lib/tls/tls_server.h b/src/lib/tls/tls_server.h index 5ea2a1318..051eda445 100644 --- a/src/lib/tls/tls_server.h +++ b/src/lib/tls/tls_server.h @@ -1,6 +1,7 @@ /* * TLS Server * (C) 2004-2011 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -12,10 +13,13 @@ #include <botan/credentials_manager.h> #include <vector> + namespace Botan { namespace TLS { +class Server_Handshake_State; + /** * TLS Server */ @@ -26,7 +30,39 @@ class BOTAN_DLL Server final : public Channel /** * Server initialization + * + * @param callbacks contains a set of callback function references + * required by the TLS client. + * + * @param session_manager manages session state + * + * @param creds manages application/user credentials + * + * @param policy specifies other connection policy information + * + * @param rng a random number generator + * + * @param is_datagram set to true if this server should expect DTLS + * connections. Otherwise TLS connections are expected. + * + * @param reserved_io_buffer_size This many bytes of memory will + * be preallocated for the read and write buffers. Smaller + * values just mean reallocations and copies are more likely. */ + Server(Callbacks& callbacks, + Session_Manager& session_manager, + Credentials_Manager& creds, + const Policy& policy, + RandomNumberGenerator& rng, + bool is_datagram = false, + size_t reserved_io_buffer_size = TLS::Server::IO_BUF_DEFAULT_SIZE + ); + + /** + * DEPRECATED. This constructor is only provided for backward + * compatibility and should not be used in new implementations. + */ + BOTAN_DEPRECATED("Use TLS::Server(TLS::Callbacks ...)") Server(output_fn output, data_cb data_cb, alert_cb alert_cb, @@ -37,9 +73,14 @@ class BOTAN_DLL Server final : public Channel RandomNumberGenerator& rng, next_protocol_fn next_proto = next_protocol_fn(), bool is_datagram = false, - size_t reserved_io_buffer_size = 16*1024 + size_t reserved_io_buffer_size = TLS::Server::IO_BUF_DEFAULT_SIZE ); + /** + * DEPRECATED. This constructor is only provided for backward + * compatibility and should not be used in new implementations. + */ + BOTAN_DEPRECATED("Use TLS::Server(TLS::Callbacks ...)") Server(output_fn output, data_cb data_cb, alert_cb alert_cb, @@ -73,12 +114,40 @@ class BOTAN_DLL Server final : public Channel Handshake_Type type, const std::vector<byte>& contents) override; + void process_client_hello_msg(const Handshake_State* active_state, + Server_Handshake_State& pending_state, + const std::vector<byte>& contents); + + void process_certificate_msg(Server_Handshake_State& pending_state, + const std::vector<byte>& contents); + + void process_client_key_exchange_msg(Server_Handshake_State& pending_state, + const std::vector<byte>& contents); + + void process_change_cipher_spec_msg(Server_Handshake_State& pending_state); + + void process_certificate_verify_msg(Server_Handshake_State& pending_state, + Handshake_Type type, + const std::vector<byte>& contents); + + void process_finished_msg(Server_Handshake_State& pending_state, + Handshake_Type type, + const std::vector<byte>& contents); + + void session_resume(Server_Handshake_State& pending_state, + bool have_session_ticket_key, + Session& session_info); + + void session_create(Server_Handshake_State& pending_state, + bool have_session_ticket_key); + Handshake_State* new_handshake_state(Handshake_IO* io) override; Credentials_Manager& m_creds; + std::string m_next_protocol; + // Set by deprecated constructor, Server calls both this fn and Callbacks version next_protocol_fn m_choose_next_protocol; - std::string m_next_protocol; }; } diff --git a/src/tests/unit_tls.cpp b/src/tests/unit_tls.cpp index f125bfcb5..f0caa4aef 100644 --- a/src/tests/unit_tls.cpp +++ b/src/tests/unit_tls.cpp @@ -1,5 +1,6 @@ /* * (C) 2014,2015 Jack Lloyd +* 2016 Matthias Gierlings * * Botan is released under the Simplified BSD License (see license.txt) */ @@ -161,7 +162,11 @@ std::function<void (const byte[], size_t)> queue_inserter(std::vector<byte>& q) return [&](const byte buf[], size_t sz) { q.insert(q.end(), buf, buf + sz); }; } -void print_alert(Botan::TLS::Alert, const byte[], size_t) +void print_alert(Botan::TLS::Alert) + { + } + +void alert_cb_with_data(Botan::TLS::Alert, const byte[], size_t) { } @@ -218,164 +223,216 @@ Test::Result test_tls_handshake(Botan::TLS::Protocol_Version offer_version, { std::vector<byte> c2s_traffic, s2c_traffic, client_recv, server_recv, client_sent, server_sent; - Botan::TLS::Server server(queue_inserter(s2c_traffic), - queue_inserter(server_recv), - print_alert, - handshake_complete, + std::unique_ptr<Botan::TLS::Callbacks> server_cb(new Botan::TLS::Compat_Callbacks( + queue_inserter(s2c_traffic), + queue_inserter(server_recv), + std::function<void (Botan::TLS::Alert, const byte[], size_t)>(alert_cb_with_data), + handshake_complete, + nullptr, + next_protocol_chooser)); + + // TLS::Server object constructed by new constructor using virtual callback interface. + std::unique_ptr<Botan::TLS::Server> server( + new Botan::TLS::Server(*server_cb, server_sessions, creds, policy, rng, - next_protocol_chooser, - false); + false)); - Botan::TLS::Client client(queue_inserter(c2s_traffic), - queue_inserter(client_recv), - print_alert, - handshake_complete, + std::unique_ptr<Botan::TLS::Callbacks> client_cb(new Botan::TLS::Compat_Callbacks( + queue_inserter(c2s_traffic), + queue_inserter(client_recv), + std::function<void (Botan::TLS::Alert, const byte[], size_t)>(alert_cb_with_data), + handshake_complete)); + + // TLS::Client object constructed by new constructor using virtual callback interface. + std::unique_ptr<Botan::TLS::Client> client( + new Botan::TLS::Client(*client_cb, client_sessions, creds, policy, rng, Botan::TLS::Server_Information("server.example.com"), offer_version, - protocols_offered); + protocols_offered)); size_t rounds = 0; - while(true) + // Test TLS using both new and legacy constructors. + for(size_t ctor_sel = 0; ctor_sel < 2; ctor_sel++) { - ++rounds; - - if(rounds > 25) + if(ctor_sel == 1) { - if(r <= 2) - result.test_failure("Still here after many rounds, deadlock?"); - break; + c2s_traffic.clear(); + s2c_traffic.clear(); + server_recv.clear(); + client_recv.clear(); + client_sent.clear(); + server_sent.clear(); + + // TLS::Server object constructed by legacy constructor. + server.reset( + new Botan::TLS::Server(queue_inserter(s2c_traffic), + queue_inserter(server_recv), + alert_cb_with_data, + handshake_complete, + server_sessions, + creds, + policy, + rng, + next_protocol_chooser, + false)); + + // TLS::Client object constructed by legacy constructor. + client.reset( + new Botan::TLS::Client(queue_inserter(c2s_traffic), + queue_inserter(client_recv), + alert_cb_with_data, + handshake_complete, + client_sessions, + creds, + policy, + rng, + Botan::TLS::Server_Information("server.example.com"), + offer_version, + protocols_offered)); } - if(handshake_done && (client.is_closed() || server.is_closed())) - break; - - if(client.is_active() && client_sent.empty()) + while(true) { - // Choose a len between 1 and 511 - const size_t c_len = 1 + rng.next_byte() + rng.next_byte(); - client_sent = unlock(rng.random_vec(c_len)); + ++rounds; - // TODO send in several records - client.send(client_sent); - } + if(rounds > 25) + { + if(r <= 2) + result.test_failure("Still here after many rounds, deadlock?"); + break; + } - if(server.is_active() && server_sent.empty()) - { - result.test_eq("server protocol", server.next_protocol(), "test/3"); + if(handshake_done && (client->is_closed() || server->is_closed())) + break; - const size_t s_len = 1 + rng.next_byte() + rng.next_byte(); - server_sent = unlock(rng.random_vec(s_len)); - server.send(server_sent); - } + if(client->is_active() && client_sent.empty()) + { + // Choose a len between 1 and 511 + const size_t c_len = 1 + rng.next_byte() + rng.next_byte(); + client_sent = unlock(rng.random_vec(c_len)); - const bool corrupt_client_data = (r == 3); - const bool corrupt_server_data = (r == 4); + // TODO send in several records + client->send(client_sent); + } - if(c2s_traffic.size() > 0) - { - /* - * Use this as a temp value to hold the queues as otherwise they - * might end up appending more in response to messages during the - * handshake. - */ - std::vector<byte> input; - std::swap(c2s_traffic, input); - - if(corrupt_server_data) + if(server->is_active() && server_sent.empty()) { - input = Test::mutate_vec(input, true); - size_t needed = server.received_data(input.data(), input.size()); - - size_t total_consumed = needed; + result.test_eq("server->protocol", server->next_protocol(), "test/3"); - while(needed > 0 && - result.test_lt("Never requesting more than max protocol len", needed, 18*1024) && - result.test_lt("Total requested is readonable", total_consumed, 128*1024)) - { - input.resize(needed); - Test::rng().randomize(input.data(), input.size()); - needed = server.received_data(input.data(), input.size()); - total_consumed += needed; - } + const size_t s_len = 1 + rng.next_byte() + rng.next_byte(); + server_sent = unlock(rng.random_vec(s_len)); + server->send(server_sent); } - else + + const bool corrupt_client_data = (r == 3); + const bool corrupt_server_data = (r == 4); + + if(c2s_traffic.size() > 0) { - size_t needed = server.received_data(input.data(), input.size()); - result.test_eq("full packet received", needed, 0); - } + /* + * Use this as a temp value to hold the queues as otherwise they + * might end up appending more in response to messages during the + * handshake. + */ + std::vector<byte> input; + std::swap(c2s_traffic, input); + + if(corrupt_server_data) + { + input = Test::mutate_vec(input, true); + size_t needed = server->received_data(input.data(), input.size()); - continue; - } + size_t total_consumed = needed; - if(s2c_traffic.size() > 0) - { - std::vector<byte> input; - std::swap(s2c_traffic, input); + while(needed > 0 && + result.test_lt("Never requesting more than max protocol len", needed, 18*1024) && + result.test_lt("Total requested is readonable", total_consumed, 128*1024)) + { + input.resize(needed); + Test::rng().randomize(input.data(), input.size()); + needed = server->received_data(input.data(), input.size()); + total_consumed += needed; + } + } + else + { + size_t needed = server->received_data(input.data(), input.size()); + result.test_eq("full packet received", needed, 0); + } + + continue; + } - if(corrupt_client_data) + if(s2c_traffic.size() > 0) { - input = Test::mutate_vec(input, true); - size_t needed = client.received_data(input.data(), input.size()); + std::vector<byte> input; + std::swap(s2c_traffic, input); - size_t total_consumed = 0; + if(corrupt_client_data) + { + input = Test::mutate_vec(input, true); + size_t needed = client->received_data(input.data(), input.size()); - while(needed > 0 && result.test_lt("Never requesting more than max protocol len", needed, 18*1024)) + size_t total_consumed = 0; + + while(needed > 0 && result.test_lt("Never requesting more than max protocol len", needed, 18*1024)) + { + input.resize(needed); + Test::rng().randomize(input.data(), input.size()); + needed = client->received_data(input.data(), input.size()); + total_consumed += needed; + } + } + else { - input.resize(needed); - Test::rng().randomize(input.data(), input.size()); - needed = client.received_data(input.data(), input.size()); - total_consumed += needed; + size_t needed = client->received_data(input.data(), input.size()); + result.test_eq("full packet received", needed, 0); } + + continue; } - else + + if(client_recv.size()) { - size_t needed = client.received_data(input.data(), input.size()); - result.test_eq("full packet received", needed, 0); + result.test_eq("client recv", client_recv, server_sent); } - continue; - } - - if(client_recv.size()) - { - result.test_eq("client recv", client_recv, server_sent); - } - - if(server_recv.size()) - { - result.test_eq("server recv", server_recv, client_sent); - } + if(server_recv.size()) + { + result.test_eq("server->recv", server_recv, client_sent); + } - if(r > 2) - { - if(client_recv.size() && server_recv.size()) + if(r > 2) { - result.test_failure("Negotiated in the face of data corruption " + std::to_string(r)); + if(client_recv.size() && server_recv.size()) + { + result.test_failure("Negotiated in the face of data corruption " + std::to_string(r)); + } } - } - if(client.is_closed() && server.is_closed()) - break; + if(client->is_closed() && server->is_closed()) + break; - if(server_recv.size() && client_recv.size()) - { - Botan::SymmetricKey client_key = client.key_material_export("label", "context", 32); - Botan::SymmetricKey server_key = server.key_material_export("label", "context", 32); + if(server_recv.size() && client_recv.size()) + { + Botan::SymmetricKey client_key = client->key_material_export("label", "context", 32); + Botan::SymmetricKey server_key = server->key_material_export("label", "context", 32); - result.test_eq("TLS key material export", client_key.bits_of(), server_key.bits_of()); + result.test_eq("TLS key material export", client_key.bits_of(), server_key.bits_of()); - if(r % 2 == 0) - client.close(); - else - server.close(); + if(r % 2 == 0) + client->close(); + else + server->close(); + } } } } @@ -444,183 +501,234 @@ Test::Result test_dtls_handshake(Botan::TLS::Protocol_Version offer_version, { std::vector<byte> c2s_traffic, s2c_traffic, client_recv, server_recv, client_sent, server_sent; - Botan::TLS::Server server(queue_inserter(s2c_traffic), - queue_inserter(server_recv), - print_alert, - handshake_complete, - server_sessions, - creds, - policy, - rng, - next_protocol_chooser, - true); - - Botan::TLS::Client client(queue_inserter(c2s_traffic), - queue_inserter(client_recv), - print_alert, - handshake_complete, + std::unique_ptr<Botan::TLS::Callbacks> server_cb(new Botan::TLS::Compat_Callbacks( + queue_inserter(s2c_traffic), + queue_inserter(server_recv), + std::function<void (Botan::TLS::Alert)>(print_alert), + handshake_complete, + nullptr, + next_protocol_chooser)); + + std::unique_ptr<Botan::TLS::Callbacks> client_cb(new Botan::TLS::Compat_Callbacks( + queue_inserter(c2s_traffic), + queue_inserter(client_recv), + std::function<void (Botan::TLS::Alert)>(print_alert), + handshake_complete)); + + // TLS::Server object constructed by new constructor using virtual callback interface. + std::unique_ptr<Botan::TLS::Server> server( + new Botan::TLS::Server(*server_cb, + server_sessions, + creds, + policy, + rng, + true)); + + // TLS::Client object constructed by new constructor using virtual callback interface. + std::unique_ptr<Botan::TLS::Client> client( + new Botan::TLS::Client(*client_cb, client_sessions, creds, policy, rng, Botan::TLS::Server_Information("server.example.com"), offer_version, - protocols_offered); + protocols_offered)); size_t rounds = 0; - while(true) + // Test DTLS using both new and legacy constructors. + for(size_t ctor_sel = 0; ctor_sel < 2; ctor_sel++) { - // TODO: client and server should be in different threads - std::this_thread::sleep_for(std::chrono::milliseconds(rng.next_byte() % 2)); - ++rounds; - - if(rounds > 100) + if(ctor_sel == 1) { - result.test_failure("Still here after many rounds"); - break; + c2s_traffic.clear(); + s2c_traffic.clear(); + server_recv.clear(); + client_recv.clear(); + client_sent.clear(); + server_sent.clear(); + // TLS::Server object constructed by legacy constructor. + server.reset( + new Botan::TLS::Server(queue_inserter(s2c_traffic), + queue_inserter(server_recv), + alert_cb_with_data, + handshake_complete, + server_sessions, + creds, + policy, + rng, + next_protocol_chooser, + true)); + + // TLS::Client object constructed by legacy constructor. + client.reset( + new Botan::TLS::Client(queue_inserter(c2s_traffic), + queue_inserter(client_recv), + alert_cb_with_data, + handshake_complete, + client_sessions, + creds, + policy, + rng, + Botan::TLS::Server_Information("server.example.com"), + offer_version, + protocols_offered)); } - if(handshake_done && (client.is_closed() || server.is_closed())) - break; - - if(client.is_active() && client_sent.empty()) + while(true) { - // Choose a len between 1 and 511 and send random chunks: - const size_t c_len = 1 + rng.next_byte() + rng.next_byte(); - client_sent = unlock(rng.random_vec(c_len)); + // TODO: client and server should be in different threads + std::this_thread::sleep_for(std::chrono::milliseconds(rng.next_byte() % 2)); + ++rounds; - // TODO send multiple parts - client.send(client_sent); - } - - if(server.is_active() && server_sent.empty()) - { - result.test_eq("server ALPN", server.next_protocol(), "test/3"); - - const size_t s_len = 1 + rng.next_byte() + rng.next_byte(); - server_sent = unlock(rng.random_vec(s_len)); - server.send(server_sent); - } + if(rounds > 100) + { + result.test_failure("Still here after many rounds"); + break; + } - const bool corrupt_client_data = (r == 3 && rng.next_byte() % 3 <= 1 && rounds < 10); - const bool corrupt_server_data = (r == 4 && rng.next_byte() % 3 <= 1 && rounds < 10); + if(handshake_done && (client->is_closed() || server->is_closed())) + break; - if(c2s_traffic.size() > 0) - { - /* - * Use this as a temp value to hold the queues as otherwise they - * might end up appending more in response to messages during the - * handshake. - */ - std::vector<byte> input; - std::swap(c2s_traffic, input); - - if(corrupt_server_data) + if(client->is_active() && client_sent.empty()) { - try - { - input = Test::mutate_vec(input, true); - size_t needed = server.received_data(input.data(), input.size()); + // Choose a len between 1 and 511 and send random chunks: + const size_t c_len = 1 + rng.next_byte() + rng.next_byte(); + client_sent = unlock(rng.random_vec(c_len)); - if(needed > 0 && result.test_lt("Never requesting more than max protocol len", needed, 18*1024)) - { - input.resize(needed); - Test::rng().randomize(input.data(), input.size()); - client.received_data(input.data(), input.size()); - } - } - catch(std::exception&) - { - result.test_note("corruption caused server exception"); - } + // TODO send multiple parts + client->send(client_sent); } - else + + if(server->is_active() && server_sent.empty()) { - try - { - size_t needed = server.received_data(input.data(), input.size()); - result.test_eq("full packet received", needed, 0); - } - catch(std::exception& e) - { - result.test_failure("server error", e.what()); - } - } + result.test_eq("server ALPN", server->next_protocol(), "test/3"); - continue; - } + const size_t s_len = 1 + rng.next_byte() + rng.next_byte(); + server_sent = unlock(rng.random_vec(s_len)); + server->send(server_sent); + } - if(s2c_traffic.size() > 0) - { - std::vector<byte> input; - std::swap(s2c_traffic, input); + const bool corrupt_client_data = (r == 3 && rng.next_byte() % 3 <= 1 && rounds < 10); + const bool corrupt_server_data = (r == 4 && rng.next_byte() % 3 <= 1 && rounds < 10); - if(corrupt_client_data) + if(c2s_traffic.size() > 0) { - try + /* + * Use this as a temp value to hold the queues as otherwise they + * might end up appending more in response to messages during the + * handshake. + */ + std::vector<byte> input; + std::swap(c2s_traffic, input); + + if(corrupt_server_data) { - input = Test::mutate_vec(input, true); - size_t needed = client.received_data(input.data(), input.size()); - - if(needed > 0 && result.test_lt("Never requesting more than max protocol len", needed, 18*1024)) + try { - input.resize(needed); - Test::rng().randomize(input.data(), input.size()); - client.received_data(input.data(), input.size()); + input = Test::mutate_vec(input, true); + size_t needed = server->received_data(input.data(), input.size()); + + if(needed > 0 && result.test_lt("Never requesting more than max protocol len", needed, 18*1024)) + { + input.resize(needed); + Test::rng().randomize(input.data(), input.size()); + client->received_data(input.data(), input.size()); + } + } + catch(std::exception&) + { + result.test_note("corruption caused server exception"); } } - catch(std::exception&) + else { - result.test_note("corruption caused client exception"); + try + { + size_t needed = server->received_data(input.data(), input.size()); + result.test_eq("full packet received", needed, 0); + } + catch(std::exception& e) + { + result.test_failure("server error", e.what()); + } } + + continue; } - else + + if(s2c_traffic.size() > 0) { - try + std::vector<byte> input; + std::swap(s2c_traffic, input); + + if(corrupt_client_data) { - size_t needed = client.received_data(input.data(), input.size()); - result.test_eq("full packet received", needed, 0); + try + { + input = Test::mutate_vec(input, true); + size_t needed = client->received_data(input.data(), input.size()); + + if(needed > 0 && result.test_lt("Never requesting more than max protocol len", needed, 18*1024)) + { + input.resize(needed); + Test::rng().randomize(input.data(), input.size()); + client->received_data(input.data(), input.size()); + } + } + catch(std::exception&) + { + result.test_note("corruption caused client exception"); + } } - catch(std::exception& e) + else { - result.test_failure("client error", e.what()); + try + { + size_t needed = client->received_data(input.data(), input.size()); + result.test_eq("full packet received", needed, 0); + } + catch(std::exception& e) + { + result.test_failure("client error", e.what()); + } } - } - continue; - } + continue; + } - // If we corrupted a DTLS application message, resend it: - if(client.is_active() && corrupt_client_data && server_recv.empty()) - client.send(client_sent); - if(server.is_active() && corrupt_server_data && client_recv.empty()) - server.send(server_sent); + // If we corrupted a DTLS application message, resend it: + if(client->is_active() && corrupt_client_data && server_recv.empty()) + client->send(client_sent); + if(server->is_active() && corrupt_server_data && client_recv.empty()) + server->send(server_sent); - if(client_recv.size()) - { - result.test_eq("client recv", client_recv, server_sent); - } + if(client_recv.size()) + { + result.test_eq("client recv", client_recv, server_sent); + } - if(server_recv.size()) - { - result.test_eq("server recv", server_recv, client_sent); - } + if(server_recv.size()) + { + result.test_eq("server recv", server_recv, client_sent); + } - if(client.is_closed() && server.is_closed()) - break; + if(client->is_closed() && server->is_closed()) + break; - if(server_recv.size() && client_recv.size()) - { - Botan::SymmetricKey client_key = client.key_material_export("label", "context", 32); - Botan::SymmetricKey server_key = server.key_material_export("label", "context", 32); + if(server_recv.size() && client_recv.size()) + { + Botan::SymmetricKey client_key = client->key_material_export("label", "context", 32); + Botan::SymmetricKey server_key = server->key_material_export("label", "context", 32); - result.test_eq("key material export", client_key.bits_of(), server_key.bits_of()); + result.test_eq("key material export", client_key.bits_of(), server_key.bits_of()); - if(r % 2 == 0) - client.close(); - else - server.close(); + if(r % 2 == 0) + client->close(); + else + server->close(); + } } } } |