diff options
author | Jack Lloyd <[email protected]> | 2019-05-24 14:17:55 -0400 |
---|---|---|
committer | Jack Lloyd <[email protected]> | 2019-05-24 14:17:55 -0400 |
commit | bb8c2b138898fb49b36157779f3b2a05dd5bba90 (patch) | |
tree | 710d7f2fe5711cebf3da8ee78bb2f6905ddf9c25 /src | |
parent | a6f271f638a20a619be8e840001ff83112506c40 (diff) |
Let TLS serialization know which side we are sending as
Since this matters for some extensions
Diffstat (limited to 'src')
-rw-r--r-- | src/lib/tls/msg_cert_req.cpp | 2 | ||||
-rw-r--r-- | src/lib/tls/msg_client_hello.cpp | 2 | ||||
-rw-r--r-- | src/lib/tls/msg_server_hello.cpp | 2 | ||||
-rw-r--r-- | src/lib/tls/tls_extensions.cpp | 71 | ||||
-rw-r--r-- | src/lib/tls/tls_extensions.h | 38 | ||||
-rw-r--r-- | src/tests/unit_tls.cpp | 2 |
6 files changed, 57 insertions, 60 deletions
diff --git a/src/lib/tls/msg_cert_req.cpp b/src/lib/tls/msg_cert_req.cpp index b6fc3825b..9e8a4d803 100644 --- a/src/lib/tls/msg_cert_req.cpp +++ b/src/lib/tls/msg_cert_req.cpp @@ -134,7 +134,7 @@ std::vector<uint8_t> Certificate_Req::serialize() const append_tls_length_value(buf, cert_types, 1); if(m_schemes.size() > 0) - buf += Signature_Algorithms(m_schemes).serialize(); + buf += Signature_Algorithms(m_schemes).serialize(Connection_Side::SERVER); std::vector<uint8_t> encoded_names; diff --git a/src/lib/tls/msg_client_hello.cpp b/src/lib/tls/msg_client_hello.cpp index 657fe01b2..f83df44f1 100644 --- a/src/lib/tls/msg_client_hello.cpp +++ b/src/lib/tls/msg_client_hello.cpp @@ -251,7 +251,7 @@ std::vector<uint8_t> Client_Hello::serialize() const * renegotiating with a modern server) */ - buf += m_extensions.serialize(); + buf += m_extensions.serialize(Connection_Side::CLIENT); return buf; } diff --git a/src/lib/tls/msg_server_hello.cpp b/src/lib/tls/msg_server_hello.cpp index f3ef45f50..f24ddeb07 100644 --- a/src/lib/tls/msg_server_hello.cpp +++ b/src/lib/tls/msg_server_hello.cpp @@ -180,7 +180,7 @@ std::vector<uint8_t> Server_Hello::serialize() const buf.push_back(m_comp_method); - buf += m_extensions.serialize(); + buf += m_extensions.serialize(Connection_Side::SERVER); return buf; } diff --git a/src/lib/tls/tls_extensions.cpp b/src/lib/tls/tls_extensions.cpp index 54f006c74..49f996228 100644 --- a/src/lib/tls/tls_extensions.cpp +++ b/src/lib/tls/tls_extensions.cpp @@ -97,7 +97,7 @@ void Extensions::deserialize(TLS_Data_Reader& reader, Connection_Side from) } } -std::vector<uint8_t> Extensions::serialize() const +std::vector<uint8_t> Extensions::serialize(Connection_Side whoami) const { std::vector<uint8_t> buf(2); // 2 bytes for length field @@ -108,7 +108,7 @@ std::vector<uint8_t> Extensions::serialize() const const uint16_t extn_code = static_cast<uint16_t>(extn.second->type()); - std::vector<uint8_t> extn_val = extn.second->serialize(); + const std::vector<uint8_t> extn_val = extn.second->serialize(whoami); buf.push_back(get_byte(0, extn_code)); buf.push_back(get_byte(1, extn_code)); @@ -156,7 +156,7 @@ Unknown_Extension::Unknown_Extension(Handshake_Extension_Type type, { } -std::vector<uint8_t> Unknown_Extension::serialize() const +std::vector<uint8_t> Unknown_Extension::serialize(Connection_Side /*whoami*/) const { throw Invalid_State("Cannot encode an unknown TLS extension"); } @@ -193,7 +193,7 @@ Server_Name_Indicator::Server_Name_Indicator(TLS_Data_Reader& reader, } } -std::vector<uint8_t> Server_Name_Indicator::serialize() const +std::vector<uint8_t> Server_Name_Indicator::serialize(Connection_Side /*whoami*/) const { std::vector<uint8_t> buf; @@ -222,7 +222,7 @@ SRP_Identifier::SRP_Identifier(TLS_Data_Reader& reader, throw Decoding_Error("Bad encoding for SRP identifier extension"); } -std::vector<uint8_t> SRP_Identifier::serialize() const +std::vector<uint8_t> SRP_Identifier::serialize(Connection_Side /*whoami*/) const { std::vector<uint8_t> buf; @@ -241,7 +241,7 @@ Renegotiation_Extension::Renegotiation_Extension(TLS_Data_Reader& reader, throw Decoding_Error("Bad encoding for secure renegotiation extn"); } -std::vector<uint8_t> Renegotiation_Extension::serialize() const +std::vector<uint8_t> Renegotiation_Extension::serialize(Connection_Side /*whoami*/) const { std::vector<uint8_t> buf; append_tls_length_value(buf, m_reneg_data, 1); @@ -286,7 +286,7 @@ const std::string& Application_Layer_Protocol_Notification::single_protocol() co return m_protocols[0]; } -std::vector<uint8_t> Application_Layer_Protocol_Notification::serialize() const +std::vector<uint8_t> Application_Layer_Protocol_Notification::serialize(Connection_Side /*whoami*/) const { std::vector<uint8_t> buf(2); @@ -333,7 +333,7 @@ std::vector<Group_Params> Supported_Groups::dh_groups() const return dh; } -std::vector<uint8_t> Supported_Groups::serialize() const +std::vector<uint8_t> Supported_Groups::serialize(Connection_Side /*whoami*/) const { std::vector<uint8_t> buf(2); @@ -372,7 +372,7 @@ Supported_Groups::Supported_Groups(TLS_Data_Reader& reader, } } -std::vector<uint8_t> Supported_Point_Formats::serialize() const +std::vector<uint8_t> Supported_Point_Formats::serialize(Connection_Side /*whoami*/) const { // if this extension is sent, it MUST include uncompressed (RFC 4492, section 5.1) if(m_prefers_compressed) @@ -414,7 +414,7 @@ Supported_Point_Formats::Supported_Point_Formats(TLS_Data_Reader& reader, } } -std::vector<uint8_t> Signature_Algorithms::serialize() const +std::vector<uint8_t> Signature_Algorithms::serialize(Connection_Side /*whoami*/) const { BOTAN_ASSERT(m_schemes.size() < 256, "Too many signature schemes"); @@ -470,7 +470,7 @@ SRTP_Protection_Profiles::SRTP_Protection_Profiles(TLS_Data_Reader& reader, throw Decoding_Error("Unhandled non-empty MKI for SRTP protection extension"); } -std::vector<uint8_t> SRTP_Protection_Profiles::serialize() const +std::vector<uint8_t> SRTP_Protection_Profiles::serialize(Connection_Side /*whoami*/) const { std::vector<uint8_t> buf; @@ -496,7 +496,7 @@ Extended_Master_Secret::Extended_Master_Secret(TLS_Data_Reader&, throw Decoding_Error("Invalid extended_master_secret extension"); } -std::vector<uint8_t> Extended_Master_Secret::serialize() const +std::vector<uint8_t> Extended_Master_Secret::serialize(Connection_Side /*whoami*/) const { return std::vector<uint8_t>(); } @@ -508,16 +508,16 @@ Encrypt_then_MAC::Encrypt_then_MAC(TLS_Data_Reader&, throw Decoding_Error("Invalid encrypt_then_mac extension"); } -std::vector<uint8_t> Encrypt_then_MAC::serialize() const +std::vector<uint8_t> Encrypt_then_MAC::serialize(Connection_Side /*whoami*/) const { return std::vector<uint8_t>(); } -std::vector<uint8_t> Certificate_Status_Request::serialize() const +std::vector<uint8_t> Certificate_Status_Request::serialize(Connection_Side whoami) const { std::vector<uint8_t> buf; - if(m_server_side) + if(whoami == Connection_Side::SERVER) return buf; // server reply is empty /* @@ -541,18 +541,22 @@ std::vector<uint8_t> Certificate_Status_Request::serialize() const Certificate_Status_Request::Certificate_Status_Request(TLS_Data_Reader& reader, uint16_t extension_size, - Connection_Side from) : - m_server_side(from == SERVER) + Connection_Side from) { - if(extension_size > 0) + if(from == Connection_Side::SERVER) + { + if(extension_size != 0) + throw Decoding_Error("Server sent non-empty Certificate_Status_Request extension"); + } + else if(extension_size > 0) { const uint8_t type = reader.get_byte(); if(type == 1) { - size_t len_resp_id_list = reader.get_uint16_t(); + const size_t len_resp_id_list = reader.get_uint16_t(); m_ocsp_names = reader.get_fixed<uint8_t>(len_resp_id_list); - size_t len_requ_ext = reader.get_uint16_t(); - m_extension_bytes = reader.get_fixed<uint8_t>(len_requ_ext ); + const size_t len_requ_ext = reader.get_uint16_t(); + m_extension_bytes = reader.get_fixed<uint8_t>(len_requ_ext); } else { @@ -564,22 +568,15 @@ Certificate_Status_Request::Certificate_Status_Request(TLS_Data_Reader& reader, Certificate_Status_Request::Certificate_Status_Request(const std::vector<uint8_t>& ocsp_responder_ids, const std::vector<std::vector<uint8_t>>& ocsp_key_ids) : m_ocsp_names(ocsp_responder_ids), - m_ocsp_keys(ocsp_key_ids), - m_server_side(false) - { - - } - -Certificate_Status_Request::Certificate_Status_Request() : m_server_side(true) + m_ocsp_keys(ocsp_key_ids) { - } -std::vector<uint8_t> Supported_Versions::serialize() const +std::vector<uint8_t> Supported_Versions::serialize(Connection_Side whoami) const { std::vector<uint8_t> buf; - if(m_server_side) + if(whoami == Connection_Side::SERVER) { BOTAN_ASSERT_NOMSG(m_versions.size() == 1); buf.push_back(m_versions[0].major_version()); @@ -587,6 +584,7 @@ std::vector<uint8_t> Supported_Versions::serialize() const } else { + BOTAN_ASSERT_NOMSG(m_versions.size() >= 1); const uint8_t len = static_cast<uint8_t>(m_versions.size() * 2); buf.push_back(len); @@ -601,8 +599,7 @@ std::vector<uint8_t> Supported_Versions::serialize() const return buf; } -Supported_Versions::Supported_Versions(Protocol_Version offer, const Policy& policy) : - m_server_side(false) +Supported_Versions::Supported_Versions(Protocol_Version offer, const Policy& policy) { if(offer.is_datagram_protocol()) { @@ -628,17 +625,19 @@ Supported_Versions::Supported_Versions(TLS_Data_Reader& reader, { if(from == Connection_Side::SERVER) { - m_server_side = true; + if(extension_size != 2) + throw Decoding_Error("Server sent invalid supported_versions extension"); m_versions.push_back(Protocol_Version(reader.get_uint16_t())); } else { - m_server_side = false; - auto versions = reader.get_range<uint16_t>(1, 1, 127); for(auto v : versions) m_versions.push_back(Protocol_Version(v)); + + if(extension_size != 1+2*versions.size()) + throw Decoding_Error("Client sent invalid supported_versions extension"); } } diff --git a/src/lib/tls/tls_extensions.h b/src/lib/tls/tls_extensions.h index e015f7b82..7dda6aaa0 100644 --- a/src/lib/tls/tls_extensions.h +++ b/src/lib/tls/tls_extensions.h @@ -64,7 +64,7 @@ class BOTAN_UNSTABLE_API Extension /** * @return serialized binary for the extension */ - virtual std::vector<uint8_t> serialize() const = 0; + virtual std::vector<uint8_t> serialize(Connection_Side whoami) const = 0; /** * @return if we should encode this extension or not @@ -93,7 +93,7 @@ class BOTAN_UNSTABLE_API Server_Name_Indicator final : public Extension std::string host_name() const { return m_sni_host_name; } - std::vector<uint8_t> serialize() const override; + std::vector<uint8_t> serialize(Connection_Side whoami) const override; bool empty() const override { return m_sni_host_name.empty(); } private: @@ -120,7 +120,7 @@ class BOTAN_UNSTABLE_API SRP_Identifier final : public Extension std::string identifier() const { return m_srp_identifier; } - std::vector<uint8_t> serialize() const override; + std::vector<uint8_t> serialize(Connection_Side whoami) const override; bool empty() const override { return m_srp_identifier.empty(); } private: @@ -150,7 +150,7 @@ class BOTAN_UNSTABLE_API Renegotiation_Extension final : public Extension const std::vector<uint8_t>& renegotiation_info() const { return m_reneg_data; } - std::vector<uint8_t> serialize() const override; + std::vector<uint8_t> serialize(Connection_Side whoami) const override; bool empty() const override { return false; } // always send this private: @@ -186,7 +186,7 @@ class BOTAN_UNSTABLE_API Application_Layer_Protocol_Notification final : public Application_Layer_Protocol_Notification(TLS_Data_Reader& reader, uint16_t extension_size); - std::vector<uint8_t> serialize() const override; + std::vector<uint8_t> serialize(Connection_Side whoami) const override; bool empty() const override { return m_protocols.empty(); } private: @@ -225,7 +225,7 @@ class BOTAN_UNSTABLE_API Session_Ticket final : public Extension */ Session_Ticket(TLS_Data_Reader& reader, uint16_t extension_size); - std::vector<uint8_t> serialize() const override { return m_ticket; } + std::vector<uint8_t> serialize(Connection_Side) const override { return m_ticket; } bool empty() const override { return false; } private: @@ -247,7 +247,7 @@ class BOTAN_UNSTABLE_API Supported_Groups final : public Extension std::vector<Group_Params> ec_groups() const; std::vector<Group_Params> dh_groups() const; - std::vector<uint8_t> serialize() const override; + std::vector<uint8_t> serialize(Connection_Side whoami) const override; explicit Supported_Groups(const std::vector<Group_Params>& groups); @@ -279,7 +279,7 @@ class BOTAN_UNSTABLE_API Supported_Point_Formats final : public Extension Handshake_Extension_Type type() const override { return static_type(); } - std::vector<uint8_t> serialize() const override; + std::vector<uint8_t> serialize(Connection_Side whoami) const override; explicit Supported_Point_Formats(bool prefer_compressed) : m_prefers_compressed(prefer_compressed) {} @@ -308,7 +308,7 @@ class BOTAN_UNSTABLE_API Signature_Algorithms final : public Extension const std::vector<Signature_Scheme>& supported_schemes() const { return m_schemes; } - std::vector<uint8_t> serialize() const override; + std::vector<uint8_t> serialize(Connection_Side whoami) const override; bool empty() const override { return m_schemes.empty(); } @@ -334,7 +334,7 @@ class BOTAN_UNSTABLE_API SRTP_Protection_Profiles final : public Extension const std::vector<uint16_t>& profiles() const { return m_pp; } - std::vector<uint8_t> serialize() const override; + std::vector<uint8_t> serialize(Connection_Side whoami) const override; bool empty() const override { return m_pp.empty(); } @@ -358,7 +358,7 @@ class BOTAN_UNSTABLE_API Extended_Master_Secret final : public Extension Handshake_Extension_Type type() const override { return static_type(); } - std::vector<uint8_t> serialize() const override; + std::vector<uint8_t> serialize(Connection_Side whoami) const override; bool empty() const override { return false; } @@ -378,7 +378,7 @@ class BOTAN_UNSTABLE_API Encrypt_then_MAC final : public Extension Handshake_Extension_Type type() const override { return static_type(); } - std::vector<uint8_t> serialize() const override; + std::vector<uint8_t> serialize(Connection_Side whoami) const override; bool empty() const override { return false; } @@ -398,7 +398,7 @@ class BOTAN_UNSTABLE_API Certificate_Status_Request final : public Extension Handshake_Extension_Type type() const override { return static_type(); } - std::vector<uint8_t> serialize() const override; + std::vector<uint8_t> serialize(Connection_Side whoami) const override; bool empty() const override { return false; } @@ -413,7 +413,7 @@ class BOTAN_UNSTABLE_API Certificate_Status_Request final : public Extension } // Server generated version: empty - Certificate_Status_Request(); + Certificate_Status_Request() {} // Client version, both lists can be empty Certificate_Status_Request(const std::vector<uint8_t>& ocsp_responder_ids, @@ -426,7 +426,6 @@ class BOTAN_UNSTABLE_API Certificate_Status_Request final : public Extension std::vector<uint8_t> m_ocsp_names; std::vector<std::vector<uint8_t>> m_ocsp_keys; // is this field really needed std::vector<uint8_t> m_extension_bytes; - bool m_server_side; }; /** @@ -440,13 +439,13 @@ class BOTAN_UNSTABLE_API Supported_Versions final : public Extension Handshake_Extension_Type type() const override { return static_type(); } - std::vector<uint8_t> serialize() const override; + std::vector<uint8_t> serialize(Connection_Side whoami) const override; bool empty() const override { return m_versions.empty(); } Supported_Versions(Protocol_Version version, const Policy& policy); - Supported_Versions(Protocol_Version version) : m_server_side(true) + Supported_Versions(Protocol_Version version) { m_versions.push_back(version); } @@ -460,7 +459,6 @@ class BOTAN_UNSTABLE_API Supported_Versions final : public Extension const std::vector<Protocol_Version> versions() const { return m_versions; } private: std::vector<Protocol_Version> m_versions; - bool m_server_side; }; /** @@ -473,7 +471,7 @@ class BOTAN_UNSTABLE_API Unknown_Extension final : public Extension TLS_Data_Reader& reader, uint16_t extension_size); - std::vector<uint8_t> serialize() const override; // always fails + std::vector<uint8_t> serialize(Connection_Side whoami) const override; // always fails const std::vector<uint8_t>& value() { return m_value; } @@ -521,7 +519,7 @@ class BOTAN_UNSTABLE_API Extensions final return nullptr; } - std::vector<uint8_t> serialize() const; + std::vector<uint8_t> serialize(Connection_Side whoami) const; void deserialize(TLS_Data_Reader& reader, Connection_Side from); diff --git a/src/tests/unit_tls.cpp b/src/tests/unit_tls.cpp index e42cd6a4a..34c25e9ef 100644 --- a/src/tests/unit_tls.cpp +++ b/src/tests/unit_tls.cpp @@ -335,7 +335,7 @@ class TLS_Handshake_Test final Botan::TLS::Handshake_Extension_Type type() const override { return static_type(); } - std::vector<uint8_t> serialize() const override { return m_buf; } + std::vector<uint8_t> serialize(Botan::TLS::Connection_Side) const override { return m_buf; } const std::vector<uint8_t>& value() const { return m_buf; } |