|
|
//
// Protocol.cpp
//
// Install, uninstall, and related code for protocols such as TCP/IP.
//
// History:
//
// 2/02/1999 KenSh Created for JetNet
// 9/29/1999 KenSh Repurposed for Home Networking Wizard
//
#include "stdafx.h"
#include "NetConn.h"
#include "nconnwrap.h"
#include "ParseInf.h"
#include "TheApp.h"
#include "HookUI.h"
#include "NetCli.h"
// Local functions
//
VOID RemoveOrphanedProtocol(LPCSTR pszProtocolID); HRESULT CreateNewProtocolBinding(LPCSTR pszProtocolDeviceID, LPSTR pszBuf, int cchBuf, LPCSTR pszClientBinding, LPCSTR pszServiceBinding);
// IsProtocolInstalled
//
// Returns TRUE if one or more instances of the given protocol
// (e.g. "MSTCP") is bound to a network adapter.
//
BOOL WINAPI IsProtocolInstalled(LPCTSTR pszProtocolDeviceID, BOOL bExhaustive) { if (!IsProtocolBoundToAnyAdapter(pszProtocolDeviceID)) return FALSE;
if (bExhaustive) { TCHAR szInfSection[50]; wsprintf(szInfSection, "%s.Install", pszProtocolDeviceID); if (!CheckInfSectionInstallation("nettrans.inf", szInfSection)) return FALSE; }
return TRUE; }
// InstallProtocol (public)
//
// Installs the given protocol via NETDI, and binds it to all adapters.
// The standard progress UI is (mostly) suppressed, and instead the given
// callback function is called so a custom progress UI can be implemented.
//
// Returns a NETCONN_xxx result, defined in NetConn.h
//
// Parameters:
//
// hwndParent parent window helpful to use in NETDI call
// pfnProgress function to call with install progress reports
// pvProgressParam user-supplied parameter to pass to pfnProgress
//
// History:
//
// 2/23/1999 KenSh Created
// 3/26/1999 KenSh Check if already installed before reinstalling
//
HRESULT WINAPI InstallProtocol(LPCSTR pszProtocolID, HWND hwndParent, PROGRESS_CALLBACK pfnProgress, LPVOID pvProgressParam) { HRESULT hr = NETCONN_SUCCESS;
if (!IsProtocolBoundToAnyAdapter(pszProtocolID)) { RemoveOrphanedProtocol(pszProtocolID);
BeginSuppressNetdiUI(hwndParent, pfnProgress, pvProgressParam); DWORD dwResult = CallClassInstaller16(hwndParent, SZ_CLASS_PROTOCOL, pszProtocolID); EndSuppressNetdiUI();
if (g_bUserAbort) { hr = NETCONN_USER_ABORT; } else if (SUCCEEDED(HresultFromCCI(dwResult))) { hr = NETCONN_NEED_RESTART; }
// Total hack to work around JetNet bug 1193
// DoDummyDialog(hwndParent);
}
if (SUCCEEDED(hr)) { // Ensure the protocol is bound exactly once to every NIC
HRESULT hr2 = BindProtocolToAllAdapters(pszProtocolID); if (hr2 != NETCONN_SUCCESS) hr = hr2; }
return hr; }
// InstallTCPIP (public)
//
// Installs TCP/IP via NETDI. See InstallProtocol for details.
//
HRESULT WINAPI InstallTCPIP(HWND hwndParent, PROGRESS_CALLBACK pfnProgress, LPVOID pvProgressParam) { return InstallProtocol(SZ_PROTOCOL_TCPIP, hwndParent, pfnProgress, pvProgressParam); }
HRESULT WINAPI RemoveProtocol(LPCSTR pszProtocol) { HRESULT hr = NETCONN_SUCCESS;
// Remove all pointers to the protocol from net adapters
NETADAPTER* prgAdapters; int cAdapters = EnumNetAdapters(&prgAdapters); for (int iAdapter = 0; iAdapter < cAdapters; iAdapter++) { NETADAPTER* pAdapter = &prgAdapters[iAdapter];
LPTSTR* prgBindings; int cBindings = EnumMatchingNetBindings(pAdapter->szEnumKey, pszProtocol, &prgBindings); if (cBindings > 0) { CRegistry regBindings; if (regBindings.OpenKey(HKEY_LOCAL_MACHINE, pAdapter->szEnumKey) && regBindings.OpenSubKey("Bindings")) { for (int iBinding = 0; iBinding < cBindings; iBinding++) { regBindings.DeleteValue(prgBindings[iBinding]); hr = NETCONN_NEED_RESTART; } } } NetConnFree(prgBindings); } NetConnFree(prgAdapters);
// Remove the protocol's enum key
CRegistry reg; if (reg.OpenKey(HKEY_LOCAL_MACHINE, "Enum\\Network")) { RegDeleteKeyAndSubKeys(reg.m_hKey, pszProtocol); }
// Remove the protocol's class key(s)
DeleteClassKeyReferences(SZ_CLASS_PROTOCOL, pszProtocol);
return hr; }
// IsProtocolBoundToAnyAdapter (public)
//
// Given a protocol ID, such as "MSTCP", returns TRUE if the protocol
// is bound to any adapter, or FALSE if not.
//
// History:
//
// 3/26/1999 KenSh Created
//
BOOL WINAPI IsProtocolBoundToAnyAdapter(LPCSTR pszProtocolID) { BOOL bResult = FALSE;
NETADAPTER* prgAdapters; int cAdapters = EnumNetAdapters(&prgAdapters); for (int i = 0; i < cAdapters; i++) { if (IsProtocolBoundToAdapter(pszProtocolID, &prgAdapters[i])) { bResult = TRUE; goto done; } }
done: NetConnFree(prgAdapters); return bResult; }
// Given a protocol ID, such as "MSTCP", and an adapter struct, determines
// whether the protocol is bound to the adapter.
BOOL WINAPI IsProtocolBoundToAdapter(LPCSTR pszProtocolID, const NETADAPTER* pAdapter) { LPSTR* prgBindings; int cBindings = EnumMatchingNetBindings(pAdapter->szEnumKey, pszProtocolID, &prgBindings); NetConnFree(prgBindings);
return (BOOL)cBindings; }
// Checks to see whether any instances of this protocol in the Class branch of
// the registry are unreferenced, and deletes them if so.
// pszProtocolID is the generic device ID of the protocol, e.g. "MSTCP"
VOID RemoveOrphanedProtocol(LPCSTR pszProtocolID) { // REVIEW: Should we first delete references in Enum that are not in use?
CRegistry reg; if (reg.OpenKey(HKEY_LOCAL_MACHINE, "System\\CurrentControlSet\\Services\\Class\\NetTrans", KEY_ALL_ACCESS)) { //
// Enumerate the various protocols, e.g. "NetTrans\0000"
//
for (DWORD iKey = 0; ; iKey++) { CHAR szSubKey[MAX_PATH]; DWORD cbSubKey = _countof(szSubKey); if (ERROR_SUCCESS != RegEnumKeyEx(reg.m_hKey, iKey, szSubKey, &cbSubKey, NULL, NULL, NULL, NULL)) break;
//
// Open the "Ndi" subkey so we can see what kind of protocol this is
//
lstrcpy(szSubKey + cbSubKey, "\\Ndi"); CRegistry regNode; if (regNode.OpenKey(reg.m_hKey, szSubKey, KEY_ALL_ACCESS)) { CHAR szDeviceID[40]; if (regNode.QueryStringValue("DeviceID", szDeviceID, _countof(szDeviceID))) { regNode.CloseKey(); // close the key before we try to delete it
if (0 == lstrcmpi(szDeviceID, pszProtocolID)) { //
// Found the right protocol, now check if it's referenced
//
if (!IsNetClassKeyReferenced(szSubKey)) { // Not referenced, so delete it
szSubKey[cbSubKey] = '\0'; // back up to just "NetTrans"
RegDeleteKeyAndSubKeys(reg.m_hKey, szSubKey); } } } } } } }
// BindProtocolToAdapter
//
// pszProtocolEnumKey is the top-level Enum key for the protocol, e.g. "MSTCP"
//
// pszAdapterEnumKey is the first part of the top-level Enum key for the adapter,
// e.g. "PCI\\VEN_10B7&DEV_9050" or "Root\\Net\\0000" (dial-up adapter)
//
// bEnableSharing determines whether file and printer sharing will be bound
// to the protocol when running through the given adapter.
//
// History:
//
// 3/26/1999 KenSh Created
// 4/09/1999 KenSh Added bEnableSharing flag
//
HRESULT BindProtocolToAdapter(HKEY hkeyAdapterBindings, LPCSTR pszProtocolDeviceID, BOOL bEnableSharing) { HRESULT hr;
// NetTrans seems to always clone the Enum key, and also clone the Class key.
// (the new clone of the Enum key points to the new clone of the Class key;
// each new Enum's MasterCopy points to itself)
// NetService and NetClient seem to clone the Enum key, but not the Class key.
CHAR szClient[MAX_PATH]; CHAR szService1[MAX_PATH]; CHAR szService2[MAX_PATH]; CHAR szProtocol[MAX_PATH];
if (bEnableSharing) { if (FAILED(hr = CreateNewFilePrintSharing(szService1, _countof(szService1)))) return hr;
if (FAILED(hr = CreateNewFilePrintSharing(szService2, _countof(szService2)))) return hr; } else { szService1[0] = '\0'; szService2[0] = '\0'; }
if (FAILED(hr = CreateNewClientForMSNet(szClient, _countof(szClient), szService1))) return hr;
if (FAILED(hr = CreateNewProtocolBinding(pszProtocolDeviceID, szProtocol, _countof(szProtocol), szClient, szService2))) return hr;
// Bind the new protocol to the adapter
if (ERROR_SUCCESS != RegSetValueEx(hkeyAdapterBindings, szProtocol, 0, REG_SZ, (CONST BYTE*)"", 1)) return NETCONN_UNKNOWN_ERROR;
return NETCONN_NEED_RESTART; }
HRESULT BindProtocolToAllAdapters_Helper(LPCSTR pszProtocolDeviceID, LPCSTR pszAdapterKey, BOOL bIgnoreVirtualNics) { HRESULT hr = NETCONN_SUCCESS;
// Get LowerRange interfaces for protocol
CHAR szProtocolLower[100]; GetDeviceLowerRange(SZ_CLASS_PROTOCOL, pszProtocolDeviceID, szProtocolLower, _countof(szProtocolLower));
// For each adapter, ensure the protocol is bound exactly once
//
NETADAPTER* prgAdapters; int cAdapters = EnumNetAdapters(&prgAdapters);
// Pass 0: add new bindings
// Pass 1: delete inappropriate bindings
for (int iPass = 0; iPass <= 1; iPass++) { for (int iAdapter = 0; iAdapter < cAdapters; iAdapter++) { NETADAPTER* pAdapter = &prgAdapters[iAdapter]; CRegistry regAdapter;
// Get UpperRange interfaces for adapter
CHAR szAdapterUpper[100]; GetDeviceUpperRange(SZ_CLASS_ADAPTER, pAdapter->szDeviceID, szAdapterUpper, _countof(szAdapterUpper));
// Check for a match between the protocol and the adapter
BOOL bMatchingInterface = CheckMatchingInterface(szProtocolLower, szAdapterUpper);
CHAR szRegKey[MAX_PATH]; wsprintf(szRegKey, "%s\\Bindings", pAdapter->szEnumKey);
BOOL bCorrectNic = (NULL == pszAdapterKey) || (0 == lstrcmpi(pAdapter->szEnumKey, pszAdapterKey));
if (!lstrcmpi(pszProtocolDeviceID, SZ_PROTOCOL_IPXSPX)) { // Bind IPX/SPX to all non-broadband NICs by default (usually a max of 1)
if (pszAdapterKey == NULL) { if (IsAdapterBroadband(pAdapter)) bCorrectNic = FALSE; }
// Don't bind IPX/SPX to Dial-Up (or other virtual) adapters (bugs 1163, 1164)
if (pAdapter->bNicType == NIC_VIRTUAL) bCorrectNic = FALSE; }
//
// Check the bindings of the current adapter, looking for this protocol
//
if (regAdapter.OpenKey(HKEY_LOCAL_MACHINE, szRegKey, KEY_ALL_ACCESS)) { TCHAR szValueName[60]; int cFound = 0; DWORD iValue = 0; for (;;) { DWORD cbValueName = _countof(szValueName); if (ERROR_SUCCESS != RegEnumValue(regAdapter.m_hKey, iValue, szValueName, &cbValueName, NULL, NULL, NULL, NULL)) break;
LPSTR pchSlash = strchr(szValueName, '\\'); if (pchSlash == NULL) break;
*pchSlash = '\0';
if (0 == lstrcmpi(szValueName, pszProtocolDeviceID)) { *pchSlash = '\\';
// If this isn't the right NIC, or if there's not a matching
// interface, or (optionally) if the NIC is virtual, then
// unbind the protocol.
BOOL bUnbindFromNic = !bCorrectNic || !bMatchingInterface; if (bIgnoreVirtualNics && (pAdapter->bNicType == NIC_VIRTUAL)) bUnbindFromNic = FALSE;
if (bUnbindFromNic || // bound to the wrong adapter!
cFound != 0) // bound more than once to this NIC!
{ if (iPass == 1) // unbind on second pass only
{ // Remove the binding, then restart our search
// for matching protocols
RemoveBindingFromParent(regAdapter.m_hKey, szValueName); iValue = 0; cFound = 0; hr = NETCONN_NEED_RESTART; continue; } }
cFound += 1; }
iValue += 1; }
if (bCorrectNic && iPass == 0) { if (cFound == 0) // Protocol is not yet bound to the correct adapter
{ if (bMatchingInterface) // There's an interface in common
{ BOOL bExternalNic = IsAdapterBroadband(pAdapter);
// Enable file/printer sharing if:
// * Adapter is an ethernet or IRDA adapter
// * Adapter is not a broadband NIC
BOOL bEnableSharing = FALSE; if ((pAdapter->bNetType == NETTYPE_LAN || pAdapter->bNetType == NETTYPE_IRDA) && !bExternalNic) { bEnableSharing = TRUE; }
HRESULT hr2 = BindProtocolToAdapter(regAdapter.m_hKey, pszProtocolDeviceID, bEnableSharing); if (hr2 != NETCONN_SUCCESS) hr = hr2; } } } } } }
NetConnFree(prgAdapters);
return hr; }
// BindProtocolToOnlyOneAdapter (public)
HRESULT WINAPI BindProtocolToOnlyOneAdapter(LPCSTR pszProtocolDeviceID, LPCSTR pszAdapterKey, BOOL bIgnoreVirtualNics) { return BindProtocolToAllAdapters_Helper(pszProtocolDeviceID, pszAdapterKey, bIgnoreVirtualNics); }
// BindProtocolToAllAdapters (public)
//
// Given the device ID ("MSTCP") of a protocol which is already installed,
// binds that protocol to all adapters, as well as to client for Microsoft
// Networks and File and Printer sharing.
//
// History:
//
// 3/26/1999 KenSh Created
// 4/23/1999 KenSh Enable file sharing only on real NICs
//
HRESULT WINAPI BindProtocolToAllAdapters(LPCSTR pszProtocolDeviceID) { return BindProtocolToAllAdapters_Helper(pszProtocolDeviceID, NULL, FALSE); }
// Given a protocol ID such as "MSTCP", an optional client binding string
// such as "VREDIR\0000", and an optional service binding string such as
// "VSERVER\0000", creates a new protocol binding, and copies the name of
// the new binding into the buffer provided (e.g. "MSTCP\0001").
HRESULT CreateNewProtocolBinding(LPCSTR pszProtocolDeviceID, LPSTR pszBuf, int cchBuf, LPCSTR pszClientBinding, LPCSTR pszServiceBinding) { HRESULT hr;
if (FAILED(hr = FindAndCloneNetEnumKey(SZ_CLASS_PROTOCOL, pszProtocolDeviceID, pszBuf, cchBuf))) { ASSERT(FALSE); return hr; }
// Now pszBuf contains a string of the form "MSTCP\0001"
CHAR szBindings[60]; CRegistry regBindings; lstrcpy(szBindings, pszBuf); // "MSTCP\0001"
lstrcat(szBindings, "\\Bindings"); // "MSTCP\0001\Bindings"
if (FAILED(hr = OpenNetEnumKey(regBindings, szBindings, KEY_ALL_ACCESS))) { ASSERT(FALSE); return hr; }
// Delete existing bindings
regBindings.DeleteAllValues();
// Add the client and server bindings
if (pszClientBinding != NULL && *pszClientBinding != '\0') regBindings.SetStringValue(pszClientBinding, ""); if (pszServiceBinding != NULL && *pszServiceBinding != '\0') regBindings.SetStringValue(pszServiceBinding, "");
// Change the MasterCopy to point to correct place
CHAR szMasterCopy[MAX_PATH]; wsprintf(szMasterCopy, "Enum\\Network\\%s", pszBuf); if (regBindings.OpenKey(HKEY_LOCAL_MACHINE, szMasterCopy, KEY_ALL_ACCESS)) { regBindings.SetStringValue("MasterCopy", szMasterCopy); }
// Create a clone of the driver (a.k.a. class key)
CHAR szExistingDriver[60]; CHAR szNewDriver[60]; regBindings.QueryStringValue("Driver", szExistingDriver, _countof(szExistingDriver)); CloneNetClassKey(szExistingDriver, szNewDriver, _countof(szNewDriver));
// Change the Driver to point to the new class key
CRegistry regEnumSubKey; VERIFY(SUCCEEDED(OpenNetEnumKey(regEnumSubKey, pszBuf, KEY_ALL_ACCESS))); regEnumSubKey.SetStringValue("Driver", szNewDriver);
// If this is a new TCP/IP binding, ensure we don't have a static IP address
if (0 == lstrcmpi(pszProtocolDeviceID, SZ_PROTOCOL_TCPIP)) { CHAR szFullClassKey[100]; wsprintf(szFullClassKey, "System\\CurrentControlSet\\Services\\Class\\%s", szNewDriver);
CRegistry regClassKey; VERIFY(regClassKey.OpenKey(HKEY_LOCAL_MACHINE, szFullClassKey)); if (regClassKey.QueryStringValue("IPAddress", szFullClassKey, _countof(szFullClassKey))) regClassKey.SetStringValue("IPAddress", "0.0.0.0"); if (regClassKey.QueryStringValue("IPMask", szFullClassKey, _countof(szFullClassKey))) regClassKey.SetStringValue("IPMask", "0.0.0.0"); }
return NETCONN_SUCCESS; }
|