/*++ Copyright (c) 2001 Microsoft Corporation Module Name: transport.cxx Abstract: transport Author: Larry Zhu (LZhu) January 1, 2002 Created Environment: User Mode Revision History: --*/ #include "precomp.hxx" #pragma hdrstop #include "sockcomm.h" #include "transport.hxx" ULONG g_MessageNumTlsIndex = TLS_OUT_OF_INDEXES; ULONG g_MsgHeaderLen = kMsgHeaderLen; HRESULT ServerInit( IN USHORT Port, IN OPTIONAL PCSTR pszDescription, OUT SOCKET* pSocketListen ) { THResult hRetval = S_OK; SOCKADDR_IN sin = {0}; SOCKET sockListen = INVALID_SOCKET; int nRes = SOCKET_ERROR; *pSocketListen = INVALID_SOCKET; // // create listening socket // sockListen = socket(PF_INET, SOCK_STREAM, 0); hRetval DBGCHK = (INVALID_SOCKET != sockListen) ? S_OK : HResultFromWin32(WSAGetLastError()); // // bind to local port // if (SUCCEEDED(hRetval)) { sin.sin_family = AF_INET; sin.sin_addr.s_addr = 0; sin.sin_port = htons(Port); nRes = bind(sockListen, (PSOCKADDR) &sin, sizeof(sin)); DBGCFG1(hRetval, HRESULT_FROM_WIN32(WSAEADDRINUSE)); hRetval DBGCHK = (SOCKET_ERROR != nRes) ? S_OK : HResultFromWin32(WSAGetLastError()); if (FAILED(hRetval) && (WSAEADDRINUSE == HRESULT_CODE(hRetval))) { DebugPrintf(SSPI_ERROR, "ServerInit port %d(%#x) in use, failed to bind\n", Port, Port); } else if (FAILED(hRetval)) { DebugPrintf(SSPI_ERROR, "ServerInit binding to port %d failed with %#x\n", Port, HRESULT_CODE(hRetval)); } } // // listen for client // if (SUCCEEDED(hRetval)) { DebugPrintf(SSPI_LOG, "%s%slistening on port %d(%#x)\n", pszDescription ? pszDescription : "", pszDescription ? " " : "", Port, Port); nRes = listen(sockListen, 1); hRetval DBGCHK = (SOCKET_ERROR != nRes) ? S_OK : HResultFromWin32(WSAGetLastError()); } if (SUCCEEDED(hRetval)) { *pSocketListen = sockListen; sockListen = INVALID_SOCKET; } THResult hr; if (sockListen != INVALID_SOCKET) { hr DBGCHK = closesocket(sockListen) == ERROR_SUCCESS ? S_OK : HResultFromWin32(WSAGetLastError()); } return hRetval; } HRESULT ClientConnect( IN OPTIONAL PCSTR pszServer, IN USHORT Port, OUT SOCKET* pSocketConnected ) { THResult hRetval = S_OK; SOCKET sockServer = INVALID_SOCKET; ULONG ulAddress = INADDR_NONE; struct hostent *pHost = NULL; SOCKADDR_IN sin = {0}; CHAR szServer[DNS_MAX_NAME_LENGTH + 1] = {0}; // // lookup the address for the server name // if (!pszServer) { hRetval DBGCHK = gethostname(szServer, sizeof(szServer) - 1) == ERROR_SUCCESS ? S_OK : HResultFromWin32(WSAGetLastError()); if (SUCCEEDED(hRetval)) { pszServer = szServer; } } if (SUCCEEDED(hRetval)) { ulAddress = inet_addr(pszServer); if (INADDR_NONE == ulAddress) { pHost = gethostbyname(pszServer); hRetval DBGCHK = pHost ? S_OK : HResultFromWin32(WSAGetLastError()); if (SUCCEEDED(hRetval)) { RtlCopyMemory((CHAR *)&ulAddress, pHost->h_addr, pHost->h_length); } } } // // create the socket // if (SUCCEEDED(hRetval)) { sockServer = socket(PF_INET, SOCK_STREAM, 0); hRetval DBGCHK = (INVALID_SOCKET == sockServer) ? S_OK : HResultFromWin32(WSAGetLastError()); } if (SUCCEEDED(hRetval)) { sin.sin_family = AF_INET; sin.sin_addr.s_addr = ulAddress; sin.sin_port = htons(Port); // // connect to remote endpoint // hRetval DBGCHK = connect(sockServer, (PSOCKADDR) &sin, sizeof(sin)) == ERROR_SUCCESS ? S_OK : HResultFromWin32(WSAGetLastError()); } if (SUCCEEDED(hRetval)) { *pSocketConnected = sockServer; sockServer = INVALID_SOCKET; } THResult hr; if (INVALID_SOCKET != sockServer) { hr DBGCHK = closesocket(sockServer) == ERROR_SUCCESS ? S_OK : HResultFromWin32(WSAGetLastError()); } return hRetval; } HRESULT WriteMessage( IN SOCKET s, IN ULONG cbBuf, IN VOID* pBuf ) { THResult hRetval = S_OK; ULONG cbWritten = 0; ULONG cbRead = 0; ULONG* pMessageNum = NULL; hRetval DBGCHK = GetPerThreadpMessageNum(g_MessageNumTlsIndex, &pMessageNum); if (SUCCEEDED(hRetval)) { hRetval DBGCHK = SendMsg(s, cbBuf, pBuf) ? S_OK : GetLastErrorAsHResult(); } if (SUCCEEDED(hRetval)) { hRetval DBGCHK = ReceiveMsg(s, sizeof(cbWritten), &cbWritten, &cbRead); } if (SUCCEEDED(hRetval)) { hRetval DBGCHK = (cbWritten == cbBuf) && (cbRead == sizeof(cbWritten)) ? S_OK : E_FAIL; } if (SUCCEEDED(hRetval)) { CHAR szBanner[MAX_PATH] = {0}; _snprintf(szBanner, sizeof(szBanner), "*******Message #%#x sent %#x bytes:*********", (*pMessageNum)++, cbWritten); DebugPrintHex(SSPI_MSG, szBanner, min(cbBuf, g_MsgHeaderLen), pBuf); } else { DebugPrintf(SSPI_ERROR, "cbWritten %#x, cbBuf %#x, cbRead %#x\n", cbWritten, cbBuf, cbRead); } return hRetval; } HRESULT GetPerThreadpMessageNum( IN ULONG Index, OUT ULONG** ppMessageNum ) { THResult hRetval = E_FAIL; ULONG* pMsgNum = NULL; *ppMessageNum = NULL; hRetval DBGCHK = (TLS_OUT_OF_INDEXES != Index) ? S_OK : E_INVALIDARG; if (SUCCEEDED(hRetval)) { pMsgNum = (ULONG*) TlsGetValue(g_MessageNumTlsIndex); } hRetval DBGCHK = pMsgNum ? S_OK : GetLastErrorAsHResult(); // last error can be NO_ERROR if (SUCCEEDED(hRetval) && !pMsgNum) { hRetval DBGCHK = E_POINTER; } if (SUCCEEDED(hRetval)) { *ppMessageNum = pMsgNum; } return hRetval; } HRESULT ReadMessage( IN SOCKET s, IN ULONG cbBuf, IN OUT VOID* pBuf, OUT ULONG* pcbRead ) { THResult hRetval = S_OK; ULONG* pMessageNum = NULL; hRetval DBGCHK = GetPerThreadpMessageNum(g_MessageNumTlsIndex, &pMessageNum); if (SUCCEEDED(hRetval)) { hRetval DBGCHK = ReceiveMsg(s, cbBuf, pBuf, pcbRead) ? S_OK : GetLastErrorAsHResult(); } if (SUCCEEDED(hRetval)) { hRetval DBGCHK = SendMsg(s, sizeof(*pcbRead), pcbRead); } if (SUCCEEDED(hRetval)) { CHAR szBanner[MAX_PATH] = {0}; _snprintf(szBanner, sizeof(szBanner), "*********Message #%#x received %#x bytes:**********", (*pMessageNum)++, *pcbRead); DebugPrintHex(SSPI_MSG, szBanner, min(*pcbRead, g_MsgHeaderLen), pBuf); } return hRetval; }