aboutsummaryrefslogtreecommitdiffstats
path: root/LibOVR/Src/Net/OVR_Session.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'LibOVR/Src/Net/OVR_Session.cpp')
-rw-r--r--LibOVR/Src/Net/OVR_Session.cpp455
1 files changed, 289 insertions, 166 deletions
diff --git a/LibOVR/Src/Net/OVR_Session.cpp b/LibOVR/Src/Net/OVR_Session.cpp
index 508f0c9..4049c6c 100644
--- a/LibOVR/Src/Net/OVR_Session.cpp
+++ b/LibOVR/Src/Net/OVR_Session.cpp
@@ -26,36 +26,95 @@ limitations under the License.
#include "OVR_Session.h"
#include "OVR_PacketizedTCPSocket.h"
-#include "../Kernel/OVR_Log.h"
-#include "../Service/Service_NetSessionCommon.h"
+#include "Kernel/OVR_Log.h"
+#include "Service/Service_NetSessionCommon.h"
namespace OVR { namespace Net {
+// The SDK version requested by the user.
+SDKVersion RuntimeSDKVersion;
+
+
//-----------------------------------------------------------------------------
// Protocol
-static const char* OfficialHelloString = "OculusVR_Hello";
+static const char* OfficialHelloString = "OculusVR_Hello";
static const char* OfficialAuthorizedString = "OculusVR_Authorized";
-void RPC_C2S_Hello::Generate(Net::BitStream* bs)
+bool RPC_C2S_Hello::Serialize(bool writeToBitstream, Net::BitStream* bs)
+{
+ bs->Serialize(writeToBitstream, HelloString);
+ bs->Serialize(writeToBitstream, MajorVersion);
+ bs->Serialize(writeToBitstream, MinorVersion);
+ if (!bs->Serialize(writeToBitstream, PatchVersion))
+ return false;
+
+ // If an older client is connecting to us,
+ if (!writeToBitstream && (MajorVersion * 100) + (MinorVersion * 10) + PatchVersion < 121)
+ {
+ // The following was version code was added to RPC version 1.2
+ // without bumping it up to 1.3 and introducing an incompatibility.
+ // We can do this because an older server will not read this additional data.
+ return true;
+ }
+
+ bs->Serialize(writeToBitstream, CodeVersion.ProductVersion);
+ bs->Serialize(writeToBitstream, CodeVersion.MajorVersion);
+ bs->Serialize(writeToBitstream, CodeVersion.MinorVersion);
+ bs->Serialize(writeToBitstream, CodeVersion.RequestedMinorVersion);
+ bs->Serialize(writeToBitstream, CodeVersion.PatchVersion);
+ bs->Serialize(writeToBitstream, CodeVersion.BuildNumber);
+ return bs->Serialize(writeToBitstream, CodeVersion.FeatureVersion);
+}
+
+void RPC_C2S_Hello::ClientGenerate(Net::BitStream* bs)
{
RPC_C2S_Hello hello;
- hello.HelloString = OfficialHelloString;
+ hello.HelloString = OfficialHelloString;
hello.MajorVersion = RPCVersion_Major;
hello.MinorVersion = RPCVersion_Minor;
hello.PatchVersion = RPCVersion_Patch;
- hello.Serialize(bs);
+ OVR_ASSERT(OVR::Net::RuntimeSDKVersion.ProductVersion != UINT16_MAX);
+ hello.CodeVersion = OVR::Net::RuntimeSDKVersion; // This should have been set to a value earlier in the first steps of ovr initialization.
+ hello.Serialize(true, bs);
}
-bool RPC_C2S_Hello::Validate()
+bool RPC_C2S_Hello::ServerValidate()
{
+ // Server checks the protocol version
return MajorVersion == RPCVersion_Major &&
MinorVersion <= RPCVersion_Minor &&
HelloString.CompareNoCase(OfficialHelloString) == 0;
}
-void RPC_S2C_Authorization::Generate(Net::BitStream* bs, String errorString)
+bool RPC_S2C_Authorization::Serialize(bool writeToBitstream, Net::BitStream* bs)
+{
+ bs->Serialize(writeToBitstream, AuthString);
+ bs->Serialize(writeToBitstream, MajorVersion);
+ bs->Serialize(writeToBitstream, MinorVersion);
+ if (!bs->Serialize(writeToBitstream, PatchVersion))
+ return false;
+
+ // If an older client is connecting to us,
+ if (!writeToBitstream && (MajorVersion * 100) + (MinorVersion * 10) + PatchVersion < 121)
+ {
+ // The following was version code was added to RPC version 1.2
+ // without bumping it up to 1.3 and introducing an incompatibility.
+ // We can do this because an older server will not read this additional data.
+ return true;
+ }
+
+ bs->Serialize(writeToBitstream, CodeVersion.ProductVersion);
+ bs->Serialize(writeToBitstream, CodeVersion.MajorVersion);
+ bs->Serialize(writeToBitstream, CodeVersion.MinorVersion);
+ bs->Serialize(writeToBitstream, CodeVersion.RequestedMinorVersion);
+ bs->Serialize(writeToBitstream, CodeVersion.PatchVersion);
+ bs->Serialize(writeToBitstream, CodeVersion.BuildNumber);
+ return bs->Serialize(writeToBitstream, CodeVersion.FeatureVersion);
+}
+
+void RPC_S2C_Authorization::ServerGenerate(Net::BitStream* bs, String errorString)
{
RPC_S2C_Authorization auth;
if (errorString.IsEmpty())
@@ -69,16 +128,33 @@ void RPC_S2C_Authorization::Generate(Net::BitStream* bs, String errorString)
auth.MajorVersion = RPCVersion_Major;
auth.MinorVersion = RPCVersion_Minor;
auth.PatchVersion = RPCVersion_Patch;
- auth.Serialize(bs);
+ // Leave CurrentSDKVersion as it is.
+ auth.Serialize(true, bs);
}
-bool RPC_S2C_Authorization::Validate()
+bool RPC_S2C_Authorization::ClientValidate()
{
return AuthString.CompareNoCase(OfficialAuthorizedString) == 0;
}
//-----------------------------------------------------------------------------
+// SingleProcess
+
+static bool SingleProcess = false;
+
+void Session::SetSingleProcess(bool enable)
+{
+ SingleProcess = enable;
+}
+
+bool Session::IsSingleProcess()
+{
+ return SingleProcess;
+}
+
+
+//-----------------------------------------------------------------------------
// Session
void Session::Shutdown()
@@ -111,29 +187,25 @@ void Session::Shutdown()
SessionResult Session::Listen(ListenerDescription* pListenerDescription)
{
- if (pListenerDescription->Transport == TransportType_PacketizedTCP)
- {
- BerkleyListenerDescription* bld = (BerkleyListenerDescription*)pListenerDescription;
- TCPSocket* tcpSocket = (TCPSocket*)bld->BoundSocketToListenWith.GetPtr();
+ if (pListenerDescription->Transport == TransportType_PacketizedTCP)
+ {
+ BerkleyListenerDescription* bld = (BerkleyListenerDescription*)pListenerDescription;
+ TCPSocket* tcpSocket = (TCPSocket*)bld->BoundSocketToListenWith.GetPtr();
if (tcpSocket->Listen() < 0)
{
return SessionResult_ListenFailure;
}
- Lock::Locker locker(&SocketListenersLock);
+ Lock::Locker locker(&SocketListenersLock);
SocketListeners.PushBack(tcpSocket);
- }
- else if (pListenerDescription->Transport == TransportType_Loopback)
- {
- HasLoopbackListener = true;
- }
+ }
else
{
OVR_ASSERT(false);
}
- return SessionResult_OK;
+ return SessionResult_OK;
}
SessionResult Session::Connect(ConnectParameters *cp)
@@ -153,6 +225,28 @@ SessionResult Session::Connect(ConnectParameters *cp)
return SessionResult_AlreadyConnected;
}
+ // If we are already connected, don't create a duplicate connection
+ if (FullConnections.GetSizeI() > 0)
+ {
+ return SessionResult_AlreadyConnected;
+ }
+
+ // If we are already connecting, don't create a duplicate connection
+ const int count = AllConnections.GetSizeI();
+ for (int i = 0; i < count; ++i)
+ {
+ Connection* arrayItem = AllConnections[i].GetPtr();
+
+ OVR_ASSERT(arrayItem);
+ if (arrayItem) {
+ if (arrayItem->State == Client_ConnectedWait
+ || arrayItem->State == Client_Connecting)
+ {
+ return SessionResult_ConnectInProgress;
+ }
+ }
+ }
+
TCPSocketBase* tcpSock = (TCPSocketBase*)cp2->BoundSocketToConnectWith.GetPtr();
int ret = tcpSock->Connect(&cp2->RemoteAddress);
@@ -174,7 +268,6 @@ SessionResult Session::Connect(ConnectParameters *cp)
c->SetState(Client_Connecting);
AllConnections.PushBack(c);
-
}
if (cp2->Blocking)
@@ -182,11 +275,12 @@ SessionResult Session::Connect(ConnectParameters *cp)
c->WaitOnConnecting();
}
- if (c->State == State_Connected)
+ EConnectionState state = c->State;
+ if (state == State_Connected)
{
return SessionResult_OK;
}
- else if (c->State == Client_Connecting)
+ else if (state == Client_Connecting)
{
return SessionResult_ConnectInProgress;
}
@@ -195,49 +289,33 @@ SessionResult Session::Connect(ConnectParameters *cp)
return SessionResult_ConnectFailure;
}
}
- else if (cp->Transport == TransportType_Loopback)
- {
- if (HasLoopbackListener)
- {
- Ptr<Connection> c = AllocConnection(cp->Transport);
- if (!c)
- {
- return SessionResult_ConnectFailure;
- }
-
- c->Transport = cp->Transport;
- c->SetState(State_Connected);
-
- {
- Lock::Locker locker(&ConnectionsLock);
- AllConnections.PushBack(c);
- }
-
- invokeSessionEvent(&SessionListener::OnConnectionRequestAccepted, c);
- }
- else
- {
- OVR_ASSERT(false);
- }
- }
else
{
OVR_ASSERT(false);
}
- return SessionResult_OK;
+ return SessionResult_OK;
}
+static Session* SingleProcessServer = nullptr;
+
SessionResult Session::ListenPTCP(OVR::Net::BerkleyBindParameters *bbp)
{
- Ptr<PacketizedTCPSocket> listenSocket = *new OVR::Net::PacketizedTCPSocket();
+ if (Session::IsSingleProcess())
+ {
+ // Do not actually listen on a socket.
+ SingleProcessServer = this;
+ return SessionResult_OK;
+ }
+
+ Ptr<PacketizedTCPSocket> listenSocket = *new OVR::Net::PacketizedTCPSocket();
if (listenSocket->Bind(bbp) == INVALID_SOCKET)
{
return SessionResult_BindFailure;
}
- BerkleyListenerDescription bld;
- bld.BoundSocketToListenWith = listenSocket.GetPtr();
+ BerkleyListenerDescription bld;
+ bld.BoundSocketToListenWith = listenSocket.GetPtr();
bld.Transport = TransportType_PacketizedTCP;
return Listen(&bld);
@@ -245,16 +323,46 @@ SessionResult Session::ListenPTCP(OVR::Net::BerkleyBindParameters *bbp)
SessionResult Session::ConnectPTCP(OVR::Net::BerkleyBindParameters* bbp, SockAddr* remoteAddress, bool blocking)
{
+ if (Session::IsSingleProcess())
+ {
+ OVR_ASSERT(SingleProcessServer); // ListenPTCP() must be called before ConnectPTCP()
+
+ SingleProcessServer->SingleTargetSession = this;
+ SingleTargetSession = SingleProcessServer;
+
+ Ptr<PacketizedTCPSocket> s = *new PacketizedTCPSocket;
+ SockAddr sa;
+ sa.Set("::1", 10101, SOCK_STREAM);
+
+ Ptr<Connection> newConnection = AllocConnection(TransportType_PacketizedTCP);
+ if (!newConnection)
+ {
+ return SessionResult_ConnectFailure;
+ }
+
+ PacketizedTCPConnection* c = (PacketizedTCPConnection*)newConnection.GetPtr();
+ c->pSocket = s;
+ c->Address = &sa;
+ c->Transport = TransportType_PacketizedTCP;
+ c->SetState(Client_Connecting);
+ AllConnections.PushBack(c);
+
+ SingleTargetSession->TCP_OnAccept(s, &sa, INVALID_SOCKET);
+ TCP_OnConnected(s);
+
+ return SessionResult_OK;
+ }
+
ConnectParametersBerkleySocket cp(NULL, remoteAddress, blocking, TransportType_PacketizedTCP);
Ptr<PacketizedTCPSocket> connectSocket = *new PacketizedTCPSocket();
- cp.BoundSocketToConnectWith = connectSocket.GetPtr();
+ cp.BoundSocketToConnectWith = connectSocket.GetPtr();
if (connectSocket->Bind(bbp) == INVALID_SOCKET)
{
return SessionResult_BindFailure;
}
- return Connect(&cp);
+ return Connect(&cp);
}
Ptr<PacketizedTCPConnection> Session::findConnectionBySockAddr(SockAddr* address)
@@ -280,47 +388,26 @@ Ptr<PacketizedTCPConnection> Session::findConnectionBySockAddr(SockAddr* address
int Session::Send(SendParameters *payload)
{
- if (payload->pConnection->Transport == TransportType_Loopback)
- {
- Lock::Locker locker(&SessionListenersLock);
-
- const int count = SessionListeners.GetSizeI();
- for (int i = 0; i < count; ++i)
- {
- SessionListener* sl = SessionListeners[i];
-
- // FIXME: This looks like it needs to be reviewed at some point..
- ReceivePayload rp;
- rp.Bytes = payload->Bytes;
- rp.pConnection = payload->pConnection;
- rp.pData = (uint8_t*)payload->pData; // FIXME
- ListenerReceiveResult lrr = LRR_CONTINUE;
- sl->OnReceive(&rp, &lrr);
- if (lrr == LRR_RETURN)
- {
- return payload->Bytes;
- }
- else if (lrr == LRR_BREAK)
- {
- break;
- }
- }
-
- return payload->Bytes;
- }
- else if (payload->pConnection->Transport == TransportType_PacketizedTCP)
- {
- PacketizedTCPConnection* conn = (PacketizedTCPConnection*)payload->pConnection.GetPtr();
-
- return conn->pSocket->Send(payload->pData, payload->Bytes);
- }
- else
+ if (payload->pConnection->Transport == TransportType_PacketizedTCP)
{
- OVR_ASSERT(false);
+ if (Session::IsSingleProcess())
+ {
+ OVR_ASSERT(SingleTargetSession->AllConnections.GetSizeI() > 0);
+ PacketizedTCPConnection* conn = (PacketizedTCPConnection*)SingleTargetSession->AllConnections[0].GetPtr();
+ SingleTargetSession->TCP_OnRecv(conn->pSocket, (uint8_t*)payload->pData, payload->Bytes);
+ return payload->Bytes;
+ }
+ else
+ {
+ PacketizedTCPConnection* conn = (PacketizedTCPConnection*)payload->pConnection.GetPtr();
+ return conn->pSocket->Send(payload->pData, payload->Bytes);
+ }
}
+ OVR_ASSERT(false); // Should not reach here
return 0;
}
+
void Session::Broadcast(BroadcastParameters *payload)
{
SendParameters sp;
@@ -338,21 +425,29 @@ void Session::Broadcast(BroadcastParameters *payload)
}
}
}
-// DO NOT CALL Poll() FROM MULTIPLE THREADS due to allBlockingTcpSockets being a member
+
+// DO NOT CALL Poll() FROM MULTIPLE THREADS due to AllBlockingTcpSockets being a member
void Session::Poll(bool listeners)
{
- allBlockingTcpSockets.Clear();
+ if (Net::Session::IsSingleProcess())
+ {
+ // Spend a lot of time sleeping in single process mode
+ Thread::MSleep(100);
+ return;
+ }
- if (listeners)
- {
- Lock::Locker locker(&SocketListenersLock);
+ AllBlockingTcpSockets.Clear();
+
+ if (listeners)
+ {
+ Lock::Locker locker(&SocketListenersLock);
const int listenerCount = SocketListeners.GetSizeI();
for (int i = 0; i < listenerCount; ++i)
- {
- allBlockingTcpSockets.PushBack(SocketListeners[i]);
- }
- }
+ {
+ AllBlockingTcpSockets.PushBack(SocketListeners[i]);
+ }
+ }
{
Lock::Locker locker(&ConnectionsLock);
@@ -366,7 +461,7 @@ void Session::Poll(bool listeners)
{
PacketizedTCPConnection* ptcp = (PacketizedTCPConnection*)arrayItem;
- allBlockingTcpSockets.PushBack(ptcp->pSocket);
+ AllBlockingTcpSockets.PushBack(ptcp->pSocket);
}
else
{
@@ -375,15 +470,15 @@ void Session::Poll(bool listeners)
}
}
- const int count = allBlockingTcpSockets.GetSizeI();
- if (count > 0)
- {
+ const int count = AllBlockingTcpSockets.GetSizeI();
+ if (count > 0)
+ {
TCPSocketPollState state;
// Add all the sockets for polling,
for (int i = 0; i < count; ++i)
{
- Net::TCPSocket* sock = allBlockingTcpSockets[i].GetPtr();
+ Net::TCPSocket* sock = AllBlockingTcpSockets[i].GetPtr();
// If socket handle is invalid,
if (sock->GetSocketHandle() == INVALID_SOCKET)
@@ -399,20 +494,20 @@ void Session::Poll(bool listeners)
}
// If polling returns with an event,
- if (state.Poll(allBlockingTcpSockets[0]->GetBlockingTimeoutUsec(), allBlockingTcpSockets[0]->GetBlockingTimeoutSec()))
+ if (state.Poll(AllBlockingTcpSockets[0]->GetBlockingTimeoutUsec(), AllBlockingTcpSockets[0]->GetBlockingTimeoutSec()))
{
// Handle any events for each socket
for (int i = 0; i < count; ++i)
{
- state.HandleEvent(allBlockingTcpSockets[i], this);
+ state.HandleEvent(AllBlockingTcpSockets[i], this);
}
}
- }
+ }
}
void Session::AddSessionListener(SessionListener* se)
{
- Lock::Locker locker(&SessionListenersLock);
+ Lock::Locker locker(&SessionListenersLock);
const int count = SessionListeners.GetSizeI();
for (int i = 0; i < count; ++i)
@@ -425,36 +520,35 @@ void Session::AddSessionListener(SessionListener* se)
}
SessionListeners.PushBack(se);
- se->OnAddedToSession(this);
+ se->OnAddedToSession(this);
}
void Session::RemoveSessionListener(SessionListener* se)
{
- Lock::Locker locker(&SessionListenersLock);
+ Lock::Locker locker(&SessionListenersLock);
const int count = SessionListeners.GetSizeI();
- for (int i = 0; i < count; ++i)
- {
+ for (int i = 0; i < count; ++i)
+ {
if (SessionListeners[i] == se)
- {
+ {
se->OnRemovedFromSession(this);
SessionListeners.RemoveAtUnordered(i);
break;
- }
- }
+ }
+ }
}
-SInt32 Session::GetActiveSocketsCount()
+
+int Session::GetActiveSocketsCount()
{
- Lock::Locker locker1(&SocketListenersLock);
- Lock::Locker locker2(&ConnectionsLock);
- return SocketListeners.GetSize() + AllConnections.GetSize()>0;
+ return SocketListeners.GetSizeI() + AllConnections.GetSizeI();
}
+
Ptr<Connection> Session::AllocConnection(TransportType transport)
{
switch (transport)
{
- case TransportType_Loopback: return *new Connection();
case TransportType_TCP: return *new TCPConnection();
case TransportType_PacketizedTCP: return *new PacketizedTCPConnection();
default:
@@ -511,14 +605,14 @@ int Session::invokeSessionListeners(ReceivePayload* rp)
void Session::TCP_OnRecv(Socket* pSocket, uint8_t* pData, int bytesRead)
{
- // KevinJ: 9/2/2014 Fix deadlock - Watchdog calls Broadcast(), which locks ConnectionsLock().
- // Lock::Locker locker(&ConnectionsLock);
+ // KevinJ: 9/2/2014 Fix deadlock - Watchdog calls Broadcast(), which locks ConnectionsLock().
+ // Lock::Locker locker(&ConnectionsLock);
// Look for the connection in the full connection list first
int connIndex;
- ConnectionsLock.DoLock();
+ ConnectionsLock.DoLock();
Ptr<PacketizedTCPConnection> conn = findConnectionBySocket(AllConnections, pSocket, &connIndex);
- ConnectionsLock.Unlock();
+ ConnectionsLock.Unlock();
if (conn)
{
if (conn->State == State_Connected)
@@ -537,8 +631,8 @@ void Session::TCP_OnRecv(Socket* pSocket, uint8_t* pData, int bytesRead)
BitStream bsIn((char*)pData, bytesRead, false);
RPC_S2C_Authorization auth;
- if (!auth.Deserialize(&bsIn) ||
- !auth.Validate())
+ if (!auth.Serialize(false, &bsIn) ||
+ !auth.ClientValidate())
{
LogError("{ERR-001} [Session] REJECTED: OVRService did not authorize us: %s", auth.AuthString.ToCStr());
@@ -551,16 +645,18 @@ void Session::TCP_OnRecv(Socket* pSocket, uint8_t* pData, int bytesRead)
conn->RemoteMajorVersion = auth.MajorVersion;
conn->RemoteMinorVersion = auth.MinorVersion;
conn->RemotePatchVersion = auth.PatchVersion;
+ conn->RemoteCodeVersion = auth.CodeVersion;
// Mark as connected
conn->SetState(State_Connected);
- ConnectionsLock.DoLock();
- int connIndex2;
- if (findConnectionBySocket(AllConnections, pSocket, &connIndex2)==conn && findConnectionBySocket(FullConnections, pSocket, &connIndex2)==NULL)
- {
- FullConnections.PushBack(conn);
- }
- ConnectionsLock.Unlock();
+ ConnectionsLock.DoLock();
+ int connIndex2;
+ if (findConnectionBySocket(AllConnections, pSocket, &connIndex2)==conn && findConnectionBySocket(FullConnections, pSocket, &connIndex2)==NULL)
+ {
+ FullConnections.PushBack(conn);
+ HaveFullConnections.store(true, std::memory_order_relaxed);
+ }
+ ConnectionsLock.Unlock();
invokeSessionEvent(&SessionListener::OnConnectionRequestAccepted, conn);
}
}
@@ -570,41 +666,59 @@ void Session::TCP_OnRecv(Socket* pSocket, uint8_t* pData, int bytesRead)
BitStream bsIn((char*)pData, bytesRead, false);
RPC_C2S_Hello hello;
- if (!hello.Deserialize(&bsIn) ||
- !hello.Validate())
+ if (!hello.Serialize(false, &bsIn) ||
+ !hello.ServerValidate())
{
- LogError("{ERR-002} [Session] REJECTED: Rift application is using an incompatible version %d.%d.%d (my version=%d.%d.%d)",
- hello.MajorVersion, hello.MinorVersion, hello.PatchVersion,
- RPCVersion_Major, RPCVersion_Minor, RPCVersion_Patch);
+ LogError("{ERR-002} [Session] REJECTED: Rift application is using an incompatible version %d.%d.%d, feature version %d (my version=%d.%d.%d, feature version %d)",
+ hello.MajorVersion, hello.MinorVersion, hello.PatchVersion, hello.CodeVersion.FeatureVersion,
+ RPCVersion_Major, RPCVersion_Minor, RPCVersion_Patch, OVR_FEATURE_VERSION);
conn->SetState(State_Zombie);
// Send auth response
BitStream bsOut;
- RPC_S2C_Authorization::Generate(&bsOut, "Incompatible protocol version. Please make sure your OVRService and SDK are both up to date.");
- conn->pSocket->Send(bsOut.GetData(), bsOut.GetNumberOfBytesUsed());
+ RPC_S2C_Authorization::ServerGenerate(&bsOut, "Incompatible protocol version. Please make sure your OVRService and SDK are both up to date.");
+
+ SendParameters sp;
+ sp.Bytes = bsOut.GetNumberOfBytesUsed();
+ sp.pData = bsOut.GetData();
+ sp.pConnection = conn;
+ Send(&sp);
}
else
{
+ if (hello.CodeVersion.FeatureVersion != OVR_FEATURE_VERSION)
+ {
+ LogError("[Session] WARNING: Rift application is using a different feature version than the server (server version = %d, app version = %d)",
+ OVR_FEATURE_VERSION, hello.CodeVersion.FeatureVersion);
+ }
+
// Read remote version
conn->RemoteMajorVersion = hello.MajorVersion;
conn->RemoteMinorVersion = hello.MinorVersion;
conn->RemotePatchVersion = hello.PatchVersion;
+ conn->RemoteCodeVersion = hello.CodeVersion;
// Send auth response
BitStream bsOut;
- RPC_S2C_Authorization::Generate(&bsOut);
- conn->pSocket->Send(bsOut.GetData(), bsOut.GetNumberOfBytesUsed());
+ RPC_S2C_Authorization::ServerGenerate(&bsOut);
+
+ SendParameters sp;
+ sp.Bytes = bsOut.GetNumberOfBytesUsed();
+ sp.pData = bsOut.GetData();
+ sp.pConnection = conn;
+ Send(&sp);
// Mark as connected
conn->SetState(State_Connected);
- ConnectionsLock.DoLock();
- int connIndex2;
- if (findConnectionBySocket(AllConnections, pSocket, &connIndex2)==conn && findConnectionBySocket(FullConnections, pSocket, &connIndex2)==NULL)
- {
- FullConnections.PushBack(conn);
- }
- ConnectionsLock.Unlock();
+ ConnectionsLock.DoLock();
+ int connIndex2;
+ if (findConnectionBySocket(AllConnections, pSocket, &connIndex2)==conn && findConnectionBySocket(FullConnections, pSocket, &connIndex2)==NULL)
+ {
+ FullConnections.PushBack(conn);
+ HaveFullConnections.store(true, std::memory_order_relaxed);
+ }
+ ConnectionsLock.Unlock();
invokeSessionEvent(&SessionListener::OnNewIncomingConnection, conn);
}
@@ -618,10 +732,10 @@ void Session::TCP_OnRecv(Socket* pSocket, uint8_t* pData, int bytesRead)
void Session::TCP_OnClosed(TCPSocket* s)
{
- Lock::Locker locker(&ConnectionsLock);
+ Lock::Locker locker(&ConnectionsLock);
// If found in the full connection list,
- int connIndex;
+ int connIndex = 0;
Ptr<PacketizedTCPConnection> conn = findConnectionBySocket(AllConnections, s, &connIndex);
if (conn)
{
@@ -631,6 +745,10 @@ void Session::TCP_OnClosed(TCPSocket* s)
if (findConnectionBySocket(FullConnections, s, &connIndex))
{
FullConnections.RemoveAtUnordered(connIndex);
+ if (FullConnections.GetSizeI() < 1)
+ {
+ HaveFullConnections.store(false, std::memory_order_relaxed);
+ }
}
// Generate an appropriate event for the current state
@@ -659,23 +777,23 @@ void Session::TCP_OnClosed(TCPSocket* s)
void Session::TCP_OnAccept(TCPSocket* pListener, SockAddr* pSockAddr, SocketHandle newSock)
{
OVR_UNUSED(pListener);
- OVR_ASSERT(pListener->Transport == TransportType_PacketizedTCP);
+ Ptr<PacketizedTCPSocket> newSocket = *new PacketizedTCPSocket(newSock, false);
+ OVR_ASSERT(pListener->Transport == TransportType_PacketizedTCP);
- Ptr<PacketizedTCPSocket> newSocket = *new PacketizedTCPSocket(newSock, false);
// If pSockAddr is not localhost, then close newSock
- if (pSockAddr->IsLocalhost()==false)
+ if (!pSockAddr->IsLocalhost())
{
newSocket->Close();
return;
}
- if (newSocket)
- {
- Ptr<Connection> b = AllocConnection(TransportType_PacketizedTCP);
- Ptr<PacketizedTCPConnection> c = (PacketizedTCPConnection*)b.GetPtr();
- c->pSocket = newSocket;
- c->Address = *pSockAddr;
+ if (newSocket)
+ {
+ Ptr<Connection> b = AllocConnection(TransportType_PacketizedTCP);
+ Ptr<PacketizedTCPConnection> c = (PacketizedTCPConnection*)b.GetPtr();
+ c->pSocket = newSocket;
+ c->Address = *pSockAddr;
c->State = Server_ConnectedWait;
{
@@ -684,7 +802,7 @@ void Session::TCP_OnAccept(TCPSocket* pListener, SockAddr* pSockAddr, SocketHand
}
// Server does not send the first packet. It waits for the client to send its version
- }
+ }
}
void Session::TCP_OnConnected(TCPSocket *s)
@@ -697,13 +815,18 @@ void Session::TCP_OnConnected(TCPSocket *s)
{
OVR_ASSERT(conn->State == Client_Connecting);
+ // Just update state but do not generate any notifications yet
+ conn->SetState(Client_ConnectedWait);
+
// Send hello message
BitStream bsOut;
- RPC_C2S_Hello::Generate(&bsOut);
- conn->pSocket->Send(bsOut.GetData(), bsOut.GetNumberOfBytesUsed());
+ RPC_C2S_Hello::ClientGenerate(&bsOut);
- // Just update state but do not generate any notifications yet
- conn->State = Client_ConnectedWait;
+ SendParameters sp;
+ sp.Bytes = bsOut.GetNumberOfBytesUsed();
+ sp.pData = bsOut.GetData();
+ sp.pConnection = conn;
+ Send(&sp);
}
}