diff options
Diffstat (limited to 'LibOVR/Src/Net/OVR_Session.cpp')
-rw-r--r-- | LibOVR/Src/Net/OVR_Session.cpp | 455 |
1 files changed, 289 insertions, 166 deletions
diff --git a/LibOVR/Src/Net/OVR_Session.cpp b/LibOVR/Src/Net/OVR_Session.cpp index 508f0c9..4049c6c 100644 --- a/LibOVR/Src/Net/OVR_Session.cpp +++ b/LibOVR/Src/Net/OVR_Session.cpp @@ -26,36 +26,95 @@ limitations under the License. #include "OVR_Session.h" #include "OVR_PacketizedTCPSocket.h" -#include "../Kernel/OVR_Log.h" -#include "../Service/Service_NetSessionCommon.h" +#include "Kernel/OVR_Log.h" +#include "Service/Service_NetSessionCommon.h" namespace OVR { namespace Net { +// The SDK version requested by the user. +SDKVersion RuntimeSDKVersion; + + //----------------------------------------------------------------------------- // Protocol -static const char* OfficialHelloString = "OculusVR_Hello"; +static const char* OfficialHelloString = "OculusVR_Hello"; static const char* OfficialAuthorizedString = "OculusVR_Authorized"; -void RPC_C2S_Hello::Generate(Net::BitStream* bs) +bool RPC_C2S_Hello::Serialize(bool writeToBitstream, Net::BitStream* bs) +{ + bs->Serialize(writeToBitstream, HelloString); + bs->Serialize(writeToBitstream, MajorVersion); + bs->Serialize(writeToBitstream, MinorVersion); + if (!bs->Serialize(writeToBitstream, PatchVersion)) + return false; + + // If an older client is connecting to us, + if (!writeToBitstream && (MajorVersion * 100) + (MinorVersion * 10) + PatchVersion < 121) + { + // The following was version code was added to RPC version 1.2 + // without bumping it up to 1.3 and introducing an incompatibility. + // We can do this because an older server will not read this additional data. + return true; + } + + bs->Serialize(writeToBitstream, CodeVersion.ProductVersion); + bs->Serialize(writeToBitstream, CodeVersion.MajorVersion); + bs->Serialize(writeToBitstream, CodeVersion.MinorVersion); + bs->Serialize(writeToBitstream, CodeVersion.RequestedMinorVersion); + bs->Serialize(writeToBitstream, CodeVersion.PatchVersion); + bs->Serialize(writeToBitstream, CodeVersion.BuildNumber); + return bs->Serialize(writeToBitstream, CodeVersion.FeatureVersion); +} + +void RPC_C2S_Hello::ClientGenerate(Net::BitStream* bs) { RPC_C2S_Hello hello; - hello.HelloString = OfficialHelloString; + hello.HelloString = OfficialHelloString; hello.MajorVersion = RPCVersion_Major; hello.MinorVersion = RPCVersion_Minor; hello.PatchVersion = RPCVersion_Patch; - hello.Serialize(bs); + OVR_ASSERT(OVR::Net::RuntimeSDKVersion.ProductVersion != UINT16_MAX); + hello.CodeVersion = OVR::Net::RuntimeSDKVersion; // This should have been set to a value earlier in the first steps of ovr initialization. + hello.Serialize(true, bs); } -bool RPC_C2S_Hello::Validate() +bool RPC_C2S_Hello::ServerValidate() { + // Server checks the protocol version return MajorVersion == RPCVersion_Major && MinorVersion <= RPCVersion_Minor && HelloString.CompareNoCase(OfficialHelloString) == 0; } -void RPC_S2C_Authorization::Generate(Net::BitStream* bs, String errorString) +bool RPC_S2C_Authorization::Serialize(bool writeToBitstream, Net::BitStream* bs) +{ + bs->Serialize(writeToBitstream, AuthString); + bs->Serialize(writeToBitstream, MajorVersion); + bs->Serialize(writeToBitstream, MinorVersion); + if (!bs->Serialize(writeToBitstream, PatchVersion)) + return false; + + // If an older client is connecting to us, + if (!writeToBitstream && (MajorVersion * 100) + (MinorVersion * 10) + PatchVersion < 121) + { + // The following was version code was added to RPC version 1.2 + // without bumping it up to 1.3 and introducing an incompatibility. + // We can do this because an older server will not read this additional data. + return true; + } + + bs->Serialize(writeToBitstream, CodeVersion.ProductVersion); + bs->Serialize(writeToBitstream, CodeVersion.MajorVersion); + bs->Serialize(writeToBitstream, CodeVersion.MinorVersion); + bs->Serialize(writeToBitstream, CodeVersion.RequestedMinorVersion); + bs->Serialize(writeToBitstream, CodeVersion.PatchVersion); + bs->Serialize(writeToBitstream, CodeVersion.BuildNumber); + return bs->Serialize(writeToBitstream, CodeVersion.FeatureVersion); +} + +void RPC_S2C_Authorization::ServerGenerate(Net::BitStream* bs, String errorString) { RPC_S2C_Authorization auth; if (errorString.IsEmpty()) @@ -69,16 +128,33 @@ void RPC_S2C_Authorization::Generate(Net::BitStream* bs, String errorString) auth.MajorVersion = RPCVersion_Major; auth.MinorVersion = RPCVersion_Minor; auth.PatchVersion = RPCVersion_Patch; - auth.Serialize(bs); + // Leave CurrentSDKVersion as it is. + auth.Serialize(true, bs); } -bool RPC_S2C_Authorization::Validate() +bool RPC_S2C_Authorization::ClientValidate() { return AuthString.CompareNoCase(OfficialAuthorizedString) == 0; } //----------------------------------------------------------------------------- +// SingleProcess + +static bool SingleProcess = false; + +void Session::SetSingleProcess(bool enable) +{ + SingleProcess = enable; +} + +bool Session::IsSingleProcess() +{ + return SingleProcess; +} + + +//----------------------------------------------------------------------------- // Session void Session::Shutdown() @@ -111,29 +187,25 @@ void Session::Shutdown() SessionResult Session::Listen(ListenerDescription* pListenerDescription) { - if (pListenerDescription->Transport == TransportType_PacketizedTCP) - { - BerkleyListenerDescription* bld = (BerkleyListenerDescription*)pListenerDescription; - TCPSocket* tcpSocket = (TCPSocket*)bld->BoundSocketToListenWith.GetPtr(); + if (pListenerDescription->Transport == TransportType_PacketizedTCP) + { + BerkleyListenerDescription* bld = (BerkleyListenerDescription*)pListenerDescription; + TCPSocket* tcpSocket = (TCPSocket*)bld->BoundSocketToListenWith.GetPtr(); if (tcpSocket->Listen() < 0) { return SessionResult_ListenFailure; } - Lock::Locker locker(&SocketListenersLock); + Lock::Locker locker(&SocketListenersLock); SocketListeners.PushBack(tcpSocket); - } - else if (pListenerDescription->Transport == TransportType_Loopback) - { - HasLoopbackListener = true; - } + } else { OVR_ASSERT(false); } - return SessionResult_OK; + return SessionResult_OK; } SessionResult Session::Connect(ConnectParameters *cp) @@ -153,6 +225,28 @@ SessionResult Session::Connect(ConnectParameters *cp) return SessionResult_AlreadyConnected; } + // If we are already connected, don't create a duplicate connection + if (FullConnections.GetSizeI() > 0) + { + return SessionResult_AlreadyConnected; + } + + // If we are already connecting, don't create a duplicate connection + const int count = AllConnections.GetSizeI(); + for (int i = 0; i < count; ++i) + { + Connection* arrayItem = AllConnections[i].GetPtr(); + + OVR_ASSERT(arrayItem); + if (arrayItem) { + if (arrayItem->State == Client_ConnectedWait + || arrayItem->State == Client_Connecting) + { + return SessionResult_ConnectInProgress; + } + } + } + TCPSocketBase* tcpSock = (TCPSocketBase*)cp2->BoundSocketToConnectWith.GetPtr(); int ret = tcpSock->Connect(&cp2->RemoteAddress); @@ -174,7 +268,6 @@ SessionResult Session::Connect(ConnectParameters *cp) c->SetState(Client_Connecting); AllConnections.PushBack(c); - } if (cp2->Blocking) @@ -182,11 +275,12 @@ SessionResult Session::Connect(ConnectParameters *cp) c->WaitOnConnecting(); } - if (c->State == State_Connected) + EConnectionState state = c->State; + if (state == State_Connected) { return SessionResult_OK; } - else if (c->State == Client_Connecting) + else if (state == Client_Connecting) { return SessionResult_ConnectInProgress; } @@ -195,49 +289,33 @@ SessionResult Session::Connect(ConnectParameters *cp) return SessionResult_ConnectFailure; } } - else if (cp->Transport == TransportType_Loopback) - { - if (HasLoopbackListener) - { - Ptr<Connection> c = AllocConnection(cp->Transport); - if (!c) - { - return SessionResult_ConnectFailure; - } - - c->Transport = cp->Transport; - c->SetState(State_Connected); - - { - Lock::Locker locker(&ConnectionsLock); - AllConnections.PushBack(c); - } - - invokeSessionEvent(&SessionListener::OnConnectionRequestAccepted, c); - } - else - { - OVR_ASSERT(false); - } - } else { OVR_ASSERT(false); } - return SessionResult_OK; + return SessionResult_OK; } +static Session* SingleProcessServer = nullptr; + SessionResult Session::ListenPTCP(OVR::Net::BerkleyBindParameters *bbp) { - Ptr<PacketizedTCPSocket> listenSocket = *new OVR::Net::PacketizedTCPSocket(); + if (Session::IsSingleProcess()) + { + // Do not actually listen on a socket. + SingleProcessServer = this; + return SessionResult_OK; + } + + Ptr<PacketizedTCPSocket> listenSocket = *new OVR::Net::PacketizedTCPSocket(); if (listenSocket->Bind(bbp) == INVALID_SOCKET) { return SessionResult_BindFailure; } - BerkleyListenerDescription bld; - bld.BoundSocketToListenWith = listenSocket.GetPtr(); + BerkleyListenerDescription bld; + bld.BoundSocketToListenWith = listenSocket.GetPtr(); bld.Transport = TransportType_PacketizedTCP; return Listen(&bld); @@ -245,16 +323,46 @@ SessionResult Session::ListenPTCP(OVR::Net::BerkleyBindParameters *bbp) SessionResult Session::ConnectPTCP(OVR::Net::BerkleyBindParameters* bbp, SockAddr* remoteAddress, bool blocking) { + if (Session::IsSingleProcess()) + { + OVR_ASSERT(SingleProcessServer); // ListenPTCP() must be called before ConnectPTCP() + + SingleProcessServer->SingleTargetSession = this; + SingleTargetSession = SingleProcessServer; + + Ptr<PacketizedTCPSocket> s = *new PacketizedTCPSocket; + SockAddr sa; + sa.Set("::1", 10101, SOCK_STREAM); + + Ptr<Connection> newConnection = AllocConnection(TransportType_PacketizedTCP); + if (!newConnection) + { + return SessionResult_ConnectFailure; + } + + PacketizedTCPConnection* c = (PacketizedTCPConnection*)newConnection.GetPtr(); + c->pSocket = s; + c->Address = &sa; + c->Transport = TransportType_PacketizedTCP; + c->SetState(Client_Connecting); + AllConnections.PushBack(c); + + SingleTargetSession->TCP_OnAccept(s, &sa, INVALID_SOCKET); + TCP_OnConnected(s); + + return SessionResult_OK; + } + ConnectParametersBerkleySocket cp(NULL, remoteAddress, blocking, TransportType_PacketizedTCP); Ptr<PacketizedTCPSocket> connectSocket = *new PacketizedTCPSocket(); - cp.BoundSocketToConnectWith = connectSocket.GetPtr(); + cp.BoundSocketToConnectWith = connectSocket.GetPtr(); if (connectSocket->Bind(bbp) == INVALID_SOCKET) { return SessionResult_BindFailure; } - return Connect(&cp); + return Connect(&cp); } Ptr<PacketizedTCPConnection> Session::findConnectionBySockAddr(SockAddr* address) @@ -280,47 +388,26 @@ Ptr<PacketizedTCPConnection> Session::findConnectionBySockAddr(SockAddr* address int Session::Send(SendParameters *payload) { - if (payload->pConnection->Transport == TransportType_Loopback) - { - Lock::Locker locker(&SessionListenersLock); - - const int count = SessionListeners.GetSizeI(); - for (int i = 0; i < count; ++i) - { - SessionListener* sl = SessionListeners[i]; - - // FIXME: This looks like it needs to be reviewed at some point.. - ReceivePayload rp; - rp.Bytes = payload->Bytes; - rp.pConnection = payload->pConnection; - rp.pData = (uint8_t*)payload->pData; // FIXME - ListenerReceiveResult lrr = LRR_CONTINUE; - sl->OnReceive(&rp, &lrr); - if (lrr == LRR_RETURN) - { - return payload->Bytes; - } - else if (lrr == LRR_BREAK) - { - break; - } - } - - return payload->Bytes; - } - else if (payload->pConnection->Transport == TransportType_PacketizedTCP) - { - PacketizedTCPConnection* conn = (PacketizedTCPConnection*)payload->pConnection.GetPtr(); - - return conn->pSocket->Send(payload->pData, payload->Bytes); - } - else + if (payload->pConnection->Transport == TransportType_PacketizedTCP) { - OVR_ASSERT(false); + if (Session::IsSingleProcess()) + { + OVR_ASSERT(SingleTargetSession->AllConnections.GetSizeI() > 0); + PacketizedTCPConnection* conn = (PacketizedTCPConnection*)SingleTargetSession->AllConnections[0].GetPtr(); + SingleTargetSession->TCP_OnRecv(conn->pSocket, (uint8_t*)payload->pData, payload->Bytes); + return payload->Bytes; + } + else + { + PacketizedTCPConnection* conn = (PacketizedTCPConnection*)payload->pConnection.GetPtr(); + return conn->pSocket->Send(payload->pData, payload->Bytes); + } } + OVR_ASSERT(false); // Should not reach here return 0; } + void Session::Broadcast(BroadcastParameters *payload) { SendParameters sp; @@ -338,21 +425,29 @@ void Session::Broadcast(BroadcastParameters *payload) } } } -// DO NOT CALL Poll() FROM MULTIPLE THREADS due to allBlockingTcpSockets being a member + +// DO NOT CALL Poll() FROM MULTIPLE THREADS due to AllBlockingTcpSockets being a member void Session::Poll(bool listeners) { - allBlockingTcpSockets.Clear(); + if (Net::Session::IsSingleProcess()) + { + // Spend a lot of time sleeping in single process mode + Thread::MSleep(100); + return; + } - if (listeners) - { - Lock::Locker locker(&SocketListenersLock); + AllBlockingTcpSockets.Clear(); + + if (listeners) + { + Lock::Locker locker(&SocketListenersLock); const int listenerCount = SocketListeners.GetSizeI(); for (int i = 0; i < listenerCount; ++i) - { - allBlockingTcpSockets.PushBack(SocketListeners[i]); - } - } + { + AllBlockingTcpSockets.PushBack(SocketListeners[i]); + } + } { Lock::Locker locker(&ConnectionsLock); @@ -366,7 +461,7 @@ void Session::Poll(bool listeners) { PacketizedTCPConnection* ptcp = (PacketizedTCPConnection*)arrayItem; - allBlockingTcpSockets.PushBack(ptcp->pSocket); + AllBlockingTcpSockets.PushBack(ptcp->pSocket); } else { @@ -375,15 +470,15 @@ void Session::Poll(bool listeners) } } - const int count = allBlockingTcpSockets.GetSizeI(); - if (count > 0) - { + const int count = AllBlockingTcpSockets.GetSizeI(); + if (count > 0) + { TCPSocketPollState state; // Add all the sockets for polling, for (int i = 0; i < count; ++i) { - Net::TCPSocket* sock = allBlockingTcpSockets[i].GetPtr(); + Net::TCPSocket* sock = AllBlockingTcpSockets[i].GetPtr(); // If socket handle is invalid, if (sock->GetSocketHandle() == INVALID_SOCKET) @@ -399,20 +494,20 @@ void Session::Poll(bool listeners) } // If polling returns with an event, - if (state.Poll(allBlockingTcpSockets[0]->GetBlockingTimeoutUsec(), allBlockingTcpSockets[0]->GetBlockingTimeoutSec())) + if (state.Poll(AllBlockingTcpSockets[0]->GetBlockingTimeoutUsec(), AllBlockingTcpSockets[0]->GetBlockingTimeoutSec())) { // Handle any events for each socket for (int i = 0; i < count; ++i) { - state.HandleEvent(allBlockingTcpSockets[i], this); + state.HandleEvent(AllBlockingTcpSockets[i], this); } } - } + } } void Session::AddSessionListener(SessionListener* se) { - Lock::Locker locker(&SessionListenersLock); + Lock::Locker locker(&SessionListenersLock); const int count = SessionListeners.GetSizeI(); for (int i = 0; i < count; ++i) @@ -425,36 +520,35 @@ void Session::AddSessionListener(SessionListener* se) } SessionListeners.PushBack(se); - se->OnAddedToSession(this); + se->OnAddedToSession(this); } void Session::RemoveSessionListener(SessionListener* se) { - Lock::Locker locker(&SessionListenersLock); + Lock::Locker locker(&SessionListenersLock); const int count = SessionListeners.GetSizeI(); - for (int i = 0; i < count; ++i) - { + for (int i = 0; i < count; ++i) + { if (SessionListeners[i] == se) - { + { se->OnRemovedFromSession(this); SessionListeners.RemoveAtUnordered(i); break; - } - } + } + } } -SInt32 Session::GetActiveSocketsCount() + +int Session::GetActiveSocketsCount() { - Lock::Locker locker1(&SocketListenersLock); - Lock::Locker locker2(&ConnectionsLock); - return SocketListeners.GetSize() + AllConnections.GetSize()>0; + return SocketListeners.GetSizeI() + AllConnections.GetSizeI(); } + Ptr<Connection> Session::AllocConnection(TransportType transport) { switch (transport) { - case TransportType_Loopback: return *new Connection(); case TransportType_TCP: return *new TCPConnection(); case TransportType_PacketizedTCP: return *new PacketizedTCPConnection(); default: @@ -511,14 +605,14 @@ int Session::invokeSessionListeners(ReceivePayload* rp) void Session::TCP_OnRecv(Socket* pSocket, uint8_t* pData, int bytesRead) { - // KevinJ: 9/2/2014 Fix deadlock - Watchdog calls Broadcast(), which locks ConnectionsLock(). - // Lock::Locker locker(&ConnectionsLock); + // KevinJ: 9/2/2014 Fix deadlock - Watchdog calls Broadcast(), which locks ConnectionsLock(). + // Lock::Locker locker(&ConnectionsLock); // Look for the connection in the full connection list first int connIndex; - ConnectionsLock.DoLock(); + ConnectionsLock.DoLock(); Ptr<PacketizedTCPConnection> conn = findConnectionBySocket(AllConnections, pSocket, &connIndex); - ConnectionsLock.Unlock(); + ConnectionsLock.Unlock(); if (conn) { if (conn->State == State_Connected) @@ -537,8 +631,8 @@ void Session::TCP_OnRecv(Socket* pSocket, uint8_t* pData, int bytesRead) BitStream bsIn((char*)pData, bytesRead, false); RPC_S2C_Authorization auth; - if (!auth.Deserialize(&bsIn) || - !auth.Validate()) + if (!auth.Serialize(false, &bsIn) || + !auth.ClientValidate()) { LogError("{ERR-001} [Session] REJECTED: OVRService did not authorize us: %s", auth.AuthString.ToCStr()); @@ -551,16 +645,18 @@ void Session::TCP_OnRecv(Socket* pSocket, uint8_t* pData, int bytesRead) conn->RemoteMajorVersion = auth.MajorVersion; conn->RemoteMinorVersion = auth.MinorVersion; conn->RemotePatchVersion = auth.PatchVersion; + conn->RemoteCodeVersion = auth.CodeVersion; // Mark as connected conn->SetState(State_Connected); - ConnectionsLock.DoLock(); - int connIndex2; - if (findConnectionBySocket(AllConnections, pSocket, &connIndex2)==conn && findConnectionBySocket(FullConnections, pSocket, &connIndex2)==NULL) - { - FullConnections.PushBack(conn); - } - ConnectionsLock.Unlock(); + ConnectionsLock.DoLock(); + int connIndex2; + if (findConnectionBySocket(AllConnections, pSocket, &connIndex2)==conn && findConnectionBySocket(FullConnections, pSocket, &connIndex2)==NULL) + { + FullConnections.PushBack(conn); + HaveFullConnections.store(true, std::memory_order_relaxed); + } + ConnectionsLock.Unlock(); invokeSessionEvent(&SessionListener::OnConnectionRequestAccepted, conn); } } @@ -570,41 +666,59 @@ void Session::TCP_OnRecv(Socket* pSocket, uint8_t* pData, int bytesRead) BitStream bsIn((char*)pData, bytesRead, false); RPC_C2S_Hello hello; - if (!hello.Deserialize(&bsIn) || - !hello.Validate()) + if (!hello.Serialize(false, &bsIn) || + !hello.ServerValidate()) { - LogError("{ERR-002} [Session] REJECTED: Rift application is using an incompatible version %d.%d.%d (my version=%d.%d.%d)", - hello.MajorVersion, hello.MinorVersion, hello.PatchVersion, - RPCVersion_Major, RPCVersion_Minor, RPCVersion_Patch); + LogError("{ERR-002} [Session] REJECTED: Rift application is using an incompatible version %d.%d.%d, feature version %d (my version=%d.%d.%d, feature version %d)", + hello.MajorVersion, hello.MinorVersion, hello.PatchVersion, hello.CodeVersion.FeatureVersion, + RPCVersion_Major, RPCVersion_Minor, RPCVersion_Patch, OVR_FEATURE_VERSION); conn->SetState(State_Zombie); // Send auth response BitStream bsOut; - RPC_S2C_Authorization::Generate(&bsOut, "Incompatible protocol version. Please make sure your OVRService and SDK are both up to date."); - conn->pSocket->Send(bsOut.GetData(), bsOut.GetNumberOfBytesUsed()); + RPC_S2C_Authorization::ServerGenerate(&bsOut, "Incompatible protocol version. Please make sure your OVRService and SDK are both up to date."); + + SendParameters sp; + sp.Bytes = bsOut.GetNumberOfBytesUsed(); + sp.pData = bsOut.GetData(); + sp.pConnection = conn; + Send(&sp); } else { + if (hello.CodeVersion.FeatureVersion != OVR_FEATURE_VERSION) + { + LogError("[Session] WARNING: Rift application is using a different feature version than the server (server version = %d, app version = %d)", + OVR_FEATURE_VERSION, hello.CodeVersion.FeatureVersion); + } + // Read remote version conn->RemoteMajorVersion = hello.MajorVersion; conn->RemoteMinorVersion = hello.MinorVersion; conn->RemotePatchVersion = hello.PatchVersion; + conn->RemoteCodeVersion = hello.CodeVersion; // Send auth response BitStream bsOut; - RPC_S2C_Authorization::Generate(&bsOut); - conn->pSocket->Send(bsOut.GetData(), bsOut.GetNumberOfBytesUsed()); + RPC_S2C_Authorization::ServerGenerate(&bsOut); + + SendParameters sp; + sp.Bytes = bsOut.GetNumberOfBytesUsed(); + sp.pData = bsOut.GetData(); + sp.pConnection = conn; + Send(&sp); // Mark as connected conn->SetState(State_Connected); - ConnectionsLock.DoLock(); - int connIndex2; - if (findConnectionBySocket(AllConnections, pSocket, &connIndex2)==conn && findConnectionBySocket(FullConnections, pSocket, &connIndex2)==NULL) - { - FullConnections.PushBack(conn); - } - ConnectionsLock.Unlock(); + ConnectionsLock.DoLock(); + int connIndex2; + if (findConnectionBySocket(AllConnections, pSocket, &connIndex2)==conn && findConnectionBySocket(FullConnections, pSocket, &connIndex2)==NULL) + { + FullConnections.PushBack(conn); + HaveFullConnections.store(true, std::memory_order_relaxed); + } + ConnectionsLock.Unlock(); invokeSessionEvent(&SessionListener::OnNewIncomingConnection, conn); } @@ -618,10 +732,10 @@ void Session::TCP_OnRecv(Socket* pSocket, uint8_t* pData, int bytesRead) void Session::TCP_OnClosed(TCPSocket* s) { - Lock::Locker locker(&ConnectionsLock); + Lock::Locker locker(&ConnectionsLock); // If found in the full connection list, - int connIndex; + int connIndex = 0; Ptr<PacketizedTCPConnection> conn = findConnectionBySocket(AllConnections, s, &connIndex); if (conn) { @@ -631,6 +745,10 @@ void Session::TCP_OnClosed(TCPSocket* s) if (findConnectionBySocket(FullConnections, s, &connIndex)) { FullConnections.RemoveAtUnordered(connIndex); + if (FullConnections.GetSizeI() < 1) + { + HaveFullConnections.store(false, std::memory_order_relaxed); + } } // Generate an appropriate event for the current state @@ -659,23 +777,23 @@ void Session::TCP_OnClosed(TCPSocket* s) void Session::TCP_OnAccept(TCPSocket* pListener, SockAddr* pSockAddr, SocketHandle newSock) { OVR_UNUSED(pListener); - OVR_ASSERT(pListener->Transport == TransportType_PacketizedTCP); + Ptr<PacketizedTCPSocket> newSocket = *new PacketizedTCPSocket(newSock, false); + OVR_ASSERT(pListener->Transport == TransportType_PacketizedTCP); - Ptr<PacketizedTCPSocket> newSocket = *new PacketizedTCPSocket(newSock, false); // If pSockAddr is not localhost, then close newSock - if (pSockAddr->IsLocalhost()==false) + if (!pSockAddr->IsLocalhost()) { newSocket->Close(); return; } - if (newSocket) - { - Ptr<Connection> b = AllocConnection(TransportType_PacketizedTCP); - Ptr<PacketizedTCPConnection> c = (PacketizedTCPConnection*)b.GetPtr(); - c->pSocket = newSocket; - c->Address = *pSockAddr; + if (newSocket) + { + Ptr<Connection> b = AllocConnection(TransportType_PacketizedTCP); + Ptr<PacketizedTCPConnection> c = (PacketizedTCPConnection*)b.GetPtr(); + c->pSocket = newSocket; + c->Address = *pSockAddr; c->State = Server_ConnectedWait; { @@ -684,7 +802,7 @@ void Session::TCP_OnAccept(TCPSocket* pListener, SockAddr* pSockAddr, SocketHand } // Server does not send the first packet. It waits for the client to send its version - } + } } void Session::TCP_OnConnected(TCPSocket *s) @@ -697,13 +815,18 @@ void Session::TCP_OnConnected(TCPSocket *s) { OVR_ASSERT(conn->State == Client_Connecting); + // Just update state but do not generate any notifications yet + conn->SetState(Client_ConnectedWait); + // Send hello message BitStream bsOut; - RPC_C2S_Hello::Generate(&bsOut); - conn->pSocket->Send(bsOut.GetData(), bsOut.GetNumberOfBytesUsed()); + RPC_C2S_Hello::ClientGenerate(&bsOut); - // Just update state but do not generate any notifications yet - conn->State = Client_ConnectedWait; + SendParameters sp; + sp.Bytes = bsOut.GetNumberOfBytesUsed(); + sp.pData = bsOut.GetData(); + sp.pConnection = conn; + Send(&sp); } } |