From 239241568d4d3ff14d2d1994e5829f3d548f2078 Mon Sep 17 00:00:00 2001 From: lloyd Date: Thu, 19 Jan 2012 17:07:03 +0000 Subject: Remove Handshake_Message::deserialize which was an unnecessary hook. Instead deserialize directly in the constructors that are passed the raw message data. This makes it easier to pass contextual information needed for decoding (eg, version numbers) where necessary. --- src/tls/c_hello.cpp | 14 ++++---- src/tls/c_kex.cpp | 22 ++++-------- src/tls/cert_req.cpp | 90 +++++++++++++++++++++++------------------------ src/tls/cert_ver.cpp | 18 +++++----- src/tls/finished.cpp | 2 +- src/tls/next_protocol.cpp | 18 +++++----- src/tls/s_hello.cpp | 84 +++++++++++++++++++++---------------------- src/tls/s_kex.cpp | 2 +- src/tls/tls_messages.h | 30 ++++++---------- src/tls/tls_reader.h | 5 +-- 10 files changed, 132 insertions(+), 153 deletions(-) diff --git a/src/tls/c_hello.cpp b/src/tls/c_hello.cpp index 6c4964fb1..60f6de487 100644 --- a/src/tls/c_hello.cpp +++ b/src/tls/c_hello.cpp @@ -56,20 +56,20 @@ Hello_Request::Hello_Request(Record_Writer& writer) } /* -* Serialize a Hello Request message +* Deserialize a Hello Request message */ -MemoryVector Hello_Request::serialize() const +Hello_Request::Hello_Request(const MemoryRegion& buf) { - return MemoryVector(); + if(buf.size()) + throw Decoding_Error("Hello_Request: Must be empty, and is not"); } /* -* Deserialize a Hello Request message +* Serialize a Hello Request message */ -void Hello_Request::deserialize(const MemoryRegion& buf) +MemoryVector Hello_Request::serialize() const { - if(buf.size()) - throw Decoding_Error("Hello_Request: Must be empty, and is not"); + return MemoryVector(); } /* diff --git a/src/tls/c_kex.cpp b/src/tls/c_kex.cpp index f95f74931..3d79116ca 100644 --- a/src/tls/c_kex.cpp +++ b/src/tls/c_kex.cpp @@ -90,7 +90,13 @@ Client_Key_Exchange::Client_Key_Exchange(const MemoryRegion& contents, if(using_version == SSL_V3 && (suite.kex_type() == TLS_ALGO_KEYEXCH_NOKEX)) include_length = false; - deserialize(contents); + if(include_length) + { + TLS_Data_Reader reader(contents); + key_material = reader.get_range(2, 0, 65535); + } + else + key_material = contents; } /* @@ -108,20 +114,6 @@ MemoryVector Client_Key_Exchange::serialize() const return key_material; } -/* -* Deserialize a Client Key Exchange message -*/ -void Client_Key_Exchange::deserialize(const MemoryRegion& buf) - { - if(include_length) - { - TLS_Data_Reader reader(buf); - key_material = reader.get_range(2, 0, 65535); - } - else - key_material = buf; - } - /* * Return the pre_master_secret */ diff --git a/src/tls/cert_req.cpp b/src/tls/cert_req.cpp index 0168e4b7d..bdb25057c 100644 --- a/src/tls/cert_req.cpp +++ b/src/tls/cert_req.cpp @@ -36,30 +36,10 @@ Certificate_Req::Certificate_Req(Record_Writer& writer, send(writer, hash); } -/** -* Serialize a Certificate Request message -*/ -MemoryVector Certificate_Req::serialize() const - { - MemoryVector buf; - - append_tls_length_value(buf, types, 1); - - for(size_t i = 0; i != names.size(); ++i) - { - DER_Encoder encoder; - encoder.encode(names[i]); - - append_tls_length_value(buf, encoder.get_contents(), 2); - } - - return buf; - } - /** * Deserialize a Certificate Request message */ -void Certificate_Req::deserialize(const MemoryRegion& buf) +Certificate_Req::Certificate_Req(const MemoryRegion& buf) { if(buf.size() < 4) throw Decoding_Error("Certificate_Req: Bad certificate request"); @@ -96,43 +76,40 @@ void Certificate_Req::deserialize(const MemoryRegion& buf) } /** -* Create a new Certificate message +* Serialize a Certificate Request message */ -Certificate::Certificate(Record_Writer& writer, - TLS_Handshake_Hash& hash, - const std::vector& cert_list) +MemoryVector Certificate_Req::serialize() const { - certs = cert_list; - send(writer, hash); - } + MemoryVector buf; -/** -* Serialize a Certificate message -*/ -MemoryVector Certificate::serialize() const - { - MemoryVector buf(3); + append_tls_length_value(buf, types, 1); - for(size_t i = 0; i != certs.size(); ++i) + for(size_t i = 0; i != names.size(); ++i) { - MemoryVector raw_cert = certs[i].BER_encode(); - const size_t cert_size = raw_cert.size(); - for(size_t i = 0; i != 3; ++i) - buf.push_back(get_byte(i+1, cert_size)); - buf += raw_cert; - } + DER_Encoder encoder; + encoder.encode(names[i]); - const size_t buf_size = buf.size() - 3; - for(size_t i = 0; i != 3; ++i) - buf[i] = get_byte(i+1, buf_size); + append_tls_length_value(buf, encoder.get_contents(), 2); + } return buf; } +/** +* Create a new Certificate message +*/ +Certificate::Certificate(Record_Writer& writer, + TLS_Handshake_Hash& hash, + const std::vector& cert_list) + { + certs = cert_list; + send(writer, hash); + } + /** * Deserialize a Certificate message */ -void Certificate::deserialize(const MemoryRegion& buf) +Certificate::Certificate(const MemoryRegion& buf) { if(buf.size() < 3) throw Decoding_Error("Certificate: Message malformed"); @@ -163,4 +140,27 @@ void Certificate::deserialize(const MemoryRegion& buf) } } +/** +* Serialize a Certificate message +*/ +MemoryVector Certificate::serialize() const + { + MemoryVector buf(3); + + for(size_t i = 0; i != certs.size(); ++i) + { + MemoryVector raw_cert = certs[i].BER_encode(); + const size_t cert_size = raw_cert.size(); + for(size_t i = 0; i != 3; ++i) + buf.push_back(get_byte(i+1, cert_size)); + buf += raw_cert; + } + + const size_t buf_size = buf.size() - 3; + for(size_t i = 0; i != 3; ++i) + buf[i] = get_byte(i+1, buf_size); + + return buf; + } + } diff --git a/src/tls/cert_ver.cpp b/src/tls/cert_ver.cpp index 81d529e88..77d9fe74b 100644 --- a/src/tls/cert_ver.cpp +++ b/src/tls/cert_ver.cpp @@ -53,6 +53,15 @@ Certificate_Verify::Certificate_Verify(Record_Writer& writer, send(writer, state->hash); } +/* +* Deserialize a Certificate Verify message +*/ +Certificate_Verify::Certificate_Verify(const MemoryRegion& buf) + { + TLS_Data_Reader reader(buf); + signature = reader.get_range(2, 0, 65535); + } + /* * Serialize a Certificate Verify message */ @@ -68,15 +77,6 @@ MemoryVector Certificate_Verify::serialize() const return buf; } -/* -* Deserialize a Certificate Verify message -*/ -void Certificate_Verify::deserialize(const MemoryRegion& buf) - { - TLS_Data_Reader reader(buf); - signature = reader.get_range(2, 0, 65535); - } - /* * Verify a Certificate Verify message */ diff --git a/src/tls/finished.cpp b/src/tls/finished.cpp index 836512f81..baa663798 100644 --- a/src/tls/finished.cpp +++ b/src/tls/finished.cpp @@ -81,7 +81,7 @@ MemoryVector Finished::serialize() const /* * Deserialize a Finished message */ -void Finished::deserialize(const MemoryRegion& buf) +Finished::Finished(const MemoryRegion& buf) { verification_data = buf; } diff --git a/src/tls/next_protocol.cpp b/src/tls/next_protocol.cpp index 2d2e2e599..a0d4278f1 100644 --- a/src/tls/next_protocol.cpp +++ b/src/tls/next_protocol.cpp @@ -19,6 +19,15 @@ Next_Protocol::Next_Protocol(Record_Writer& writer, send(writer, hash); } +Next_Protocol::Next_Protocol(const MemoryRegion& buf) + { + TLS_Data_Reader reader(buf); + + m_protocol = reader.get_string(1, 0, 255); + + reader.get_range_vector(1, 0, 255); // padding, ignored + } + MemoryVector Next_Protocol::serialize() const { MemoryVector buf; @@ -38,13 +47,4 @@ MemoryVector Next_Protocol::serialize() const return buf; } -void Next_Protocol::deserialize(const MemoryRegion& buf) - { - TLS_Data_Reader reader(buf); - - m_protocol = reader.get_string(1, 0, 255); - - reader.get_range_vector(1, 0, 255); // padding, ignored - } - } diff --git a/src/tls/s_hello.cpp b/src/tls/s_hello.cpp index 90e18ae90..fa185599d 100644 --- a/src/tls/s_hello.cpp +++ b/src/tls/s_hello.cpp @@ -89,44 +89,10 @@ Server_Hello::Server_Hello(Record_Writer& writer, send(writer, hash); } -/* -* Serialize a Server Hello message -*/ -MemoryVector Server_Hello::serialize() const - { - MemoryVector buf; - - buf.push_back(static_cast(s_version >> 8)); - buf.push_back(static_cast(s_version )); - buf += s_random; - - append_tls_length_value(buf, m_session_id, 1); - - buf.push_back(get_byte(0, suite)); - buf.push_back(get_byte(1, suite)); - - buf.push_back(comp_method); - - TLS_Extensions extensions; - - if(m_secure_renegotiation) - extensions.push_back(new Renegotation_Extension(m_renegotiation_info)); - - if(m_fragment_size != 0) - extensions.push_back(new Maximum_Fragment_Length(m_fragment_size)); - - if(m_next_protocol) - extensions.push_back(new Next_Protocol_Notification(m_next_protocols)); - - buf += extensions.serialize(); - - return buf; - } - /* * Deserialize a Server Hello message */ -void Server_Hello::deserialize(const MemoryRegion& buf) +Server_Hello::Server_Hello(const MemoryRegion& buf) { m_secure_renegotiation = false; m_next_protocol = false; @@ -172,6 +138,40 @@ void Server_Hello::deserialize(const MemoryRegion& buf) } } +/* +* Serialize a Server Hello message +*/ +MemoryVector Server_Hello::serialize() const + { + MemoryVector buf; + + buf.push_back(static_cast(s_version >> 8)); + buf.push_back(static_cast(s_version )); + buf += s_random; + + append_tls_length_value(buf, m_session_id, 1); + + buf.push_back(get_byte(0, suite)); + buf.push_back(get_byte(1, suite)); + + buf.push_back(comp_method); + + TLS_Extensions extensions; + + if(m_secure_renegotiation) + extensions.push_back(new Renegotation_Extension(m_renegotiation_info)); + + if(m_fragment_size != 0) + extensions.push_back(new Maximum_Fragment_Length(m_fragment_size)); + + if(m_next_protocol) + extensions.push_back(new Next_Protocol_Notification(m_next_protocols)); + + buf += extensions.serialize(); + + return buf; + } + /* * Create a new Server Hello Done message */ @@ -182,20 +182,20 @@ Server_Hello_Done::Server_Hello_Done(Record_Writer& writer, } /* -* Serialize a Server Hello Done message +* Deserialize a Server Hello Done message */ -MemoryVector Server_Hello_Done::serialize() const +Server_Hello_Done::Server_Hello_Done(const MemoryRegion& buf) { - return MemoryVector(); + if(buf.size()) + throw Decoding_Error("Server_Hello_Done: Must be empty, and is not"); } /* -* Deserialize a Server Hello Done message +* Serialize a Server Hello Done message */ -void Server_Hello_Done::deserialize(const MemoryRegion& buf) +MemoryVector Server_Hello_Done::serialize() const { - if(buf.size()) - throw Decoding_Error("Server_Hello_Done: Must be empty, and is not"); + return MemoryVector(); } } diff --git a/src/tls/s_kex.cpp b/src/tls/s_kex.cpp index 69531c7c4..7008c89de 100644 --- a/src/tls/s_kex.cpp +++ b/src/tls/s_kex.cpp @@ -73,7 +73,7 @@ MemoryVector Server_Key_Exchange::serialize_params() const /** * Deserialize a Server Key Exchange message */ -void Server_Key_Exchange::deserialize(const MemoryRegion& buf) +Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion& buf) { if(buf.size() < 6) throw Decoding_Error("Server_Key_Exchange: Packet corrupted"); diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h index 0b43545dc..d3735972e 100644 --- a/src/tls/tls_messages.h +++ b/src/tls/tls_messages.h @@ -37,7 +37,6 @@ class Handshake_Message private: Handshake_Message& operator=(const Handshake_Message&) { return (*this); } virtual MemoryVector serialize() const = 0; - virtual void deserialize(const MemoryRegion&) = 0; }; MemoryVector make_hello_random(RandomNumberGenerator& rng); @@ -179,10 +178,9 @@ class Server_Hello : public Handshake_Message const std::vector& next_protocols, RandomNumberGenerator& rng); - Server_Hello(const MemoryRegion& buf) { deserialize(buf); } + Server_Hello(const MemoryRegion& buf); private: MemoryVector serialize() const; - void deserialize(const MemoryRegion&); Version_Code s_version; MemoryVector m_session_id, s_random; @@ -224,7 +222,6 @@ class Client_Key_Exchange : public Handshake_Message Version_Code using_version); private: MemoryVector serialize() const; - void deserialize(const MemoryRegion&); SecureVector key_material, pre_master; bool include_length; @@ -246,10 +243,10 @@ class Certificate : public Handshake_Message TLS_Handshake_Hash& hash, const std::vector& certs); - Certificate(const MemoryRegion& buf) { deserialize(buf); } + Certificate(const MemoryRegion& buf); private: MemoryVector serialize() const; - void deserialize(const MemoryRegion&); + std::vector certs; }; @@ -270,10 +267,9 @@ class Certificate_Req : public Handshake_Message const std::vector& types = std::vector()); - Certificate_Req(const MemoryRegion& buf) { deserialize(buf); } + Certificate_Req(const MemoryRegion& buf); private: MemoryVector serialize() const; - void deserialize(const MemoryRegion&); std::vector names; std::vector types; @@ -300,10 +296,9 @@ class Certificate_Verify : public Handshake_Message RandomNumberGenerator& rng, const Private_Key* key); - Certificate_Verify(const MemoryRegion& buf) { deserialize(buf); } + Certificate_Verify(const MemoryRegion& buf); private: MemoryVector serialize() const; - void deserialize(const MemoryRegion&); MemoryVector signature; }; @@ -326,10 +321,9 @@ class Finished : public Handshake_Message TLS_Handshake_State* state, Connection_Side side); - Finished(const MemoryRegion& buf) { deserialize(buf); } + Finished(const MemoryRegion& buf); private: MemoryVector serialize() const; - void deserialize(const MemoryRegion&); Connection_Side side; MemoryVector verification_data; @@ -344,10 +338,9 @@ class Hello_Request : public Handshake_Message Handshake_Type type() const { return HELLO_REQUEST; } Hello_Request(Record_Writer& writer); - Hello_Request(const MemoryRegion& buf) { deserialize(buf); } + Hello_Request(const MemoryRegion& buf); private: MemoryVector serialize() const; - void deserialize(const MemoryRegion&); }; /** @@ -367,11 +360,10 @@ class Server_Key_Exchange : public Handshake_Message RandomNumberGenerator& rng, const Private_Key* priv_key); - Server_Key_Exchange(const MemoryRegion& buf) { deserialize(buf); } + Server_Key_Exchange(const MemoryRegion& buf); private: MemoryVector serialize() const; MemoryVector serialize_params() const; - void deserialize(const MemoryRegion&); std::vector params; MemoryVector signature; @@ -386,10 +378,9 @@ class Server_Hello_Done : public Handshake_Message Handshake_Type type() const { return SERVER_HELLO_DONE; } Server_Hello_Done(Record_Writer& writer, TLS_Handshake_Hash& hash); - Server_Hello_Done(const MemoryRegion& buf) { deserialize(buf); } + Server_Hello_Done(const MemoryRegion& buf); private: MemoryVector serialize() const; - void deserialize(const MemoryRegion&); }; /** @@ -406,10 +397,9 @@ class Next_Protocol : public Handshake_Message TLS_Handshake_Hash& hash, const std::string& protocol); - Next_Protocol(const MemoryRegion& buf) { deserialize(buf); } + Next_Protocol(const MemoryRegion& buf); private: MemoryVector serialize() const; - void deserialize(const MemoryRegion&); std::string m_protocol; }; diff --git a/src/tls/tls_reader.h b/src/tls/tls_reader.h index ef36912d3..6a0bcd5b1 100644 --- a/src/tls/tls_reader.h +++ b/src/tls/tls_reader.h @@ -26,13 +26,10 @@ class TLS_Data_Reader TLS_Data_Reader(const MemoryRegion& buf_in) : buf(buf_in), offset(0) {} - ~TLS_Data_Reader() + void assert_done() const { if(has_remaining()) - { - abort(); throw Decoding_Error("Extra bytes at end of message"); - } } size_t remaining_bytes() const -- cgit v1.2.3