aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/tls/tls_channel.cpp2
-rw-r--r--src/tls/tls_channel.h2
-rw-r--r--src/tls/tls_client.cpp9
-rw-r--r--src/tls/tls_client.h6
-rw-r--r--src/tls/tls_server.cpp9
-rw-r--r--src/tls/tls_server.h10
6 files changed, 27 insertions, 11 deletions
diff --git a/src/tls/tls_channel.cpp b/src/tls/tls_channel.cpp
index 1b4cb407e..61dff6b03 100644
--- a/src/tls/tls_channel.cpp
+++ b/src/tls/tls_channel.cpp
@@ -177,7 +177,7 @@ void Channel::read_handshake(byte rec_type,
if(rec_type == HANDSHAKE)
{
if(!m_state)
- m_state = new Handshake_State(new Stream_Handshake_Reader);
+ m_state = new_handshake_state();
m_state->handshake_reader()->add_input(&rec_buf[0], rec_buf.size());
}
diff --git a/src/tls/tls_channel.h b/src/tls/tls_channel.h
index db78cdc69..c2193b282 100644
--- a/src/tls/tls_channel.h
+++ b/src/tls/tls_channel.h
@@ -107,6 +107,8 @@ class BOTAN_DLL Channel
virtual void alert_notify(const Alert& alert) = 0;
+ virtual class Handshake_State* new_handshake_state() const = 0;
+
class Secure_Renegotiation_State
{
public:
diff --git a/src/tls/tls_client.cpp b/src/tls/tls_client.cpp
index db967cee3..2915fc68d 100644
--- a/src/tls/tls_client.cpp
+++ b/src/tls/tls_client.cpp
@@ -35,7 +35,7 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn,
{
m_writer.set_version(Protocol_Version::SSL_V3);
- m_state = new Handshake_State(new Stream_Handshake_Reader);
+ m_state = new_handshake_state();
m_state->set_expected_next(SERVER_HELLO);
m_state->client_npn_cb = next_protocol;
@@ -82,6 +82,11 @@ Client::Client(std::function<void (const byte[], size_t)> output_fn,
m_secure_renegotiation.update(m_state->client_hello);
}
+Handshake_State* Client::new_handshake_state() const
+ {
+ return new Handshake_State(new Stream_Handshake_Reader);
+ }
+
/*
* Send a new client hello to renegotiate
*/
@@ -91,7 +96,7 @@ void Client::renegotiate(bool force_full_renegotiation)
return; // currently in active handshake
delete m_state;
- m_state = new Handshake_State(new Stream_Handshake_Reader);
+ m_state = new_handshake_state();
m_state->set_expected_next(SERVER_HELLO);
diff --git a/src/tls/tls_client.h b/src/tls/tls_client.h
index fc08ca796..cd9da78b9 100644
--- a/src/tls/tls_client.h
+++ b/src/tls/tls_client.h
@@ -52,12 +52,14 @@ class BOTAN_DLL Client : public Channel
std::function<std::string (std::vector<std::string>)> next_protocol =
std::function<std::string (std::vector<std::string>)>());
- void renegotiate(bool force_full_renegotiation = false);
+ void renegotiate(bool force_full_renegotiation = false) override;
private:
void process_handshake_msg(Handshake_Type type,
const std::vector<byte>& contents) override;
- void alert_notify(const Alert& alert);
+ void alert_notify(const Alert& alert) override;
+
+ Handshake_State* new_handshake_state() const override;
const Policy& m_policy;
RandomNumberGenerator& m_rng;
diff --git a/src/tls/tls_server.cpp b/src/tls/tls_server.cpp
index d2a51fabd..e89ec7b4a 100644
--- a/src/tls/tls_server.cpp
+++ b/src/tls/tls_server.cpp
@@ -200,6 +200,11 @@ Server::Server(std::function<void (const byte[], size_t)> output_fn,
{
}
+Handshake_State* Server::new_handshake_state() const
+ {
+ return new Handshake_State(new Stream_Handshake_Reader);
+ }
+
/*
* Send a hello request to the client
*/
@@ -208,7 +213,7 @@ void Server::renegotiate(bool force_full_renegotiation)
if(m_state)
return; // currently in handshake
- m_state = new Handshake_State(new Stream_Handshake_Reader);
+ m_state = new_handshake_state();
m_state->allow_session_resumption = !force_full_renegotiation;
m_state->set_expected_next(CLIENT_HELLO);
@@ -235,7 +240,7 @@ void Server::read_handshake(byte rec_type,
{
if(rec_type == HANDSHAKE && !m_state)
{
- m_state = new Handshake_State(new Stream_Handshake_Reader);
+ m_state = new_handshake_state();
m_state->set_expected_next(CLIENT_HELLO);
}
diff --git a/src/tls/tls_server.h b/src/tls/tls_server.h
index 89e27fa92..b2143b182 100644
--- a/src/tls/tls_server.h
+++ b/src/tls/tls_server.h
@@ -35,7 +35,7 @@ class BOTAN_DLL Server : public Channel
const std::vector<std::string>& protocols =
std::vector<std::string>());
- void renegotiate(bool force_full_renegotiation = false);
+ void renegotiate(bool force_full_renegotiation = false) override;
/**
* Return the server name indicator, if sent by the client
@@ -50,11 +50,13 @@ class BOTAN_DLL Server : public Channel
{ return m_next_protocol; }
private:
- void read_handshake(byte, const std::vector<byte>&);
+ void read_handshake(byte, const std::vector<byte>&) override;
- void process_handshake_msg(Handshake_Type, const std::vector<byte>&);
+ void process_handshake_msg(Handshake_Type, const std::vector<byte>&) override;
- void alert_notify(const Alert& alert);
+ void alert_notify(const Alert& alert) override;
+
+ Handshake_State* new_handshake_state() const override;
const Policy& m_policy;
RandomNumberGenerator& m_rng;