Source code of Windows XP (NT5)
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

514 lines
16 KiB

  1. //---------------------------------------------------------------------------
  2. // TrustEnumerator.cpp
  3. //
  4. // definition of a COM object for enumerating trust relationships
  5. //
  6. // (c) Copyright 1999, Mission Critical Software, Inc., All Rights Reserved
  7. //
  8. // Proprietary and confidential to Mission Critical Software, Inc.
  9. //---------------------------------------------------------------------------
  10. // TrustEnumerator.cpp : Implementation of CTrustEnumerator
  11. #include "stdafx.h"
  12. #include "EnumTr.h"
  13. #include "TrEnum.h"
  14. #include <activeds.h>
  15. #include <lm.h>
  16. #include <dsgetdc.h>
  17. #include "ntsecapi.h"
  18. #import "\bin\McsVarSetMin.tlb"
  19. #include "LSAUtils.h"
  20. #include "UString.hpp"
  21. #include "ErrDct.hpp"
  22. #include "EaLen.hpp"
  23. TErrorDct err;
  24. /////////////////////////////////////////////////////////////////////////////
  25. // CTrustEnumerator
  26. #undef DIM
  27. #define DIM(array) (sizeof(array) / sizeof(array[0]))
  28. #define MAX_ELEM 16
  29. using MCSVARSETMINLib::IVarSet;
  30. using MCSVARSETMINLib::IVarSetPtr;
  31. using MCSVARSETMINLib::VarSet;
  32. // Get the LDAP distinguished name of the server
  33. HRESULT
  34. GetNameDc(
  35. _bstr_t & sNameDc ,// out-LDAP distinguished name
  36. const BSTR sServer // in -command line arguments
  37. )
  38. {
  39. HRESULT hr;
  40. CComPtr<IADs> pIRootDse = NULL;
  41. _bstr_t sLdapPath;
  42. sLdapPath = L"LDAP://";
  43. if( sServer && wcslen(sServer) > 0 )
  44. {
  45. sLdapPath += sServer;
  46. sLdapPath += L"/";
  47. }
  48. sLdapPath += L"RootDse";
  49. hr = ADsGetObject(static_cast<WCHAR*>(sLdapPath), IID_IADs, (void **) &pIRootDse);
  50. if ( FAILED(hr) )
  51. {
  52. return hr;
  53. }
  54. else
  55. {
  56. VARIANT var;
  57. VariantInit(&var);
  58. hr = pIRootDse->Get(L"DefaultNamingContext", &var);
  59. if ( FAILED(hr) )
  60. {
  61. VariantClear(&var);
  62. return hr;
  63. }
  64. else
  65. {
  66. sNameDc = var.bstrVal;
  67. }
  68. VariantClear(&var);
  69. }
  70. return S_OK;
  71. }
  72. // given a server name, get a pointer to an IADsContainer
  73. HRESULT
  74. getIADContainer
  75. (
  76. BSTR const server , // in, the server to use
  77. IADsContainer ** ppIContainer // out
  78. )
  79. {
  80. HRESULT hr;
  81. _bstr_t sNameDc; // LDAP distinguished name
  82. _bstr_t sLdapPath(L"LDAP://");
  83. hr = GetNameDc(sNameDc, server);
  84. if( FAILED(hr) ) { return hr; }
  85. if( server )
  86. {
  87. sLdapPath += server;
  88. sLdapPath += L"/";
  89. }
  90. sLdapPath += L"CN=System,";
  91. sLdapPath += sNameDc;
  92. hr = ADsGetObject(sLdapPath, IID_IADsContainer, (void **) ppIContainer);
  93. return hr;
  94. }
  95. // clear all the elements of an array of variants
  96. void
  97. clearVariantArray(
  98. VARIANT * array ,
  99. unsigned int const numItems
  100. )
  101. {
  102. for( VARIANT *currItem = array;
  103. currItem < array + numItems;
  104. ++currItem )
  105. {
  106. VariantClear(currItem);
  107. }
  108. }
  109. // This class maintains the state across multiple insertions into the VarSet.
  110. // Mainly this is needed to name the servers properly: Server1, Server2, Server3, etc
  111. class AdItemHandler
  112. {
  113. private:
  114. IVarSetPtr pIVarset; // The VarSet in which to insert all items processed
  115. int counter; // the number (converted to a string) to append to "Server" to get the name
  116. AdItemHandler(){} // don't allow use of the default constructor
  117. public:
  118. AdItemHandler( IVarSetPtr & p ) : pIVarset(p), counter(1) {}
  119. void insertItem( IADs * pObject );
  120. };
  121. // Take in an IADs an insert it into the VarSet if it is a TrustedDomain. The name to put it under
  122. void AdItemHandler::
  123. insertItem(
  124. IADs * pObject // in -AD object
  125. )
  126. {
  127. HRESULT hr;
  128. WCHAR * sClass=NULL; // class
  129. WCHAR * sName=NULL; // name
  130. VARIANT var;
  131. WCHAR numBuffer [3 * sizeof(int)]; // hold result of _itow
  132. VariantInit(&var);
  133. // See if desired class
  134. hr = pObject->get_Class(&sClass);
  135. if ( FAILED(hr) )
  136. { /* so what */ }
  137. else
  138. {
  139. if ( !wcsicmp(L"TrustedDomain", sClass) )
  140. {
  141. hr = pObject->get_Name(&sName);
  142. if ( FAILED(hr) )
  143. { /* so what */ }
  144. else
  145. {
  146. _bstr_t name;
  147. if( wcsncmp(sName, L"CN=", 3) == 0 )
  148. {
  149. name = (sName + 3); // chop off the leading "CN="
  150. }
  151. else
  152. {
  153. name = sName;
  154. }
  155. _bstr_t server(L"Server");
  156. SysFreeString(sName);
  157. server += _itow(counter++, numBuffer, 10); // the servers are named "Server1", "Server2", etc
  158. hr = pIVarset->put(server + L".Name", name);
  159. hr = pObject->Get(L"TrustDirection", &var);
  160. if ( SUCCEEDED(hr) )
  161. {
  162. hr = pIVarset->put( server + L".TrustDirection", var );
  163. }
  164. VariantClear(&var);
  165. hr = pObject->Get(L"TrustType", &var);
  166. if ( SUCCEEDED(hr) )
  167. {
  168. hr = pIVarset->put( server + L".TrustType", var);
  169. }
  170. VariantClear(&var);
  171. hr = pObject->Get(L"TrustAttributes", &var);
  172. if ( SUCCEEDED(hr) )
  173. {
  174. hr = pIVarset->put( server + L".TrustAttributes", var);
  175. }
  176. VariantClear(&var);
  177. }
  178. }
  179. SysFreeString(sClass);
  180. }
  181. }
  182. BOOL IsDownLevel(WCHAR * sComputer)
  183. {
  184. BOOL bDownlevel = TRUE;
  185. WKSTA_INFO_100 * pInfo;
  186. long rc = NetWkstaGetInfo(sComputer,100,(LPBYTE*)&pInfo);
  187. if ( ! rc )
  188. {
  189. if ( pInfo->wki100_ver_major >= 5 )
  190. {
  191. bDownlevel = FALSE;
  192. }
  193. NetApiBufferFree(pInfo);
  194. }
  195. return bDownlevel;
  196. }
  197. STDMETHODIMP CTrustEnumerator::createTrust(/*[in]*/ BSTR trustingDomain,/*[in]*/ BSTR trustedDomain)
  198. {
  199. HRESULT hr = S_OK;
  200. WCHAR trustingComp[MAX_PATH];
  201. WCHAR trustedComp[MAX_PATH];
  202. WCHAR trustingDNSName[MAX_PATH];
  203. WCHAR trustedDNSName[MAX_PATH];
  204. WCHAR name[LEN_Domain];
  205. DWORD lenName = DIM(name);
  206. BYTE trustingSid[200];
  207. BYTE trustedSid[200];
  208. DWORD lenSid = DIM(trustingSid);
  209. SID_NAME_USE snu;
  210. DOMAIN_CONTROLLER_INFO * pInfo;
  211. DWORD rc = 0;
  212. LSA_HANDLE hTrusting = NULL;
  213. LSA_HANDLE hTrusted = NULL;
  214. NTSTATUS status;
  215. LSA_AUTH_INFORMATION curr;
  216. LSA_AUTH_INFORMATION prev;
  217. WCHAR password[] = L"password";
  218. rc = DsGetDcName(NULL, trustingDomain, NULL, NULL, DS_PDC_REQUIRED, &pInfo);
  219. if ( !rc )
  220. {
  221. wcscpy(trustingComp,pInfo->DomainControllerName);
  222. wcscpy(trustingDNSName,pInfo->DomainName);
  223. NetApiBufferFree(pInfo);
  224. }
  225. rc = DsGetDcName(NULL, trustedDomain, NULL, NULL, DS_PDC_REQUIRED, &pInfo);
  226. if ( !rc )
  227. {
  228. wcscpy(trustedComp,pInfo->DomainControllerName);
  229. wcscpy(trustedDNSName,pInfo->DomainName);
  230. NetApiBufferFree(pInfo);
  231. }
  232. // Need to get the computer name and the SIDs for the domains.
  233. if ( ! LookupAccountName(trustingComp,trustingDomain,trustingSid,&lenSid,name,&lenName,&snu) )
  234. {
  235. rc = GetLastError();
  236. return 1;
  237. }
  238. lenSid = DIM(trustedSid);
  239. lenName = DIM(name);
  240. if (! LookupAccountName(trustedComp,trustedDomain,trustedSid,&lenSid,name,&lenName,&snu) )
  241. {
  242. rc = GetLastError();
  243. return 1;
  244. }
  245. // open an LSA handle to each domain
  246. if ( *trustingComp && *trustedComp )
  247. {
  248. status = OpenPolicy(trustedComp,POLICY_VIEW_LOCAL_INFORMATION | POLICY_TRUST_ADMIN | POLICY_CREATE_SECRET,&hTrusted);
  249. rc = LsaNtStatusToWinError(rc);
  250. if ( ! rc )
  251. {
  252. // set up the auth information for the trust relationship
  253. curr.AuthInfo = (LPBYTE)password;
  254. curr.AuthInfoLength = sizeof (password);
  255. curr.AuthType = TRUST_AUTH_TYPE_CLEAR;
  256. curr.LastUpdateTime.QuadPart = 0;
  257. prev.AuthInfo = NULL;
  258. prev.AuthInfoLength = 0;
  259. prev.AuthType = TRUST_AUTH_TYPE_CLEAR;
  260. prev.LastUpdateTime.QuadPart = 0;
  261. // set up the trusted side of the relationship
  262. if ( IsDownLevel(trustedComp) )
  263. {
  264. // create an inter-domain trust account for the trusting domain on the trusted domain
  265. USER_INFO_1 uInfo;
  266. DWORD parmErr;
  267. memset(&uInfo,0,(sizeof uInfo));
  268. UStrCpy(name,trustingDomain);
  269. name[UStrLen(name) + 1] = 0;
  270. name[UStrLen(name)] = L'$';
  271. uInfo.usri1_flags = UF_SCRIPT | UF_INTERDOMAIN_TRUST_ACCOUNT;
  272. uInfo.usri1_name = name;
  273. uInfo.usri1_password = password;
  274. uInfo.usri1_priv = 1;
  275. rc = NetUserAdd(trustedComp,1,(LPBYTE)&uInfo,&parmErr);
  276. }
  277. else
  278. {
  279. // Create the trustedDomain object.
  280. LSA_UNICODE_STRING sTemp;
  281. TRUSTED_DOMAIN_INFORMATION_EX trustedInfo;
  282. TRUSTED_DOMAIN_AUTH_INFORMATION trustAuth;
  283. InitLsaString(&sTemp, const_cast<WCHAR*>(trustingDomain));
  284. trustedInfo.FlatName = sTemp;
  285. InitLsaString(&sTemp, trustingDNSName);
  286. trustedInfo.Name = sTemp;
  287. trustedInfo.Sid = trustingSid;
  288. if ( IsDownLevel(trustingComp) )
  289. {
  290. trustedInfo.TrustAttributes = TRUST_TYPE_DOWNLEVEL;
  291. }
  292. else
  293. {
  294. trustedInfo.TrustAttributes = TRUST_TYPE_UPLEVEL;
  295. }
  296. trustedInfo.TrustDirection = TRUST_DIRECTION_INBOUND;
  297. trustedInfo.TrustType = TRUST_ATTRIBUTE_NON_TRANSITIVE;
  298. trustAuth.IncomingAuthInfos = 1;
  299. trustAuth.OutgoingAuthInfos = 0;
  300. trustAuth.OutgoingAuthenticationInformation = NULL;
  301. trustAuth.OutgoingPreviousAuthenticationInformation = NULL;
  302. trustAuth.IncomingAuthenticationInformation = &curr;
  303. trustAuth.IncomingPreviousAuthenticationInformation = &prev;
  304. status = LsaCreateTrustedDomainEx( hTrusted, &trustedInfo, &trustAuth, POLICY_VIEW_LOCAL_INFORMATION |
  305. POLICY_TRUST_ADMIN | POLICY_CREATE_SECRET, &hTrusting );
  306. rc = LsaNtStatusToWinError(status);
  307. if ( ! rc )
  308. {
  309. LsaClose(hTrusting);
  310. hTrusting = NULL;
  311. }
  312. }
  313. status = OpenPolicy(trustingComp,POLICY_VIEW_LOCAL_INFORMATION
  314. | POLICY_TRUST_ADMIN | POLICY_CREATE_SECRET,&hTrusting);
  315. rc = LsaNtStatusToWinError(rc);
  316. // set up the trusting side of the relationship
  317. if ( IsDownLevel(trustingComp) )
  318. {
  319. TRUSTED_DOMAIN_NAME_INFO nameInfo;
  320. InitLsaString(&nameInfo.Name,const_cast<WCHAR*>(trustedDomain));
  321. status = LsaSetTrustedDomainInformation(hTrusting,trustedSid,TrustedDomainNameInformation,&nameInfo);
  322. rc = LsaNtStatusToWinError(status);
  323. if ( ! rc )
  324. {
  325. // set the password for the new trust
  326. TRUSTED_PASSWORD_INFO pwdInfo;
  327. InitLsaString(&pwdInfo.Password,password);
  328. InitLsaString(&pwdInfo.OldPassword,NULL);
  329. status = LsaSetTrustedDomainInformation(hTrusting,trustedSid,TrustedPasswordInformation,&pwdInfo);
  330. rc = LsaNtStatusToWinError(status);
  331. }
  332. }
  333. else
  334. {
  335. // for Win2K domain, use LsaCreateTrustedDomainEx
  336. // to create the trustedDomain object.
  337. LSA_UNICODE_STRING sTemp;
  338. TRUSTED_DOMAIN_INFORMATION_EX trustedInfo;
  339. TRUSTED_DOMAIN_AUTH_INFORMATION trustAuth;
  340. InitLsaString(&sTemp, const_cast<WCHAR*>(trustedDomain));
  341. trustedInfo.FlatName = sTemp;
  342. InitLsaString(&sTemp, trustedDNSName);
  343. trustedInfo.Name = sTemp;
  344. trustedInfo.Sid = trustedSid;
  345. if ( IsDownLevel(trustedComp) )
  346. {
  347. trustedInfo.TrustAttributes = TRUST_TYPE_DOWNLEVEL;
  348. }
  349. else
  350. {
  351. trustedInfo.TrustAttributes = TRUST_TYPE_UPLEVEL;
  352. }
  353. trustedInfo.TrustDirection = TRUST_DIRECTION_OUTBOUND;
  354. trustedInfo.TrustType = TRUST_ATTRIBUTE_NON_TRANSITIVE;
  355. trustAuth.IncomingAuthInfos = 0;
  356. trustAuth.OutgoingAuthInfos = 1;
  357. trustAuth.IncomingAuthenticationInformation = NULL;
  358. trustAuth.IncomingPreviousAuthenticationInformation = NULL;
  359. trustAuth.OutgoingAuthenticationInformation = &curr;
  360. trustAuth.OutgoingPreviousAuthenticationInformation = &prev;
  361. LSA_HANDLE hTemp;
  362. status = LsaCreateTrustedDomainEx( hTrusting, &trustedInfo, &trustAuth, 0, &hTemp );
  363. rc = LsaNtStatusToWinError(status);
  364. if( ! rc )
  365. {
  366. LsaClose(hTemp);
  367. }
  368. }
  369. }
  370. }
  371. if ( hTrusting )
  372. LsaClose(hTrusting);
  373. if( hTrusted )
  374. LsaClose(hTrusted);
  375. return HRESULT_FROM_WIN32(rc);
  376. }
  377. // externally visible method
  378. STDMETHODIMP CTrustEnumerator::
  379. getTrustRelations
  380. (
  381. BSTR server ,// in, The name of the server from which to get trust information
  382. IUnknown ** enumeration // [out, retval] a pointer to an IVarSet containing the enumerated machines
  383. )
  384. {
  385. *enumeration = NULL;
  386. IVarSetPtr pIVarset = NULL;
  387. CComPtr<IADsContainer> pIContainer = NULL;
  388. HRESULT hr;
  389. IEnumVARIANT * pIContents=NULL;
  390. hr = getIADContainer(server, &pIContainer);
  391. if( FAILED(hr) ) { return hr; }
  392. hr = pIVarset.CreateInstance(__uuidof(VarSet));
  393. if( FAILED(hr) ) { return hr; }
  394. AdItemHandler handler(pIVarset);
  395. hr = ADsBuildEnumerator(pIContainer, &pIContents);
  396. if( FAILED(hr) ) { return hr; }
  397. DWORD nRead=0; // number enumerated items returned
  398. do
  399. {
  400. VARIANT arrayEnumItems[MAX_ELEM]; // array of enumerated items
  401. VARIANT * pEnumItem; // enumerated item
  402. nRead = 0;
  403. memset(arrayEnumItems, 0, sizeof arrayEnumItems);
  404. hr = ADsEnumerateNext(
  405. pIContents,
  406. DIM(arrayEnumItems),
  407. arrayEnumItems,
  408. &nRead );
  409. if( FAILED(hr) ) { return hr; }
  410. const VARIANT *pEndOfItems = arrayEnumItems + nRead;
  411. for ( pEnumItem = arrayEnumItems;
  412. pEnumItem < pEndOfItems;
  413. pEnumItem++ )
  414. {
  415. CComPtr<IDispatch> pDispatch (pEnumItem->pdispVal);
  416. CComPtr<IADs> pObject;
  417. hr = pDispatch->QueryInterface(
  418. IID_IADs,
  419. (void **) &pObject );
  420. if ( FAILED(hr) )
  421. {
  422. clearVariantArray(arrayEnumItems, nRead);
  423. return hr;
  424. }
  425. if ( SUCCEEDED(hr) )
  426. handler.insertItem(pObject); // insert the item into the varset if necessary
  427. }
  428. clearVariantArray(arrayEnumItems, nRead);
  429. } while ( nRead );
  430. if ( pIContents )
  431. {
  432. ADsFreeEnumerator( pIContents );
  433. // pIContents->Release();
  434. pIContents = NULL;
  435. }
  436. hr = pIVarset->QueryInterface(__uuidof(IUnknown), reinterpret_cast<void**>(enumeration));
  437. return hr;
  438. }