diff options
Diffstat (limited to 'src/tls/tls_channel.cpp')
-rw-r--r-- | src/tls/tls_channel.cpp | 41 |
1 files changed, 20 insertions, 21 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp index 382671afe..9fb41c9f6 100644 --- a/src/tls/tls_channel.cpp +++ b/src/tls/tls_channel.cpp @@ -183,7 +183,7 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) { while(buf_size) { - byte rec_type = CONNECTION_CLOSED; + byte rec_type = NO_RECORD; std::vector<byte> record; u64bit record_number = 0; Protocol_Version record_version; @@ -214,15 +214,14 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) if(buf_size == 0 && needed != 0) return needed; // need more data to complete record + if(rec_type == NO_RECORD) + continue; + if(record.size() > m_max_fragment) throw TLS_Exception(Alert::RECORD_OVERFLOW, "Plaintext record is too large"); - if(rec_type == CONNECTION_CLOSED) - { - continue; - } - else if(rec_type == HANDSHAKE || rec_type == CHANGE_CIPHER_SPEC) + if(rec_type == HANDSHAKE || rec_type == CHANGE_CIPHER_SPEC) { if(!m_pending_state) { @@ -247,6 +246,9 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) } else if(rec_type == HEARTBEAT && peer_supports_heartbeats()) { + if(!m_active_state) + throw Unexpected_Message("Heartbeat sent before handshake done"); + Heartbeat_Message heartbeat(record); const std::vector<byte>& payload = heartbeat.payload(); @@ -269,20 +271,16 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) } else if(rec_type == APPLICATION_DATA) { - if(m_active_state) - { - /* - * OpenSSL among others sends empty records in versions - * before TLS v1.1 in order to randomize the IV of the - * following record. Avoid spurious callbacks. - */ - if(record.size() > 0) - m_proc_fn(&record[0], record.size(), Alert()); - } - else - { + if(!m_active_state) throw Unexpected_Message("Application data before handshake done"); - } + + /* + * OpenSSL among others sends empty records in versions + * before TLS v1.1 in order to randomize the IV of the + * following record. Avoid spurious callbacks. + */ + if(record.size() > 0) + m_proc_fn(&record[0], record.size(), Alert()); } else if(rec_type == ALERT) { @@ -317,8 +315,9 @@ size_t Channel::received_data(const byte buf[], size_t buf_size) } } else - throw Unexpected_Message("Unknown TLS message type " + - std::to_string(rec_type) + " received"); + throw Unexpected_Message("Unknown record type " + + std::to_string(readbuf[0]) + + " from counterparty"); } return 0; // on a record boundary |