/*++ Copyright (c) 1999-2000 Microsoft Corporation Module Name : sessmgr.cpp Abstract: Keeps track of the collection of session objects for the RDP device redirector Revision History: --*/ #include "precomp.hxx" #define TRC_FILE "sessmgr" #include "trc.h" DrSessionManager::DrSessionManager() { BEGIN_FN("DrSessionManager::DrSessionManager"); SetClassName("DrSessionManager"); } DrSessionManager::~DrSessionManager() { BEGIN_FN("DrSessionManager::~DrSessionManager"); } BOOL DrSessionManager::AddSession(SmartPtr &Session) { DrSession *SessionT; SmartPtr SessionFound; BOOL rc = FALSE; BEGIN_FN("DrSessionManager::AddSession"); // // Create a new SmartPtr to track in the list // ASSERT(Session != NULL); _SessionList.LockExclusive(); if (FindSessionById(Session->GetSessionId(), SessionFound)) { rc = FALSE; goto EXIT; } SessionT = Session; SessionT->AddRef(); // // Add it to the list // if (_SessionList.CreateEntry((PVOID)SessionT)) { // // successfully added this entry // rc = TRUE; } else { // // Unable to add it to the list, clean up // SessionT->Release(); rc = FALSE; } EXIT: _SessionList.Unlock(); return rc; } BOOL DrSessionManager::OnConnect(PCHANNEL_CONNECT_IN ConnectIn, PCHANNEL_CONNECT_OUT ConnectOut) { SmartPtr ConnectingSession; BOOL Reconnect = FALSE; BOOL Connected = FALSE; BOOL Added = FALSE; BEGIN_FN("DrSessionManager::OnConnect"); ASSERT(ConnectIn != NULL); ASSERT(ConnectOut != NULL); // Clear out the output buffer by default ConnectOut->hdr.contextData = (UINT_PTR)0; Reconnect = FindSessionById(ConnectIn->hdr.sessionID, ConnectingSession); if (Reconnect) { TRC_DBG((TB, "Reconnecting session %d", ConnectIn->hdr.sessionID)); } else { TRC_DBG((TB, "Connecting session %d", ConnectIn->hdr.sessionID)); ConnectingSession = new(NonPagedPool) DrSession; if (ConnectingSession != NULL) { TRC_DBG((TB, "Created new session")); if (!ConnectingSession->Initialize()) { TRC_DBG((TB, "Session couldn't initialize")); ConnectingSession = NULL; } } else { TRC_ERR((TB, "Failed to allocate new session")); } } if (ConnectingSession != NULL) { Connected = ConnectingSession->Connect(ConnectIn, ConnectOut); } if (Connected) { TRC_DBG((TB, "Session connected, adding")); if (!Reconnect) { Added = AddSession(ConnectingSession); if (!Added) { if (FindSessionById(ConnectIn->hdr.sessionID, ConnectingSession)) { Added = TRUE; } } } else { // Don't add what we found there anyway Added = TRUE; } } if (Added) { // Stash this here for the disconnect notification TRC_DBG((TB, "Added session")); ConnectingSession->AddRef(); ConnectOut->hdr.contextData = (UINT_PTR)-1; } return Added; } VOID DrSessionManager::OnDisconnect(PCHANNEL_DISCONNECT_IN DisconnectIn, PCHANNEL_DISCONNECT_OUT DisconnectOut) { SmartPtr DisconnectingSession; BEGIN_FN("DrSessionManager::OnDisconnect"); ASSERT(DisconnectIn != NULL); ASSERT(DisconnectOut != NULL); if (DisconnectIn->hdr.contextData == (UINT_PTR)-1 && FindSessionById(DisconnectIn->hdr.sessionID, DisconnectingSession)) { TRC_NRM((TB, "Closing session for doctored client.")); ASSERT(DisconnectingSession->IsValid()); DisconnectingSession->Disconnect(DisconnectIn, DisconnectOut); DisconnectingSession->Release(); } else { // // Must not have been a "doctored" client // TRC_NRM((TB, "Undoctored session ending")); } // // make sure the output context is blank // DisconnectOut->hdr.contextData = (UINT_PTR)0; } BOOL DrSessionManager::FindSessionById(ULONG SessionId, SmartPtr &SessionFound) { DrSession *SessionEnum; ListEntry *ListEnum; BOOL Found = FALSE; BEGIN_FN("DrSessionManager::FindSessionById"); _SessionList.LockShared(); ListEnum = _SessionList.First(); while (ListEnum != NULL) { SessionEnum = (DrSession *)ListEnum->Node(); if (SessionEnum->GetSessionId() == SessionId) { SessionFound = SessionEnum; // // These aren't guaranteed valid once the resource is released // SessionEnum = NULL; ListEnum = NULL; break; } ListEnum = _SessionList.Next(ListEnum); } _SessionList.Unlock(); return SessionFound != NULL; } BOOL DrSessionManager::FindSessionByIdAndClient(ULONG SessionId, ULONG ClientId, SmartPtr &SessionFound) { DrSession *SessionEnum; ListEntry *ListEnum; BOOL Found = FALSE; BEGIN_FN("DrSessionManager::FindSessionByIdAndClient"); _SessionList.LockShared(); ListEnum = _SessionList.First(); while (ListEnum != NULL) { SessionEnum = (DrSession *)ListEnum->Node(); ASSERT(SessionEnum->IsValid()); if ((SessionEnum->GetSessionId() == SessionId) && (SessionEnum->GetClientId() == ClientId)) { SessionFound = SessionEnum; Found = TRUE; // // These aren't guaranteed valid past EndEnumeration() anyway // SessionEnum = NULL; ListEnum = NULL; break; } ListEnum = _SessionList.Next(ListEnum); } _SessionList.Unlock(); return Found; } BOOL DrSessionManager::FindSessionByClientName(PWCHAR ClientName, SmartPtr &SessionFound) { DrSession *SessionEnum; ListEntry *ListEnum; BOOL Found = FALSE; BEGIN_FN("DrSessionManager::FindSessionByClientName"); _SessionList.LockShared(); ListEnum = _SessionList.First(); while (ListEnum != NULL) { SessionEnum = (DrSession *)ListEnum->Node(); if (_wcsicmp(SessionEnum->GetClientName(), ClientName) == 0) { SessionFound = SessionEnum; // // These aren't guaranteed valid once the resource is released // SessionEnum = NULL; ListEnum = NULL; break; } ListEnum = _SessionList.Next(ListEnum); } _SessionList.Unlock(); return SessionFound != NULL; } VOID DrSessionManager::Remove(DrSession *Session) { DrSession *SessionEnum; ListEntry *ListEnum; BOOL Found = FALSE; BEGIN_FN("DrSessionManager::Remove"); _SessionList.LockExclusive(); ListEnum = _SessionList.First(); while (ListEnum != NULL) { SessionEnum = (DrSession *)ListEnum->Node(); ASSERT(SessionEnum->IsValid()); if (SessionEnum == Session) { Found = TRUE; _SessionList.RemoveEntry(ListEnum); // // These aren't guaranteed valid past EndEnumeration() anyway // SessionEnum = NULL; ListEnum = NULL; break; } ListEnum = _SessionList.Next(ListEnum); } _SessionList.Unlock(); }