diff options
author | lloyd <[email protected]> | 2012-01-19 17:07:03 +0000 |
---|---|---|
committer | lloyd <[email protected]> | 2012-01-19 17:07:03 +0000 |
commit | 239241568d4d3ff14d2d1994e5829f3d548f2078 (patch) | |
tree | a21fe21d7c229f00ae06859dbe1768ead68e6d13 | |
parent | 30104a60568b392886c1d717a7ca006378552e4d (diff) |
Remove Handshake_Message::deserialize which was an unnecessary hook.
Instead deserialize directly in the constructors that are passed the
raw message data. This makes it easier to pass contextual information
needed for decoding (eg, version numbers) where necessary.
-rw-r--r-- | src/tls/c_hello.cpp | 14 | ||||
-rw-r--r-- | src/tls/c_kex.cpp | 22 | ||||
-rw-r--r-- | src/tls/cert_req.cpp | 90 | ||||
-rw-r--r-- | src/tls/cert_ver.cpp | 18 | ||||
-rw-r--r-- | src/tls/finished.cpp | 2 | ||||
-rw-r--r-- | src/tls/next_protocol.cpp | 18 | ||||
-rw-r--r-- | src/tls/s_hello.cpp | 84 | ||||
-rw-r--r-- | src/tls/s_kex.cpp | 2 | ||||
-rw-r--r-- | src/tls/tls_messages.h | 30 | ||||
-rw-r--r-- | src/tls/tls_reader.h | 5 |
10 files changed, 132 insertions, 153 deletions
diff --git a/src/tls/c_hello.cpp b/src/tls/c_hello.cpp index 6c4964fb1..60f6de487 100644 --- a/src/tls/c_hello.cpp +++ b/src/tls/c_hello.cpp @@ -56,20 +56,20 @@ Hello_Request::Hello_Request(Record_Writer& writer) } /* -* Serialize a Hello Request message +* Deserialize a Hello Request message */ -MemoryVector<byte> Hello_Request::serialize() const +Hello_Request::Hello_Request(const MemoryRegion<byte>& buf) { - return MemoryVector<byte>(); + if(buf.size()) + throw Decoding_Error("Hello_Request: Must be empty, and is not"); } /* -* Deserialize a Hello Request message +* Serialize a Hello Request message */ -void Hello_Request::deserialize(const MemoryRegion<byte>& buf) +MemoryVector<byte> Hello_Request::serialize() const { - if(buf.size()) - throw Decoding_Error("Hello_Request: Must be empty, and is not"); + return MemoryVector<byte>(); } /* diff --git a/src/tls/c_kex.cpp b/src/tls/c_kex.cpp index f95f74931..3d79116ca 100644 --- a/src/tls/c_kex.cpp +++ b/src/tls/c_kex.cpp @@ -90,7 +90,13 @@ Client_Key_Exchange::Client_Key_Exchange(const MemoryRegion<byte>& contents, if(using_version == SSL_V3 && (suite.kex_type() == TLS_ALGO_KEYEXCH_NOKEX)) include_length = false; - deserialize(contents); + if(include_length) + { + TLS_Data_Reader reader(contents); + key_material = reader.get_range<byte>(2, 0, 65535); + } + else + key_material = contents; } /* @@ -109,20 +115,6 @@ MemoryVector<byte> Client_Key_Exchange::serialize() const } /* -* Deserialize a Client Key Exchange message -*/ -void Client_Key_Exchange::deserialize(const MemoryRegion<byte>& buf) - { - if(include_length) - { - TLS_Data_Reader reader(buf); - key_material = reader.get_range<byte>(2, 0, 65535); - } - else - key_material = buf; - } - -/* * Return the pre_master_secret */ SecureVector<byte> diff --git a/src/tls/cert_req.cpp b/src/tls/cert_req.cpp index 0168e4b7d..bdb25057c 100644 --- a/src/tls/cert_req.cpp +++ b/src/tls/cert_req.cpp @@ -37,29 +37,9 @@ Certificate_Req::Certificate_Req(Record_Writer& writer, } /** -* Serialize a Certificate Request message -*/ -MemoryVector<byte> Certificate_Req::serialize() const - { - MemoryVector<byte> buf; - - append_tls_length_value(buf, types, 1); - - for(size_t i = 0; i != names.size(); ++i) - { - DER_Encoder encoder; - encoder.encode(names[i]); - - append_tls_length_value(buf, encoder.get_contents(), 2); - } - - return buf; - } - -/** * Deserialize a Certificate Request message */ -void Certificate_Req::deserialize(const MemoryRegion<byte>& buf) +Certificate_Req::Certificate_Req(const MemoryRegion<byte>& buf) { if(buf.size() < 4) throw Decoding_Error("Certificate_Req: Bad certificate request"); @@ -96,43 +76,40 @@ void Certificate_Req::deserialize(const MemoryRegion<byte>& buf) } /** -* Create a new Certificate message +* Serialize a Certificate Request message */ -Certificate::Certificate(Record_Writer& writer, - TLS_Handshake_Hash& hash, - const std::vector<X509_Certificate>& cert_list) +MemoryVector<byte> Certificate_Req::serialize() const { - certs = cert_list; - send(writer, hash); - } + MemoryVector<byte> buf; -/** -* Serialize a Certificate message -*/ -MemoryVector<byte> Certificate::serialize() const - { - MemoryVector<byte> buf(3); + append_tls_length_value(buf, types, 1); - for(size_t i = 0; i != certs.size(); ++i) + for(size_t i = 0; i != names.size(); ++i) { - MemoryVector<byte> raw_cert = certs[i].BER_encode(); - const size_t cert_size = raw_cert.size(); - for(size_t i = 0; i != 3; ++i) - buf.push_back(get_byte<u32bit>(i+1, cert_size)); - buf += raw_cert; - } + DER_Encoder encoder; + encoder.encode(names[i]); - const size_t buf_size = buf.size() - 3; - for(size_t i = 0; i != 3; ++i) - buf[i] = get_byte<u32bit>(i+1, buf_size); + append_tls_length_value(buf, encoder.get_contents(), 2); + } return buf; } /** +* Create a new Certificate message +*/ +Certificate::Certificate(Record_Writer& writer, + TLS_Handshake_Hash& hash, + const std::vector<X509_Certificate>& cert_list) + { + certs = cert_list; + send(writer, hash); + } + +/** * Deserialize a Certificate message */ -void Certificate::deserialize(const MemoryRegion<byte>& buf) +Certificate::Certificate(const MemoryRegion<byte>& buf) { if(buf.size() < 3) throw Decoding_Error("Certificate: Message malformed"); @@ -163,4 +140,27 @@ void Certificate::deserialize(const MemoryRegion<byte>& buf) } } +/** +* Serialize a Certificate message +*/ +MemoryVector<byte> Certificate::serialize() const + { + MemoryVector<byte> buf(3); + + for(size_t i = 0; i != certs.size(); ++i) + { + MemoryVector<byte> raw_cert = certs[i].BER_encode(); + const size_t cert_size = raw_cert.size(); + for(size_t i = 0; i != 3; ++i) + buf.push_back(get_byte<u32bit>(i+1, cert_size)); + buf += raw_cert; + } + + const size_t buf_size = buf.size() - 3; + for(size_t i = 0; i != 3; ++i) + buf[i] = get_byte<u32bit>(i+1, buf_size); + + return buf; + } + } diff --git a/src/tls/cert_ver.cpp b/src/tls/cert_ver.cpp index 81d529e88..77d9fe74b 100644 --- a/src/tls/cert_ver.cpp +++ b/src/tls/cert_ver.cpp @@ -54,6 +54,15 @@ Certificate_Verify::Certificate_Verify(Record_Writer& writer, } /* +* Deserialize a Certificate Verify message +*/ +Certificate_Verify::Certificate_Verify(const MemoryRegion<byte>& buf) + { + TLS_Data_Reader reader(buf); + signature = reader.get_range<byte>(2, 0, 65535); + } + +/* * Serialize a Certificate Verify message */ MemoryVector<byte> Certificate_Verify::serialize() const @@ -69,15 +78,6 @@ MemoryVector<byte> Certificate_Verify::serialize() const } /* -* Deserialize a Certificate Verify message -*/ -void Certificate_Verify::deserialize(const MemoryRegion<byte>& buf) - { - TLS_Data_Reader reader(buf); - signature = reader.get_range<byte>(2, 0, 65535); - } - -/* * Verify a Certificate Verify message */ bool Certificate_Verify::verify(const X509_Certificate& cert, diff --git a/src/tls/finished.cpp b/src/tls/finished.cpp index 836512f81..baa663798 100644 --- a/src/tls/finished.cpp +++ b/src/tls/finished.cpp @@ -81,7 +81,7 @@ MemoryVector<byte> Finished::serialize() const /* * Deserialize a Finished message */ -void Finished::deserialize(const MemoryRegion<byte>& buf) +Finished::Finished(const MemoryRegion<byte>& buf) { verification_data = buf; } diff --git a/src/tls/next_protocol.cpp b/src/tls/next_protocol.cpp index 2d2e2e599..a0d4278f1 100644 --- a/src/tls/next_protocol.cpp +++ b/src/tls/next_protocol.cpp @@ -19,6 +19,15 @@ Next_Protocol::Next_Protocol(Record_Writer& writer, send(writer, hash); } +Next_Protocol::Next_Protocol(const MemoryRegion<byte>& buf) + { + TLS_Data_Reader reader(buf); + + m_protocol = reader.get_string(1, 0, 255); + + reader.get_range_vector<byte>(1, 0, 255); // padding, ignored + } + MemoryVector<byte> Next_Protocol::serialize() const { MemoryVector<byte> buf; @@ -38,13 +47,4 @@ MemoryVector<byte> Next_Protocol::serialize() const return buf; } -void Next_Protocol::deserialize(const MemoryRegion<byte>& buf) - { - TLS_Data_Reader reader(buf); - - m_protocol = reader.get_string(1, 0, 255); - - reader.get_range_vector<byte>(1, 0, 255); // padding, ignored - } - } diff --git a/src/tls/s_hello.cpp b/src/tls/s_hello.cpp index 90e18ae90..fa185599d 100644 --- a/src/tls/s_hello.cpp +++ b/src/tls/s_hello.cpp @@ -90,43 +90,9 @@ Server_Hello::Server_Hello(Record_Writer& writer, } /* -* Serialize a Server Hello message -*/ -MemoryVector<byte> Server_Hello::serialize() const - { - MemoryVector<byte> buf; - - buf.push_back(static_cast<byte>(s_version >> 8)); - buf.push_back(static_cast<byte>(s_version )); - buf += s_random; - - append_tls_length_value(buf, m_session_id, 1); - - buf.push_back(get_byte(0, suite)); - buf.push_back(get_byte(1, suite)); - - buf.push_back(comp_method); - - TLS_Extensions extensions; - - if(m_secure_renegotiation) - extensions.push_back(new Renegotation_Extension(m_renegotiation_info)); - - if(m_fragment_size != 0) - extensions.push_back(new Maximum_Fragment_Length(m_fragment_size)); - - if(m_next_protocol) - extensions.push_back(new Next_Protocol_Notification(m_next_protocols)); - - buf += extensions.serialize(); - - return buf; - } - -/* * Deserialize a Server Hello message */ -void Server_Hello::deserialize(const MemoryRegion<byte>& buf) +Server_Hello::Server_Hello(const MemoryRegion<byte>& buf) { m_secure_renegotiation = false; m_next_protocol = false; @@ -173,6 +139,40 @@ void Server_Hello::deserialize(const MemoryRegion<byte>& buf) } /* +* Serialize a Server Hello message +*/ +MemoryVector<byte> Server_Hello::serialize() const + { + MemoryVector<byte> buf; + + buf.push_back(static_cast<byte>(s_version >> 8)); + buf.push_back(static_cast<byte>(s_version )); + buf += s_random; + + append_tls_length_value(buf, m_session_id, 1); + + buf.push_back(get_byte(0, suite)); + buf.push_back(get_byte(1, suite)); + + buf.push_back(comp_method); + + TLS_Extensions extensions; + + if(m_secure_renegotiation) + extensions.push_back(new Renegotation_Extension(m_renegotiation_info)); + + if(m_fragment_size != 0) + extensions.push_back(new Maximum_Fragment_Length(m_fragment_size)); + + if(m_next_protocol) + extensions.push_back(new Next_Protocol_Notification(m_next_protocols)); + + buf += extensions.serialize(); + + return buf; + } + +/* * Create a new Server Hello Done message */ Server_Hello_Done::Server_Hello_Done(Record_Writer& writer, @@ -182,20 +182,20 @@ Server_Hello_Done::Server_Hello_Done(Record_Writer& writer, } /* -* Serialize a Server Hello Done message +* Deserialize a Server Hello Done message */ -MemoryVector<byte> Server_Hello_Done::serialize() const +Server_Hello_Done::Server_Hello_Done(const MemoryRegion<byte>& buf) { - return MemoryVector<byte>(); + if(buf.size()) + throw Decoding_Error("Server_Hello_Done: Must be empty, and is not"); } /* -* Deserialize a Server Hello Done message +* Serialize a Server Hello Done message */ -void Server_Hello_Done::deserialize(const MemoryRegion<byte>& buf) +MemoryVector<byte> Server_Hello_Done::serialize() const { - if(buf.size()) - throw Decoding_Error("Server_Hello_Done: Must be empty, and is not"); + return MemoryVector<byte>(); } } diff --git a/src/tls/s_kex.cpp b/src/tls/s_kex.cpp index 69531c7c4..7008c89de 100644 --- a/src/tls/s_kex.cpp +++ b/src/tls/s_kex.cpp @@ -73,7 +73,7 @@ MemoryVector<byte> Server_Key_Exchange::serialize_params() const /** * Deserialize a Server Key Exchange message */ -void Server_Key_Exchange::deserialize(const MemoryRegion<byte>& buf) +Server_Key_Exchange::Server_Key_Exchange(const MemoryRegion<byte>& buf) { if(buf.size() < 6) throw Decoding_Error("Server_Key_Exchange: Packet corrupted"); diff --git a/src/tls/tls_messages.h b/src/tls/tls_messages.h index 0b43545dc..d3735972e 100644 --- a/src/tls/tls_messages.h +++ b/src/tls/tls_messages.h @@ -37,7 +37,6 @@ class Handshake_Message private: Handshake_Message& operator=(const Handshake_Message&) { return (*this); } virtual MemoryVector<byte> serialize() const = 0; - virtual void deserialize(const MemoryRegion<byte>&) = 0; }; MemoryVector<byte> make_hello_random(RandomNumberGenerator& rng); @@ -179,10 +178,9 @@ class Server_Hello : public Handshake_Message const std::vector<std::string>& next_protocols, RandomNumberGenerator& rng); - Server_Hello(const MemoryRegion<byte>& buf) { deserialize(buf); } + Server_Hello(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); Version_Code s_version; MemoryVector<byte> m_session_id, s_random; @@ -224,7 +222,6 @@ class Client_Key_Exchange : public Handshake_Message Version_Code using_version); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); SecureVector<byte> key_material, pre_master; bool include_length; @@ -246,10 +243,10 @@ class Certificate : public Handshake_Message TLS_Handshake_Hash& hash, const std::vector<X509_Certificate>& certs); - Certificate(const MemoryRegion<byte>& buf) { deserialize(buf); } + Certificate(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); + std::vector<X509_Certificate> certs; }; @@ -270,10 +267,9 @@ class Certificate_Req : public Handshake_Message const std::vector<Certificate_Type>& types = std::vector<Certificate_Type>()); - Certificate_Req(const MemoryRegion<byte>& buf) { deserialize(buf); } + Certificate_Req(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); std::vector<X509_DN> names; std::vector<Certificate_Type> types; @@ -300,10 +296,9 @@ class Certificate_Verify : public Handshake_Message RandomNumberGenerator& rng, const Private_Key* key); - Certificate_Verify(const MemoryRegion<byte>& buf) { deserialize(buf); } + Certificate_Verify(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); MemoryVector<byte> signature; }; @@ -326,10 +321,9 @@ class Finished : public Handshake_Message TLS_Handshake_State* state, Connection_Side side); - Finished(const MemoryRegion<byte>& buf) { deserialize(buf); } + Finished(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); Connection_Side side; MemoryVector<byte> verification_data; @@ -344,10 +338,9 @@ class Hello_Request : public Handshake_Message Handshake_Type type() const { return HELLO_REQUEST; } Hello_Request(Record_Writer& writer); - Hello_Request(const MemoryRegion<byte>& buf) { deserialize(buf); } + Hello_Request(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); }; /** @@ -367,11 +360,10 @@ class Server_Key_Exchange : public Handshake_Message RandomNumberGenerator& rng, const Private_Key* priv_key); - Server_Key_Exchange(const MemoryRegion<byte>& buf) { deserialize(buf); } + Server_Key_Exchange(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; MemoryVector<byte> serialize_params() const; - void deserialize(const MemoryRegion<byte>&); std::vector<BigInt> params; MemoryVector<byte> signature; @@ -386,10 +378,9 @@ class Server_Hello_Done : public Handshake_Message Handshake_Type type() const { return SERVER_HELLO_DONE; } Server_Hello_Done(Record_Writer& writer, TLS_Handshake_Hash& hash); - Server_Hello_Done(const MemoryRegion<byte>& buf) { deserialize(buf); } + Server_Hello_Done(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); }; /** @@ -406,10 +397,9 @@ class Next_Protocol : public Handshake_Message TLS_Handshake_Hash& hash, const std::string& protocol); - Next_Protocol(const MemoryRegion<byte>& buf) { deserialize(buf); } + Next_Protocol(const MemoryRegion<byte>& buf); private: MemoryVector<byte> serialize() const; - void deserialize(const MemoryRegion<byte>&); std::string m_protocol; }; diff --git a/src/tls/tls_reader.h b/src/tls/tls_reader.h index ef36912d3..6a0bcd5b1 100644 --- a/src/tls/tls_reader.h +++ b/src/tls/tls_reader.h @@ -26,13 +26,10 @@ class TLS_Data_Reader TLS_Data_Reader(const MemoryRegion<byte>& buf_in) : buf(buf_in), offset(0) {} - ~TLS_Data_Reader() + void assert_done() const { if(has_remaining()) - { - abort(); throw Decoding_Error("Extra bytes at end of message"); - } } size_t remaining_bytes() const |