diff --git a/src/Network/Connection.cpp b/src/Network/Connection.cpp --- a/src/Network/Connection.cpp +++ b/src/Network/Connection.cpp @@ -1,752 +1,756 @@ #include #include #include #include #include #include #include #include #include #include #include #include #include namespace Echo { #ifdef ECHO_BIG_ENDIAN bool Connection::mPlatformBigEndian=true; #else bool Connection::mPlatformBigEndian=false; #endif const u32 Connection::mUnreasonableDataSize=0x00A00000; //10MiB for a packet HUGE really. Connection::Connection(NetworkManager& manager) : mNetworkManager(manager) { mState=States::DISCONNECTED; mQueuedData=0; mCanSend=false; mAutoAttemptReconnect = false; mQueueDataPacketsIfNotConnected = true; mDiscardDataPacketQueueOnDisconnect = false; SetAutoAttemptReconnectTime(Seconds(5.)); mReconnectTimer.AddTimeoutFunction(bind(&Connection::Connect,this),"Connection::Connect"); mOwner=0; mHeaderSent=false; // This is updated when a PacketTypes::REMOTE_DETAILS is received, which is queued to send first when a connection is established. mIsRemoteBigEndian = false; mTempBuffer = nullptr; mTempBufferSize = 0; DataPacketHeader header; mHeaderPacket=shared_ptr(new DataPacket(0,header.GetHeaderDataSizeInBytes())); RegisterPacketCallback(PacketTypes::LABELLED_PACKET,bind(&Connection::ProcessLabelledPacket,this,_1,_2)); RegisterPacketCallback(PacketTypes::REMOTE_DETAILS, bind(&Connection::OnRemoteDetails,this,_1,_2)); SetTempBufferSize(manager.GetNewConnectionBufferSize()); } Connection::~Connection() { } void Connection::SetTempBufferSize(Size sizeInBytes) { ECHO_ASSERT_NOT_NULL(sizeInBytes); mTempBuffer.reset(new u8[sizeInBytes]); mTempBufferSize = sizeInBytes; } void Connection::SetOwner(ConnectionOwner* receiver) { mOwner=receiver; } void Connection::SetState(State state) { mState=state; //Disable sending if we aren't connected. if(mState!=States::CONNECTED) { mCanSend = false; }else { SendHostDetails(); Write(true); } } shared_ptr Connection::NewDataPacket() { return mNetworkManager.NewDataPacket(); } shared_ptr Connection::NewDataPacket(u32 packetTypeID, u32 size) { shared_ptr packet = mNetworkManager.NewDataPacket(); packet->Configure(packetTypeID,size); return packet; } shared_ptr Connection::NewDataPacket(std::string label, u32 size) { shared_ptr packet = mNetworkManager.NewDataPacket(); packet->Configure(label,size); return packet; } void Connection::_notifyOwner(shared_ptr packet) { if(mOwner) { mOwner->ReceivedPacket(shared_from_this(), packet); } } void Connection::_dropped() { SetState(States::DISCONNECTED); if(mOwner) { mOwner->ConnectionDrop(shared_from_this()); } shared_ptr connection = shared_from_this(); //Make a copy of the callbacks because they might want to remove from mDisconnectCallbacks. std::map< std::string, std::vector > disconnectCallbacks = mDisconnectCallbacks; typedef std::pair< const std::string, std::vector > IdentifierDisconnectPair; BOOST_FOREACH(IdentifierDisconnectPair& idp, disconnectCallbacks) { BOOST_FOREACH(DisconnectCallback& callback, idp.second) { callback(connection); } } if(mDiscardDataPacketQueueOnDisconnect) { ScopedLock lock(mQueuedPacketsMutex); mQueuedPackets.clear(); }else { ScopedLock lock(mQueuedPacketsMutex); // Make sure the head packet is reset for resending if(!mQueuedPackets.empty()) { mQueuedPackets.front().first->mReceived = mQueuedPackets.front().first->mSize; } } //Attempt to reconnect if(mAutoAttemptReconnect) { if(!mNetworkManager.HasTask(mReconnectTimer)) { mNetworkManager.AddTask(mReconnectTimer); } mReconnectTimer.Reset(); } } void Connection::_established() { if(mOwner) { mOwner->ConnectionEstablished(shared_from_this()); } shared_ptr connection = shared_from_this(); + + //Make a copy of the callbacks because they might want to remove from mDisconnectCallbacks. + std::map< std::string, std::vector > connectCallbacks = mConnectCallbacks; + typedef std::pair< const std::string, std::vector > IdentifierConnectPair; - BOOST_FOREACH(IdentifierConnectPair& icp, mConnectCallbacks) + BOOST_FOREACH(IdentifierConnectPair& icp, connectCallbacks) { BOOST_FOREACH(ConnectCallback& callback, icp.second) { callback(connection); } } } bool Connection::Connect() { if(!IsConnected()) { mReconnectTimer.Pause(); mNetworkManager.RemoveTask(mReconnectTimer); mManualDisconnect=false; //Reset this to allow auto connect to work. return _Connect(); } return false; } bool Connection::Disconnect() { if(IsConnected()) { mManualDisconnect=true; //This was a deliberate disconnect, this stops allow auto connecting. return _Disconnect(); } return false; } int Connection::Receive() { assert(mHeaderPacket); s32 bytesReceived=0; //If there isn't a current packet we need to extract header data bytesReceived=_recv(mTempBuffer.get(),mTempBufferSize,0); if(!_HandleRecvError(bytesReceived)) { if(bytesReceived<0) { return bytesReceived; } //Revert to the default error code. return -1; } mNetworkManager.ReportReceivedData(bytesReceived); //ECHO_LOG_DEBUG("0x" << std::hex << this << std::dec << "Recv: " << mTempBuffer << ":" << bytesReceived); u8* bufferStart=mTempBuffer.get(); while(bytesReceived>0) { //ECHO_LOG_DEBUG("bufferStart: " << bufferStart << ":" << bytesReceived); u32 headerBytes=0; //ECHO_LOG_DEBUG(bytesReceived << " bytes"); if(!(mHeaderPacket->HasReceivedAllData())) //We have a valid header { //ECHO_LOG_DEBUG("I'll try and make a header for you"); //Add the received data to the header packet u32 remainingBytes=mHeaderPacket->GetRemainingDataSize(); if(remainingBytes>(u32)bytesReceived) headerBytes=bytesReceived; else headerBytes=remainingBytes; //ECHO_LOG_DEBUG("Appending " << headerBytes << " bytes"); mHeaderPacket->AppendData(bufferStart,headerBytes); if(mHeaderPacket->HasReceivedAllData()) { //ECHO_LOG_DEBUG("Found enough data for a full header..."; DataPacketHeader header; if(!header.BuildFromPacketData(*mHeaderPacket)) { ECHO_LOG_ERROR("Failed to build packet header from incoming data"); mHeaderPacket->mReceived=0; Disconnect(); return GetNumReceviedPackets(); } //Determine if the packet size is reasonable. //If it is then get rid of that incoming data if(header.GetDataLength()<=mUnreasonableDataSize) { mCurrentPacket=NewDataPacket(); mCurrentPacket->Configure(header); }else { ECHO_LOG_ERROR("DataPacket length was too large: " << header.GetDataLength() << " when maximum is " << mUnreasonableDataSize); mHeaderPacket->mReceived=0; Disconnect(); return GetNumReceviedPackets(); } } bytesReceived-=headerBytes; //For next part bufferStart+=headerBytes; } if(mCurrentPacket) { ///ECHO_LOG_DEBUG("Constructing Packet..."); if(bytesReceived>0) { //ECHO_LOG_DEBUG("Appending " << headerBytes << " bytes"); u32 bytesAppended=0; bytesAppended=mCurrentPacket->AppendData(bufferStart, bytesReceived); u32 remaining=bytesReceived-bytesAppended; if(remaining>0) { bufferStart+=bytesAppended; //There is another packet waiting } bytesReceived-=bytesAppended; } if(mCurrentPacket->HasReceivedAllData()) { //ECHO_LOG_DEBUG("Packet Constructed!"); //ECHO_LOG_DEBUG("Notifying Owner..."); ProcessReceivedPacket(mCurrentPacket); mCurrentPacket.reset(); mHeaderPacket->mReceived=0; } } } return GetNumReceviedPackets(); } void Connection::Write(bool reenable) { // reenable is true when the NetworkSystem changes the connections state appropriately // or calls Write(true) explicitly to update the write state. if(!reenable) { // If we can't send or acquire a lock, we need to bail. The latter means something else is writing. if(!mCanSend || !mQueuedPacketsMutex.AttemptLock()) { return; } }else { mQueuedPacketsMutex.Lock(); } BOOST_SCOPE_EXIT(&mQueuedPacketsMutex) { mQueuedPacketsMutex.Unlock(); } BOOST_SCOPE_EXIT_END mCanSend=false; //ECHO_LOG_DEBUG(mQueuedPackets.size()); //if(!mSendingData) s32 totalDataSent=0; while(!(mQueuedPackets.empty())) { shared_ptr packet=mQueuedPackets.front().first; if(packet) { //ECHO_LOG_DEBUG(packet->GetData()); //ECHO_LOG_DEBUG("Connection::Write() - mCanSend==true"); //ECHO_LOG_DEBUG("Connection::Write():Packet:0x" << std::hex << packet->GetPacketTypeID() << std::dec); if(!mHeaderSent) { DataPacketHeader header; header.BuildForPacket(*packet); //ECHO_LOG_DEBUG("0x" << std::hex << this << std::dec << ": Send Header"); s32 bytesSent=_send((const u8*)header.GetHeaderData(),header.GetHeaderDataSizeInBytes(),0); //ECHO_LOG_DEBUG("0x" << std::hex << socketNum << std::dec); if(!_HandleWriteError(bytesSent)) { //ECHO_LOG_DEBUG("mCanSend=false; for header"); return; } totalDataSent+=bytesSent; mNetworkManager.ReportSentData(bytesSent); //ECHO_LOG_DEBUG("Sent: Header"); mHeaderSent=true; //See if this is just a control packet - If the data packet is not the same as the header if(packet->SendHeaderOnly()) { packet.reset(); mHeaderSent=false; if(mQueuedPackets.front().second) { Disconnect(); mQueuedPackets.pop_front(); return; } mQueuedPackets.pop_front(); continue; } } u8* data=&(packet->mData[packet->mSize-packet->mReceived]); //ECHO_LOG_DEBUG("0x" << std::hex << this << std::dec << ": Send Data"); s32 bytesSent=_send((const u8*)data,packet->mReceived,0); if(_HandleWriteError(bytesSent)) { packet->mReceived-=bytesSent; //ECHO_LOG_DEBUG("Sent: " << packet->mSize-packet->mReceived << " of " << packet->mSize << " bytes"); if(packet->mReceived==0) //All our data was sent { packet.reset(); mHeaderSent=false; if(mQueuedPackets.front().second) { Disconnect(); mQueuedPackets.pop_front(); return; } mQueuedPackets.pop_front(); } totalDataSent+=bytesSent; mNetworkManager.ReportSentData(bytesSent); }else { //ECHO_LOG_DEBUG("mCanSend=false; for packet"); return; } } } mCanSend=true; } void Connection::SendDataPacket(shared_ptr packet, PacketCallback responseCallback, bool prioritise, bool disconnectAfterSend, bool isResponsePacket) { // In the CONNECTING or CONNECTED states we should treat it as connected. // If we're not connected, should we silently discard? if(mState==States::DISCONNECTED && !mQueueDataPacketsIfNotConnected) { return; } //ECHO_LOG_DEBUG("SendDataPacket: " << mQueuedPackets.size()); if(packet) { if(packet->mReceived!=packet->mSize) { packet->mReceived=packet->mSize; } if(!isResponsePacket) { // All packets are sent through this method. packet->mPacketID = mNextPacketID++; } mQueuedPacketsMutex.Lock(); if(prioritise) { mQueuedPackets.push_front( std::pair< shared_ptr, bool >(packet, disconnectAfterSend) ); }else { mQueuedPackets.push_back( std::pair< shared_ptr, bool >(packet, disconnectAfterSend) ); } if(responseCallback) { mResponseCallbacks[packet->mPacketID] = responseCallback; } mQueuedPacketsMutex.Unlock(); //ECHO_LOG_DEBUG("SendDataPacket2: " << mQueuedPackets.size()); Write(false); } } void Connection::SendData(const u8* data, u32 dataSize, u32 packetTypeID, PacketCallback responseCallback, bool prioritise) { shared_ptr packet=NewDataPacket(); packet->Configure(packetTypeID,dataSize); //Copy the data packet->AppendData(data,dataSize); SendDataPacket(packet,responseCallback,prioritise); } void Connection::SendMessage(const std::string& message, u32 packetTypeID, PacketCallback responseCallback, bool prioritise) { shared_ptr packet=NewDataPacket(); packet->Configure(message); packet->SetPacketTypeID(packetTypeID); packet->AppendString(message); SendDataPacket(packet,responseCallback,prioritise); } void Connection::SendControlPacket(u32 packetTypeID, PacketCallback responseCallback, bool prioritise) { shared_ptr packet=NewDataPacket(); packet->Configure(packetTypeID,0); SendDataPacket(packet,responseCallback,prioritise); } void Connection::SendLabelledPacket(const std::string& label, const u8* data, u32 dataSize, PacketCallback responseCallback, bool prioritise) { shared_ptr packet=NewDataPacket(); packet->Configure(PacketTypes::LABELLED_PACKET,label.length()+DataPacket::NUMBYTES_FOR_STRING_HEADER+dataSize); packet->AppendString(label); packet->AppendData(data,dataSize); SendDataPacket(packet,responseCallback,prioritise); } void Connection::SendLabelledPacket(const std::string& label, const std::string& content, PacketCallback responseCallback, bool prioritise) { shared_ptr packet=NewDataPacket(); packet->Configure(label,content); SendDataPacket(packet,responseCallback,prioritise); } void Connection::SendLabelledPacket(const std::string& label, const std::vector& content, PacketCallback responseCallback, bool prioritise) { shared_ptr packet=NewDataPacket(); packet->Configure(label,content); SendDataPacket(packet,responseCallback,prioritise); } void Connection::SendDataPacketResponse(shared_ptr packetRespondingTo, shared_ptr packet, bool prioritise, bool disconnectAfterSend) { packet->SetPacketID(packetRespondingTo->GetPacketID()); if(packet->IsLabelledPacket()) { packet->SetPacketTypeID(PacketTypes::LABELLED_RESPONSE_PACKET); }else { packet->SetPacketTypeID(PacketTypes::RESPONSE_PACKET); } SendDataPacket(packet,PacketCallback(),prioritise,disconnectAfterSend,true); } void Connection::SendDataResponse(shared_ptr packetRespondingTo, const u8* data, u32 dataSize, bool prioritise, bool disconnectAfterSend) { shared_ptr packet=NewDataPacket(); packet->Configure(PacketTypes::RESPONSE_PACKET,dataSize); //Copy the data packet->AppendData(data,dataSize); packet->SetPacketID(packetRespondingTo->GetPacketID()); SendDataPacket(packet,PacketCallback(),prioritise,disconnectAfterSend,true); } void Connection::SendMessageResponse(shared_ptr packetRespondingTo, const std::string& message, bool prioritise, bool disconnectAfterSend) { shared_ptr packet=NewDataPacket(); packet->Configure(message); packet->SetPacketID(packetRespondingTo->GetPacketID()); packet->SetPacketTypeID(PacketTypes::RESPONSE_PACKET); SendDataPacket(packet,PacketCallback(),prioritise,disconnectAfterSend,true); } void Connection::SendLabelledPacketResponse(shared_ptr packetRespondingTo, const std::string& label, const u8* data, u32 dataSize, bool prioritise, bool disconnectAfterSend) { shared_ptr packet=NewDataPacket(); packet->Configure(PacketTypes::LABELLED_RESPONSE_PACKET,label.length()+DataPacket::NUMBYTES_FOR_STRING_HEADER+dataSize); packet->AppendString(label); packet->AppendData(data,dataSize); packet->SetPacketID(packetRespondingTo->GetPacketID()); SendDataPacket(packet,PacketCallback(),prioritise,disconnectAfterSend,true); } void Connection::SendLabelledPacketResponse(shared_ptr packetRespondingTo, const std::string& label, const std::string& content, bool prioritise, bool disconnectAfterSend) { shared_ptr packet=NewDataPacket(); packet->Configure(label,content); packet->SetPacketID(packetRespondingTo->GetPacketID()); packet->SetPacketTypeID(PacketTypes::LABELLED_RESPONSE_PACKET); SendDataPacket(packet,PacketCallback(),prioritise,disconnectAfterSend,true); } void Connection::SendLabelledPacketResponse(shared_ptr packetRespondingTo, const std::string& label, const std::vector& content, bool prioritise, bool disconnectAfterSend) { shared_ptr packet=NewDataPacket(); packet->Configure(label,content); packet->SetPacketID(packetRespondingTo->GetPacketID()); packet->SetPacketTypeID(PacketTypes::LABELLED_RESPONSE_PACKET); SendDataPacket(packet,PacketCallback(),prioritise,disconnectAfterSend,true); } void Connection::SendHostDetails() { #ifdef ECHO_LITTLE_ENDIAN std::string hostDetails="1.0:little"; #elif defined(ECHO_BIG_ENDIAN) std::string hostDetails="1.0:big"; #else #error "Unable to determine host endianess" #endif ECHO_LOG_DEBUG("Sending host details " << hostDetails); SendMessage(hostDetails,PacketTypes::REMOTE_DETAILS, PacketCallback(), true); } void Connection::OnRemoteDetails(shared_ptr connection, shared_ptr dataPacket) { std::string remoteDetails; if(!dataPacket->GetStringFromDataPacket(remoteDetails)) { connection->Disconnect(); return; } //Remote details needs to be "protocolVersion:endian" // protocolVersion - 1.0 // endian - "big" or "little" ECHO_LOG_DEBUG("Remote details: " << remoteDetails); std::vector parameters; Utils::String::Split(remoteDetails,":",parameters); if(parameters.size()<2) { ECHO_LOG_WARNING("Received remote details without enough parameters - " << remoteDetails << ". Diconnecting."); connection->Disconnect(); return; } if(parameters[0]!="1.0") { ECHO_LOG_WARNING("Received remote details specifying an unsupported protocol version - " << parameters[0] << ". Diconnecting."); connection->Disconnect(); return; } bool bigEndian = true; if(parameters[1]=="little") { bigEndian = false; }else if(parameters[1]!="big") { ECHO_LOG_WARNING("Received remote details specifying an unsupported endian format - " << parameters[1] << ". Diconnecting."); connection->Disconnect(); return; } connection->SetRemoteBigEndian(bigEndian); } bool Connection::NotifyAnyReceivedPackets() { //Only process as many packets as we received in the last frame to avoid getting stuck //in here if we have a constant stream of packets. //Don't worry about locking the mutex, if we're off by one (less) then the remaining //packet will be in the next frame, and probably should have been anyway. mReceviedPacketsMutex.Lock(); std::list< shared_ptr > packetsToProcess = std::move(mReceviedPackets); mReceviedPacketsMutex.Unlock(); while(!packetsToProcess.empty()) { shared_ptr packet=packetsToProcess.front(); packetsToProcess.pop_front(); if(packet->IsResponsePacket()) { std::map< u32, PacketCallback >::iterator rit=mResponseCallbacks.find(packet->GetPacketID()); if(rit!=mResponseCallbacks.end()) { rit->second(shared_from_this(),packet); mResponseCallbacks.erase(rit); } }else { std::map< u32, std::vector >::iterator it=mPacketCallbacks.find(packet->GetPacketTypeID()); if(it!=mPacketCallbacks.end()) { BOOST_FOREACH(PacketCallback& callback, it->second) { callback(shared_from_this(),packet); } } } _notifyOwner(packet); } return !mReceviedPackets.empty(); } void Connection::ProcessReceivedPacket(shared_ptr packet) { mReceviedPacketsMutex.Lock(); mReceviedPackets.push_back(packet); mReceviedPacketsMutex.Unlock(); } void Connection::ProcessLabelledPacket(shared_ptr connection, shared_ptr packet) { Size dataOffset = 0; std::string label = packet->GetLabel(&dataOffset); if(label.empty()) { ECHO_LOG_WARNING("Packet with ID LABELLED_PACKET is invalid or does not contain a label."); return; } std::map< std::string, std::vector >::iterator it=mLabelledPacketCallbacks.find(label); if(it!=mLabelledPacketCallbacks.end()) { u8* dataStart = &(packet->GetData()[dataOffset]); Size numberOfBytes = (packet->GetDataSize()-dataOffset); BOOST_FOREACH(LabelledPacketCallback& callback, it->second) { callback(connection,packet,dataStart,numberOfBytes); } } } std::string Connection::GetFriendlyIdentifier() { return mConnectionDetails.ToString(); } std::string Connection::GetLocalFriendlyIdentifier() { return mLocalConnectionDetails.ToString(); } void Connection::RegisterLabelledPacketCallback(const std::string& label, LabelledPacketCallback callback) { if(callback) { mLabelledPacketCallbacks[label].push_back(callback); } } void Connection::RegisterPacketCallback(u32 packetTypeID, PacketCallback callback) { if(callback) { mPacketCallbacks[packetTypeID].push_back(callback); } } void Connection::RegisterConnectCallback(const std::string& identifier, ConnectCallback callback) { if(callback) { mConnectCallbacks[identifier].push_back(callback); } } void Connection::RegisterDisconnectCallback(const std::string& identifier, DisconnectCallback callback) { if(callback) { mDisconnectCallbacks[identifier].push_back(callback); } } void Connection::ClearLabelledPacketCallbacks(const std::string& label) { mLabelledPacketCallbacks.erase(label); } void Connection::ClearAllLabelledPacketCallbacks() { mLabelledPacketCallbacks.clear(); } void Connection::ClearPacketIDCallbacks(u32 packetTypeID) { mPacketCallbacks.erase(packetTypeID); } void Connection::ClearAllPacketIDCallbacks() { mPacketCallbacks.clear(); } void Connection::ClearAllPacketCallbacks() { ClearAllPacketIDCallbacks(); ClearAllLabelledPacketCallbacks(); } void Connection::ClearConnectCallbacks(const std::string& identifier) { mConnectCallbacks.erase(identifier); } void Connection::ClearAllConnectCallbacks() { mConnectCallbacks.clear(); } void Connection::ClearDisconnectCallbacks(const std::string& identifier) { mDisconnectCallbacks.erase(identifier); } void Connection::ClearAllDisconnectCallbacks() { mDisconnectCallbacks.clear(); } void Connection::SetAutoAttemptReconnectTime(Seconds seconds) { mReconnectTimer.SetTimeout(seconds); } }