Source code of Windows XP (NT5)
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1037 lines
20 KiB

  1. #include "StdAfx.h"
  2. #include "ADMTScript.h"
  3. #include "MigrationBase.h"
  4. #include <LM.h>
  5. #include "Error.h"
  6. #include "VarSetAccounts.h"
  7. #include "VarSetServers.h"
  8. #include "FixHierarchy.h"
  9. using namespace _com_util;
  10. namespace MigrationBase
  11. {
  12. void GetNamesFromData(VARIANT& vntData, StringSet& setNames);
  13. void GetNamesFromVariant(VARIANT* pvnt, StringSet& setNames);
  14. void GetNamesFromString(BSTR bstr, StringSet& setNames);
  15. void GetNamesFromStringArray(SAFEARRAY* psa, StringSet& setNames);
  16. void GetNamesFromVariantArray(SAFEARRAY* psa, StringSet& setNames);
  17. void GetNamesFromFile(VARIANT& vntData, StringSet& setNames);
  18. void GetNamesFromFile(LPCTSTR pszFileName, StringSet& setNames);
  19. void GetNamesFromStringA(LPCSTR pchString, DWORD cchString, StringSet& setNames);
  20. void GetNamesFromStringW(LPCWSTR pchString, DWORD cchString, StringSet& setNames);
  21. _bstr_t RemoveTrailingDollarSign(LPCTSTR pszName);
  22. }
  23. using namespace MigrationBase;
  24. //---------------------------------------------------------------------------
  25. // MigrationBase Class
  26. //---------------------------------------------------------------------------
  27. // Constructor
  28. CMigrationBase::CMigrationBase() :
  29. m_nRecurseMaintain(0),
  30. m_Mutex(ADMT_MUTEX)
  31. {
  32. }
  33. // Destructor
  34. CMigrationBase::~CMigrationBase()
  35. {
  36. }
  37. // InitSourceDomainAndContainer Method
  38. void CMigrationBase::InitSourceDomainAndContainer()
  39. {
  40. m_SourceDomain.Initialize(m_spInternal->SourceDomain);
  41. m_SourceContainer = m_SourceDomain.GetContainer(m_spInternal->SourceOu);
  42. }
  43. // InitTargetDomainAndContainer Method
  44. void CMigrationBase::InitTargetDomainAndContainer()
  45. {
  46. m_TargetDomain.Initialize(m_spInternal->TargetDomain);
  47. m_TargetContainer = m_TargetDomain.GetContainer(m_spInternal->TargetOu);
  48. // verify target domain is in native mode
  49. if (m_TargetDomain.NativeMode() == false)
  50. {
  51. AdmtThrowError(
  52. GUID_NULL, GUID_NULL,
  53. E_INVALIDARG, IDS_E_TARGET_DOMAIN_NOT_NATIVE_MODE,
  54. (LPCTSTR)m_TargetDomain.Name()
  55. );
  56. }
  57. VerifyTargetContainerPathLength();
  58. }
  59. // VerifyInterIntraForest Method
  60. void CMigrationBase::VerifyInterIntraForest()
  61. {
  62. // if the source and target domains have the same forest name then they are intra-forest
  63. bool bIntraForest = m_spInternal->IntraForest ? true : false;
  64. if (m_SourceDomain.ForestName() == m_TargetDomain.ForestName())
  65. {
  66. // intra-forest must be set to true to match the domains
  67. if (!bIntraForest)
  68. {
  69. AdmtThrowError(
  70. GUID_NULL, GUID_NULL,
  71. E_INVALIDARG, IDS_E_NOT_INTER_FOREST,
  72. (LPCTSTR)m_SourceDomain.Name(), (LPCTSTR)m_TargetDomain.Name()
  73. );
  74. }
  75. }
  76. else
  77. {
  78. // intra-forest must be set to false to match the domains
  79. if (bIntraForest)
  80. {
  81. AdmtThrowError(
  82. GUID_NULL, GUID_NULL,
  83. E_INVALIDARG, IDS_E_NOT_INTRA_FOREST,
  84. (LPCTSTR)m_SourceDomain.Name(), (LPCTSTR)m_TargetDomain.Name()
  85. );
  86. }
  87. }
  88. }
  89. // DoOption Method
  90. void CMigrationBase::DoOption(long lOptions, VARIANT& vntInclude, VARIANT& vntExclude)
  91. {
  92. m_setIncludeNames.clear();
  93. m_setExcludeNames.clear();
  94. InitRecurseMaintainOption(lOptions);
  95. GetExcludeNames(vntExclude, m_setExcludeNames);
  96. switch (lOptions & 0xFF)
  97. {
  98. case admtNone:
  99. {
  100. DoNone();
  101. break;
  102. }
  103. case admtData:
  104. {
  105. GetNamesFromData(vntInclude, m_setIncludeNames);
  106. DoNames();
  107. break;
  108. }
  109. case admtFile:
  110. {
  111. GetNamesFromFile(vntInclude, m_setIncludeNames);
  112. DoNames();
  113. break;
  114. }
  115. case admtDomain:
  116. {
  117. m_setIncludeNames.clear();
  118. DoDomain();
  119. break;
  120. }
  121. default:
  122. {
  123. AdmtThrowError(GUID_NULL, GUID_NULL, E_INVALIDARG, IDS_E_INVALID_OPTION);
  124. break;
  125. }
  126. }
  127. }
  128. // DoNone Method
  129. void CMigrationBase::DoNone()
  130. {
  131. }
  132. // DoNames Method
  133. void CMigrationBase::DoNames()
  134. {
  135. }
  136. // DoDomain Method
  137. void CMigrationBase::DoDomain()
  138. {
  139. }
  140. // InitRecurseMaintainOption Method
  141. void CMigrationBase::InitRecurseMaintainOption(long lOptions)
  142. {
  143. switch (lOptions & 0xFF)
  144. {
  145. case admtData:
  146. case admtFile:
  147. {
  148. if (lOptions & 0xFF00)
  149. {
  150. AdmtThrowError(GUID_NULL, GUID_NULL, E_INVALIDARG, IDS_E_DATA_OPTION_FLAGS_NOT_ALLOWED);
  151. }
  152. m_nRecurseMaintain = 0;
  153. break;
  154. }
  155. case admtDomain:
  156. {
  157. m_nRecurseMaintain = 0;
  158. if (lOptions & admtRecurse)
  159. {
  160. ++m_nRecurseMaintain;
  161. if (lOptions & admtMaintainHierarchy)
  162. {
  163. ++m_nRecurseMaintain;
  164. }
  165. }
  166. break;
  167. }
  168. default:
  169. {
  170. m_nRecurseMaintain = 0;
  171. break;
  172. }
  173. }
  174. }
  175. // GetExcludeNames Method
  176. void CMigrationBase::GetExcludeNames(VARIANT& vntExclude, StringSet& setExcludeNames)
  177. {
  178. try
  179. {
  180. switch (V_VT(&vntExclude))
  181. {
  182. case VT_EMPTY:
  183. case VT_ERROR:
  184. {
  185. setExcludeNames.clear();
  186. break;
  187. }
  188. case VT_BSTR:
  189. {
  190. GetNamesFromFile(V_BSTR(&vntExclude), setExcludeNames);
  191. break;
  192. }
  193. case VT_BSTR|VT_BYREF:
  194. {
  195. BSTR* pbstr = V_BSTRREF(&vntExclude);
  196. if (pbstr)
  197. {
  198. GetNamesFromFile(*pbstr, setExcludeNames);
  199. }
  200. break;
  201. }
  202. case VT_BSTR|VT_ARRAY:
  203. {
  204. GetNamesFromStringArray(V_ARRAY(&vntExclude), setExcludeNames);
  205. break;
  206. }
  207. case VT_BSTR|VT_ARRAY|VT_BYREF:
  208. {
  209. SAFEARRAY** ppsa = V_ARRAYREF(&vntExclude);
  210. if (ppsa)
  211. {
  212. GetNamesFromStringArray(*ppsa, setExcludeNames);
  213. }
  214. break;
  215. }
  216. case VT_VARIANT|VT_BYREF:
  217. {
  218. VARIANT* pvnt = V_VARIANTREF(&vntExclude);
  219. if (pvnt)
  220. {
  221. GetExcludeNames(*pvnt, setExcludeNames);
  222. }
  223. break;
  224. }
  225. case VT_VARIANT|VT_ARRAY:
  226. {
  227. GetNamesFromVariantArray(V_ARRAY(&vntExclude), setExcludeNames);
  228. break;
  229. }
  230. case VT_VARIANT|VT_ARRAY|VT_BYREF:
  231. {
  232. SAFEARRAY** ppsa = V_ARRAYREF(&vntExclude);
  233. if (ppsa)
  234. {
  235. GetNamesFromVariantArray(*ppsa, setExcludeNames);
  236. }
  237. break;
  238. }
  239. default:
  240. {
  241. _com_issue_error(E_INVALIDARG);
  242. break;
  243. }
  244. }
  245. }
  246. catch (_com_error& ce)
  247. {
  248. AdmtThrowError(GUID_NULL, GUID_NULL, ce.Error(), IDS_E_INVALID_EXCLUDE_DATA_TYPE);
  249. }
  250. catch (...)
  251. {
  252. AdmtThrowError(GUID_NULL, GUID_NULL, E_FAIL, IDS_E_INVALID_EXCLUDE_DATA_TYPE);
  253. }
  254. }
  255. // FillInVarSetForUsers Method
  256. void CMigrationBase::FillInVarSetForUsers(CDomainAccounts& rUsers, CVarSet& rVarSet)
  257. {
  258. CVarSetAccounts aAccounts(rVarSet);
  259. for (CDomainAccounts::iterator it = rUsers.begin(); it != rUsers.end(); it++)
  260. {
  261. aAccounts.AddAccount(_T("User"), it->GetADsPath(), it->GetName(), it->GetUserPrincipalName());
  262. }
  263. }
  264. // FillInVarSetForGroups Method
  265. void CMigrationBase::FillInVarSetForGroups(CDomainAccounts& rGroups, CVarSet& rVarSet)
  266. {
  267. CVarSetAccounts aAccounts(rVarSet);
  268. for (CDomainAccounts::iterator it = rGroups.begin(); it != rGroups.end(); it++)
  269. {
  270. aAccounts.AddAccount(_T("Group"), it->GetADsPath(), it->GetName());
  271. }
  272. }
  273. // FillInVarSetForComputers Method
  274. void CMigrationBase::FillInVarSetForComputers(CDomainAccounts& rComputers, bool bMigrateOnly, bool bMoveToTarget, bool bReboot, long lRebootDelay, CVarSet& rVarSet)
  275. {
  276. CVarSetAccounts aAccounts(rVarSet);
  277. CVarSetServers aServers(rVarSet);
  278. for (CDomainAccounts::iterator it = rComputers.begin(); it != rComputers.end(); it++)
  279. {
  280. // remove trailing '$'
  281. // ADMT doesn't accept true SAM account name
  282. _bstr_t strName = RemoveTrailingDollarSign(it->GetSamAccountName());
  283. aAccounts.AddAccount(_T("Computer"), strName);
  284. aServers.AddServer(strName, bMigrateOnly, bMoveToTarget, bReboot, lRebootDelay);
  285. }
  286. }
  287. // VerifyRenameConflictPrefixSuffixValid Method
  288. void CMigrationBase::VerifyRenameConflictPrefixSuffixValid()
  289. {
  290. int nTotalPrefixSuffixLength = 0;
  291. long lRenameOption = m_spInternal->RenameOption;
  292. if ((lRenameOption == admtRenameWithPrefix) || (lRenameOption == admtRenameWithSuffix))
  293. {
  294. _bstr_t strPrefixSuffix = m_spInternal->RenamePrefixOrSuffix;
  295. nTotalPrefixSuffixLength += strPrefixSuffix.length();
  296. }
  297. long lConflictOption = m_spInternal->ConflictOptions & 0x0F;
  298. if ((lConflictOption == admtRenameConflictingWithSuffix) || (lConflictOption == admtRenameConflictingWithPrefix))
  299. {
  300. _bstr_t strPrefixSuffix = m_spInternal->ConflictPrefixOrSuffix;
  301. nTotalPrefixSuffixLength += strPrefixSuffix.length();
  302. }
  303. if (nTotalPrefixSuffixLength > MAXIMUM_PREFIX_SUFFIX_LENGTH)
  304. {
  305. AdmtThrowError(GUID_NULL, GUID_NULL, E_INVALIDARG, IDS_E_PREFIX_SUFFIX_TOO_LONG, MAXIMUM_PREFIX_SUFFIX_LENGTH);
  306. }
  307. }
  308. // VerifyCanAddSidHistory Method
  309. void CMigrationBase::VerifyCanAddSidHistory()
  310. {
  311. #define F_WORKS 0x00000000
  312. #define F_WRONGOS 0x00000001
  313. #define F_NO_REG_KEY 0x00000002
  314. #define F_NO_AUDITING_SOURCE 0x00000004
  315. #define F_NO_AUDITING_TARGET 0x00000008
  316. #define F_NO_LOCAL_GROUP 0x00000010
  317. try
  318. {
  319. long lErrorFlags = 0;
  320. IAccessCheckerPtr spAccessChecker(__uuidof(AccessChecker));
  321. spAccessChecker->CanUseAddSidHistory(m_SourceDomain.Name(), m_TargetDomain.Name(), &lErrorFlags);
  322. if (lErrorFlags != 0)
  323. {
  324. _bstr_t strError;
  325. CComBSTR str;
  326. if (lErrorFlags & F_NO_AUDITING_SOURCE)
  327. {
  328. str.LoadString(IDS_E_NO_AUDITING_SOURCE);
  329. strError += str.operator BSTR();
  330. }
  331. if (lErrorFlags & F_NO_AUDITING_TARGET)
  332. {
  333. str.LoadString(IDS_E_NO_AUDITING_TARGET);
  334. strError += str.operator BSTR();
  335. }
  336. if (lErrorFlags & F_NO_LOCAL_GROUP)
  337. {
  338. str.LoadString(IDS_E_NO_SID_HISTORY_LOCAL_GROUP);
  339. strError += str.operator BSTR();
  340. }
  341. if (lErrorFlags & F_NO_REG_KEY)
  342. {
  343. str.LoadString(IDS_E_NO_SID_HISTORY_REGISTRY_ENTRY);
  344. strError += str.operator BSTR();
  345. }
  346. AdmtThrowError(GUID_NULL, GUID_NULL, E_FAIL, IDS_E_SID_HISTORY_CONFIGURATION, (LPCTSTR)strError);
  347. }
  348. }
  349. catch (_com_error& ce)
  350. {
  351. AdmtThrowError(GUID_NULL, GUID_NULL, ce, IDS_E_CAN_ADD_SID_HISTORY);
  352. }
  353. catch (...)
  354. {
  355. AdmtThrowError(GUID_NULL, GUID_NULL, E_FAIL, IDS_E_CAN_ADD_SID_HISTORY);
  356. }
  357. }
  358. // VerifyTargetContainerPathLength Method
  359. void CMigrationBase::VerifyTargetContainerPathLength()
  360. {
  361. _bstr_t strPath = GetTargetContainer().GetPath();
  362. if (strPath.length() > 999)
  363. {
  364. AdmtThrowError(GUID_NULL, GUID_NULL, E_INVALIDARG, IDS_E_TARGET_CONTAINER_PATH_TOO_LONG);
  365. }
  366. }
  367. // VerifyPasswordServer Method
  368. void CMigrationBase::VerifyPasswordOption()
  369. {
  370. if (m_spInternal->PasswordOption == admtCopyPassword)
  371. {
  372. _bstr_t strServer = m_spInternal->PasswordServer;
  373. // a password server must be specified for copy password option
  374. if (strServer.length() == 0)
  375. {
  376. AdmtThrowError(GUID_NULL, GUID_NULL, E_INVALIDARG, IDS_E_PASSWORD_DC_NOT_SPECIFIED);
  377. }
  378. //
  379. // verify that password server exists and is a domain controller
  380. //
  381. _bstr_t strPrefixedServer;
  382. if (_tcsncmp(strServer, _T("\\\\"), 2) == 0)
  383. {
  384. strPrefixedServer = strServer;
  385. }
  386. else
  387. {
  388. strPrefixedServer = _T("\\\\") + strServer;
  389. }
  390. PSERVER_INFO_101 psiInfo;
  391. NET_API_STATUS nasStatus = NetServerGetInfo(strPrefixedServer, 101, (LPBYTE*)&psiInfo);
  392. if (nasStatus != NERR_Success)
  393. {
  394. AdmtThrowError(GUID_NULL, GUID_NULL, HRESULT_FROM_WIN32(nasStatus), IDS_E_PASSWORD_DC_NOT_FOUND, (LPCTSTR)strServer);
  395. }
  396. UINT uMsgId = 0;
  397. if (psiInfo->sv101_platform_id != PLATFORM_ID_NT)
  398. {
  399. uMsgId = IDS_E_PASSWORD_DC_NOT_NT;
  400. }
  401. else if (!(psiInfo->sv101_type & SV_TYPE_DOMAIN_CTRL) && !(psiInfo->sv101_type & SV_TYPE_DOMAIN_BAKCTRL))
  402. {
  403. uMsgId = IDS_E_PASSWORD_DC_NOT_DC;
  404. }
  405. NetApiBufferFree(psiInfo);
  406. if (uMsgId)
  407. {
  408. AdmtThrowError(GUID_NULL, GUID_NULL, E_INVALIDARG, uMsgId, (LPCTSTR)strServer);
  409. }
  410. //
  411. // verify that password server is configured properly
  412. //
  413. IPasswordMigrationPtr spPasswordMigration(__uuidof(PasswordMigration));
  414. spPasswordMigration->EstablishSession(strServer, m_TargetDomain.DomainControllerName());
  415. }
  416. }
  417. // PerformMigration Method
  418. void CMigrationBase::PerformMigration(CVarSet& rVarSet)
  419. {
  420. IPerformMigrationTaskPtr spMigrator(__uuidof(Migrator));
  421. try
  422. {
  423. spMigrator->PerformMigrationTask(IUnknownPtr(rVarSet.GetInterface()), 0);
  424. }
  425. catch (_com_error& ce)
  426. {
  427. if (ce.Error() == MIGRATOR_E_PROCESSES_STILL_RUNNING)
  428. {
  429. AdmtThrowError(GUID_NULL, GUID_NULL, ce.Error(), IDS_E_ADMT_PROCESS_RUNNING);
  430. }
  431. else
  432. {
  433. throw;
  434. }
  435. }
  436. }
  437. // FixObjectsInHierarchy Method
  438. void CMigrationBase::FixObjectsInHierarchy(LPCTSTR pszType)
  439. {
  440. CFixObjectsInHierarchy fix;
  441. fix.SetObjectType(pszType);
  442. long lOptions = m_spInternal->ConflictOptions;
  443. long lOption = lOptions & 0x0F;
  444. long lFlags = lOptions & 0xF0;
  445. fix.SetFixReplaced((lOption == admtReplaceConflicting) && (lFlags & admtMoveReplacedAccounts));
  446. fix.SetSourceContainerPath(m_SourceContainer.GetPath());
  447. fix.SetTargetContainerPath(m_TargetContainer.GetPath());
  448. fix.FixObjects();
  449. }
  450. //---------------------------------------------------------------------------
  451. namespace MigrationBase
  452. {
  453. // GetNamesFromData Method
  454. void GetNamesFromData(VARIANT& vntData, StringSet& setNames)
  455. {
  456. try
  457. {
  458. GetNamesFromVariant(&vntData, setNames);
  459. }
  460. catch (_com_error& ce)
  461. {
  462. AdmtThrowError(GUID_NULL, GUID_NULL, ce.Error(), IDS_E_INVALID_DATA_OPTION_DATA_TYPE);
  463. }
  464. catch (...)
  465. {
  466. AdmtThrowError(GUID_NULL, GUID_NULL, E_FAIL, IDS_E_INVALID_DATA_OPTION_DATA_TYPE);
  467. }
  468. }
  469. // GetNamesFromVariant Method
  470. void GetNamesFromVariant(VARIANT* pvntData, StringSet& setNames)
  471. {
  472. switch (V_VT(pvntData))
  473. {
  474. case VT_BSTR:
  475. {
  476. GetNamesFromString(V_BSTR(pvntData), setNames);
  477. break;
  478. }
  479. case VT_BSTR|VT_BYREF:
  480. {
  481. BSTR* pbstr = V_BSTRREF(pvntData);
  482. if (pbstr)
  483. {
  484. GetNamesFromString(*pbstr, setNames);
  485. }
  486. break;
  487. }
  488. case VT_BSTR|VT_ARRAY:
  489. {
  490. GetNamesFromStringArray(V_ARRAY(pvntData), setNames);
  491. break;
  492. }
  493. case VT_BSTR|VT_ARRAY|VT_BYREF:
  494. {
  495. SAFEARRAY** ppsa = V_ARRAYREF(pvntData);
  496. if (ppsa)
  497. {
  498. GetNamesFromStringArray(*ppsa, setNames);
  499. }
  500. break;
  501. }
  502. case VT_VARIANT|VT_BYREF:
  503. {
  504. VARIANT* pvnt = V_VARIANTREF(pvntData);
  505. if (pvnt)
  506. {
  507. GetNamesFromVariant(pvnt, setNames);
  508. }
  509. break;
  510. }
  511. case VT_VARIANT|VT_ARRAY:
  512. {
  513. GetNamesFromVariantArray(V_ARRAY(pvntData), setNames);
  514. break;
  515. }
  516. case VT_VARIANT|VT_ARRAY|VT_BYREF:
  517. {
  518. SAFEARRAY** ppsa = V_ARRAYREF(pvntData);
  519. if (ppsa)
  520. {
  521. GetNamesFromVariantArray(*ppsa, setNames);
  522. }
  523. break;
  524. }
  525. case VT_EMPTY:
  526. {
  527. // ignore empty variants
  528. break;
  529. }
  530. default:
  531. {
  532. _com_issue_error(E_INVALIDARG);
  533. break;
  534. }
  535. }
  536. }
  537. // GetNamesFromString Method
  538. void GetNamesFromString(BSTR bstr, StringSet& setNames)
  539. {
  540. if (bstr)
  541. {
  542. UINT cch = SysStringLen(bstr);
  543. if (cch > 0)
  544. {
  545. GetNamesFromStringW(bstr, cch, setNames);
  546. }
  547. }
  548. }
  549. // GetNamesFromStringArray Method
  550. void GetNamesFromStringArray(SAFEARRAY* psa, StringSet& setNames)
  551. {
  552. BSTR* pbstr;
  553. HRESULT hr = SafeArrayAccessData(psa, (void**)&pbstr);
  554. if (SUCCEEDED(hr))
  555. {
  556. try
  557. {
  558. UINT uDimensionCount = psa->cDims;
  559. for (UINT uDimension = 0; uDimension < uDimensionCount; uDimension++)
  560. {
  561. UINT uElementCount = psa->rgsabound[uDimension].cElements;
  562. for (UINT uElement = 0; uElement < uElementCount; uElement++)
  563. {
  564. setNames.insert(_bstr_t(*pbstr++));
  565. }
  566. }
  567. SafeArrayUnaccessData(psa);
  568. }
  569. catch (...)
  570. {
  571. SafeArrayUnaccessData(psa);
  572. throw;
  573. }
  574. }
  575. }
  576. // GetNamesFromVariantArray Method
  577. void GetNamesFromVariantArray(SAFEARRAY* psa, StringSet& setNames)
  578. {
  579. VARIANT* pvnt;
  580. HRESULT hr = SafeArrayAccessData(psa, (void**)&pvnt);
  581. if (SUCCEEDED(hr))
  582. {
  583. try
  584. {
  585. UINT uDimensionCount = psa->cDims;
  586. for (UINT uDimension = 0; uDimension < uDimensionCount; uDimension++)
  587. {
  588. UINT uElementCount = psa->rgsabound[uDimension].cElements;
  589. for (UINT uElement = 0; uElement < uElementCount; uElement++)
  590. {
  591. GetNamesFromVariant(pvnt++, setNames);
  592. }
  593. }
  594. SafeArrayUnaccessData(psa);
  595. }
  596. catch (...)
  597. {
  598. SafeArrayUnaccessData(psa);
  599. throw;
  600. }
  601. }
  602. }
  603. // GetNamesFromFile Method
  604. //
  605. // - the maximum file size this implementation can handle is 4,294,967,295 bytes
  606. void GetNamesFromFile(VARIANT& vntData, StringSet& setNames)
  607. {
  608. bool bInvalidArg = false;
  609. switch (V_VT(&vntData))
  610. {
  611. case VT_BSTR:
  612. {
  613. BSTR bstr = V_BSTR(&vntData);
  614. if (bstr)
  615. {
  616. GetNamesFromFile(bstr, setNames);
  617. }
  618. else
  619. {
  620. bInvalidArg = true;
  621. }
  622. break;
  623. }
  624. case VT_BSTR|VT_BYREF:
  625. {
  626. BSTR* pbstr = V_BSTRREF(&vntData);
  627. if (pbstr && *pbstr)
  628. {
  629. GetNamesFromFile(*pbstr, setNames);
  630. }
  631. else
  632. {
  633. bInvalidArg = true;
  634. }
  635. break;
  636. }
  637. case VT_VARIANT|VT_BYREF:
  638. {
  639. VARIANT* pvnt = V_VARIANTREF(&vntData);
  640. if (pvnt)
  641. {
  642. GetNamesFromFile(*pvnt, setNames);
  643. }
  644. else
  645. {
  646. bInvalidArg = true;
  647. }
  648. break;
  649. }
  650. default:
  651. {
  652. bInvalidArg = true;
  653. break;
  654. }
  655. }
  656. if (bInvalidArg)
  657. {
  658. AdmtThrowError(GUID_NULL, GUID_NULL, E_INVALIDARG, IDS_E_INVALID_FILE_OPTION_DATA_TYPE);
  659. }
  660. }
  661. // GetNamesFromFile Method
  662. //
  663. // - the maximum file size this implementation can handle is 4,294,967,295 bytes
  664. void GetNamesFromFile(LPCTSTR pszFileName, StringSet& setNames)
  665. {
  666. HRESULT hr = S_OK;
  667. if (pszFileName)
  668. {
  669. HANDLE hFile = CreateFile(pszFileName, GENERIC_READ, 0, NULL, OPEN_EXISTING, 0, NULL);
  670. if (hFile != INVALID_HANDLE_VALUE)
  671. {
  672. DWORD dwFileSize = GetFileSize(hFile, NULL);
  673. if (dwFileSize > 0)
  674. {
  675. HANDLE hFileMappingObject = CreateFileMapping(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
  676. if (hFileMappingObject != NULL)
  677. {
  678. LPVOID pvBase = MapViewOfFile(hFileMappingObject, FILE_MAP_READ, 0, 0, 0);
  679. if (pvBase != NULL)
  680. {
  681. // if Unicode signature assume Unicode file
  682. // otherwise it must be an ANSI file
  683. LPCWSTR pwcs = (LPCWSTR)pvBase;
  684. if ((dwFileSize >= 2) && (*pwcs == L'\xFEFF'))
  685. {
  686. GetNamesFromStringW(pwcs + 1, dwFileSize / sizeof(WCHAR) - 1, setNames);
  687. }
  688. else
  689. {
  690. GetNamesFromStringA((LPCSTR)pvBase, dwFileSize, setNames);
  691. }
  692. UnmapViewOfFile(pvBase);
  693. }
  694. else
  695. {
  696. hr = HRESULT_FROM_WIN32(GetLastError());
  697. }
  698. CloseHandle(hFileMappingObject);
  699. }
  700. else
  701. {
  702. hr = HRESULT_FROM_WIN32(GetLastError());
  703. }
  704. }
  705. CloseHandle(hFile);
  706. }
  707. else
  708. {
  709. hr = HRESULT_FROM_WIN32(GetLastError());
  710. }
  711. }
  712. else
  713. {
  714. hr = E_INVALIDARG;
  715. }
  716. if (FAILED(hr))
  717. {
  718. AdmtThrowError(GUID_NULL, GUID_NULL, hr, IDS_E_INCLUDE_NAMES_FILE, pszFileName);
  719. }
  720. }
  721. // GetNamesFromStringA Method
  722. void GetNamesFromStringA(LPCSTR pchString, DWORD cchString, StringSet& setNames)
  723. {
  724. static const CHAR chSeparators[] = "\t\n\r";
  725. LPCSTR pchStringEnd = &pchString[cchString];
  726. for (LPCSTR pch = pchString; pch < pchStringEnd; pch++)
  727. {
  728. // skip space characters
  729. while ((pch < pchStringEnd) && (*pch == ' '))
  730. {
  731. ++pch;
  732. }
  733. // beginning of name
  734. LPCSTR pchBeg = pch;
  735. // scan for separator saving pointer to last non-whitespace character
  736. LPCSTR pchEnd = pch;
  737. while ((pch < pchStringEnd) && (strchr(chSeparators, *pch) == NULL))
  738. {
  739. if (*pch++ != ' ')
  740. {
  741. pchEnd = pch;
  742. }
  743. }
  744. // insert name which doesn't contain any leading or trailing whitespace characters
  745. if (pchEnd > pchBeg)
  746. {
  747. size_t cchName = pchEnd - pchBeg;
  748. LPSTR pszName = (LPSTR) _alloca((cchName + 1) * sizeof(CHAR));
  749. strncpy(pszName, pchBeg, cchName);
  750. pszName[cchName] = '\0';
  751. setNames.insert(_bstr_t(pszName));
  752. }
  753. }
  754. }
  755. // GetNamesFromStringW Method
  756. void GetNamesFromStringW(LPCWSTR pchString, DWORD cchString, StringSet& setNames)
  757. {
  758. static const WCHAR chSeparators[] = L"\t\n\r";
  759. LPCWSTR pchStringEnd = &pchString[cchString];
  760. for (LPCWSTR pch = pchString; pch < pchStringEnd; pch++)
  761. {
  762. // skip space characters
  763. while ((pch < pchStringEnd) && (*pch == L' '))
  764. {
  765. ++pch;
  766. }
  767. // beginning of name
  768. LPCWSTR pchBeg = pch;
  769. // scan for separator saving pointer to last non-whitespace character
  770. LPCWSTR pchEnd = pch;
  771. while ((pch < pchStringEnd) && (wcschr(chSeparators, *pch) == NULL))
  772. {
  773. if (*pch++ != L' ')
  774. {
  775. pchEnd = pch;
  776. }
  777. }
  778. // insert name which doesn't contain any leading or trailing whitespace characters
  779. if (pchEnd > pchBeg)
  780. {
  781. _bstr_t strName(SysAllocStringLen(pchBeg, pchEnd - pchBeg), false);
  782. setNames.insert(strName);
  783. }
  784. }
  785. }
  786. // RemoveTrailingDollarSign Method
  787. _bstr_t RemoveTrailingDollarSign(LPCTSTR pszName)
  788. {
  789. LPTSTR psz = _T("");
  790. if (pszName)
  791. {
  792. size_t cch = _tcslen(pszName);
  793. if (cch > 0)
  794. {
  795. psz = reinterpret_cast<LPTSTR>(_alloca((cch + 1) * sizeof(_TCHAR)));
  796. _tcscpy(psz, pszName);
  797. LPTSTR p = &psz[cch - 1];
  798. if (*p == _T('$'))
  799. {
  800. *p = _T('\0');
  801. }
  802. }
  803. }
  804. return psz;
  805. }
  806. } // namespace