Source code of Windows XP (NT5)
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1168 lines
23 KiB

  1. //#include <StdAfx.h>
  2. #include "AdmtCrypt.h"
  3. #include <NtSecApi.h>
  4. #pragma comment( lib, "AdvApi32.lib" )
  5. namespace
  6. {
  7. void __stdcall CreateByteArray(DWORD cb, _variant_t& vntByteArray)
  8. {
  9. vntByteArray.Clear();
  10. vntByteArray.parray = SafeArrayCreateVector(VT_UI1, 0, cb);
  11. if (vntByteArray.parray == NULL)
  12. {
  13. _com_issue_error(E_OUTOFMEMORY);
  14. }
  15. vntByteArray.vt = VT_UI1|VT_ARRAY;
  16. }
  17. _variant_t operator +(const _variant_t& vntByteArrayA, const _variant_t& vntByteArrayB)
  18. {
  19. _variant_t vntByteArrayC;
  20. // validate parameters
  21. if ((vntByteArrayA.vt != (VT_UI1|VT_ARRAY)) || ((vntByteArrayA.parray == NULL)))
  22. {
  23. _com_issue_error(E_INVALIDARG);
  24. }
  25. if ((vntByteArrayB.vt != (VT_UI1|VT_ARRAY)) || ((vntByteArrayB.parray == NULL)))
  26. {
  27. _com_issue_error(E_INVALIDARG);
  28. }
  29. // concatenate byte arrays
  30. DWORD cbA = vntByteArrayA.parray->rgsabound[0].cElements;
  31. DWORD cbB = vntByteArrayB.parray->rgsabound[0].cElements;
  32. CreateByteArray(cbA + cbB, vntByteArrayC);
  33. memcpy(vntByteArrayC.parray->pvData, vntByteArrayA.parray->pvData, cbA);
  34. memcpy((BYTE*)vntByteArrayC.parray->pvData + cbA, vntByteArrayB.parray->pvData, cbB);
  35. return vntByteArrayC;
  36. }
  37. #ifdef _DEBUG
  38. _bstr_t __stdcall DebugByteArray(const _variant_t& vnt)
  39. {
  40. _bstr_t strArray;
  41. if ((vnt.vt == (VT_UI1|VT_ARRAY)) && ((vnt.parray != NULL)))
  42. {
  43. _TCHAR szArray[256] = _T("");
  44. DWORD c = vnt.parray->rgsabound[0].cElements;
  45. BYTE* pb = (BYTE*) vnt.parray->pvData;
  46. for (DWORD i = 0; i < c; i++, pb++)
  47. {
  48. _TCHAR sz[48];
  49. wsprintf(sz, _T("%02X"), (UINT)(USHORT)*pb);
  50. if (i > 0)
  51. {
  52. _tcscat(szArray, _T(" "));
  53. }
  54. _tcscat(szArray, sz);
  55. }
  56. strArray = szArray;
  57. }
  58. return strArray;
  59. }
  60. #define TRACE_BUFFER_SIZE 1024
  61. void _cdecl Trace(LPCTSTR pszFormat, ...)
  62. {
  63. _TCHAR szMessage[TRACE_BUFFER_SIZE];
  64. if (pszFormat)
  65. {
  66. va_list args;
  67. va_start(args, pszFormat);
  68. _vsntprintf(szMessage, TRACE_BUFFER_SIZE, pszFormat, args);
  69. va_end(args);
  70. #if 0
  71. OutputDebugString(szMessage);
  72. #else
  73. HANDLE hFile = CreateFile(L"C:\\AdmtCrypt.log", GENERIC_WRITE, FILE_SHARE_READ, NULL, OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL);
  74. if (hFile != INVALID_HANDLE_VALUE)
  75. {
  76. SetFilePointer(hFile, 0, NULL, FILE_END);
  77. DWORD dwWritten;
  78. WriteFile(hFile, szMessage, _tcslen(szMessage) * sizeof(_TCHAR), &dwWritten, NULL);
  79. CloseHandle(hFile);
  80. }
  81. #endif
  82. }
  83. }
  84. #else
  85. _bstr_t __stdcall DebugByteArray(const _variant_t& vnt)
  86. {
  87. return _T("");
  88. }
  89. void _cdecl Trace(LPCTSTR pszFormat, ...)
  90. {
  91. }
  92. #endif
  93. }
  94. //---------------------------------------------------------------------------
  95. // Target Crypt Class
  96. //---------------------------------------------------------------------------
  97. // Constructor
  98. CTargetCrypt::CTargetCrypt()
  99. {
  100. Trace(_T("CTargetCrypt::CTargetCrypt()\r\n"));
  101. }
  102. // Destructor
  103. CTargetCrypt::~CTargetCrypt()
  104. {
  105. Trace(_T("CTargetCrypt::~CTargetCrypt()\r\n"));
  106. }
  107. // CreateEncryptionKey Method
  108. _variant_t CTargetCrypt::CreateEncryptionKey(LPCTSTR pszKeyId, LPCTSTR pszPassword)
  109. {
  110. Trace(_T("CreateEncryptionKey(pszKeyId='%s', pszPassword='%s')\r\n"), pszKeyId, pszPassword);
  111. // generate encryption key bytes
  112. _variant_t vntBytes = GenerateRandom(ENCRYPTION_KEY_SIZE);
  113. Trace(_T(" vntBytes={ %s }\r\n"), (LPCTSTR)DebugByteArray(vntBytes));
  114. // store encryption key bytes
  115. StoreBytes(pszKeyId, vntBytes);
  116. // create key from password
  117. CCryptHash hashPassword(CreateHash(CALG_SHA1));
  118. if (pszPassword && pszPassword[0])
  119. {
  120. hashPassword.Hash(pszPassword);
  121. }
  122. else
  123. {
  124. BYTE b = 0;
  125. hashPassword.Hash(&b, 1);
  126. }
  127. CCryptKey keyPassword(DeriveKey(CALG_3DES, hashPassword));
  128. _variant_t vntPasswordFlag;
  129. CreateByteArray(1, vntPasswordFlag);
  130. *((BYTE*)vntPasswordFlag.parray->pvData) = (pszPassword && pszPassword[0]) ? 0xFF : 0x00;
  131. // concatenate encryption key bytes and hash of encryption key bytes
  132. CCryptHash hashBytes(CreateHash(CALG_SHA1));
  133. hashBytes.Hash(vntBytes);
  134. _variant_t vntDecrypted = vntBytes + hashBytes.GetValue();
  135. // Trace(_T(" vntDecrypted={ %s }\n"), (LPCTSTR)DebugByteArray(vntDecrypted));
  136. // encrypt bytes / hash pair
  137. _variant_t vntEncrypted = keyPassword.Encrypt(NULL, true, vntDecrypted);
  138. // Trace(_T(" vntEncrypted={ %s }\n"), (LPCTSTR)DebugByteArray(vntEncrypted));
  139. return vntPasswordFlag + vntEncrypted;
  140. }
  141. // CreateSession Method
  142. _variant_t CTargetCrypt::CreateSession(LPCTSTR pszKeyId)
  143. {
  144. Trace(_T("CreateSession(pszKeyId='%s')\r\n"), pszKeyId);
  145. // get encryption key
  146. CCryptHash hashEncryption(CreateHash(CALG_SHA1));
  147. hashEncryption.Hash(RetrieveBytes(pszKeyId));
  148. CCryptKey keyEncryption(DeriveKey(CALG_3DES, hashEncryption));
  149. // generate session key bytes
  150. _variant_t vntBytes = GenerateRandom(SESSION_KEY_SIZE);
  151. // create session key
  152. CCryptHash hash(CreateHash(CALG_SHA1));
  153. hash.Hash(vntBytes);
  154. m_keySession.Attach(DeriveKey(CALG_3DES, hash));
  155. // concatenate session key bytes and hash of session key bytes
  156. _variant_t vntDecrypted = vntBytes + hash.GetValue();
  157. // encrypt session bytes and include hash
  158. return keyEncryption.Encrypt(NULL, true, vntDecrypted);
  159. }
  160. // Encrypt Method
  161. _variant_t CTargetCrypt::Encrypt(_bstr_t strData)
  162. {
  163. Trace(_T("Encrypt(strData='%s')\r\n"), (LPCTSTR)strData);
  164. // convert string to byte array
  165. _variant_t vnt;
  166. HRESULT hr = VectorFromBstr(strData, &vnt.parray);
  167. if (FAILED(hr))
  168. {
  169. _com_issue_error(hr);
  170. }
  171. vnt.vt = VT_UI1|VT_ARRAY;
  172. // encrypt data
  173. return m_keySession.Encrypt(NULL, true, vnt);
  174. }
  175. //---------------------------------------------------------------------------
  176. // Source Crypt Class
  177. //---------------------------------------------------------------------------
  178. // Constructor
  179. CSourceCrypt::CSourceCrypt()
  180. {
  181. Trace(_T("CSourceCrypt::CSourceCrypt()\r\n"));
  182. }
  183. // Destructor
  184. CSourceCrypt::~CSourceCrypt()
  185. {
  186. Trace(_T("CSourceCrypt::~CSourceCrypt()\r\n"));
  187. }
  188. // ImportEncryptionKey Method
  189. void CSourceCrypt::ImportEncryptionKey(const _variant_t& vntEncryptedKey, LPCTSTR pszPassword)
  190. {
  191. Trace(_T("ImportEncryptionKey(vntEncryptedKey={ %s }, pszPassword='%s')\r\n"), (LPCTSTR)DebugByteArray(vntEncryptedKey), pszPassword);
  192. // validate parameters
  193. if ((vntEncryptedKey.vt != (VT_UI1|VT_ARRAY)) || ((vntEncryptedKey.parray == NULL)))
  194. {
  195. _com_issue_error(E_INVALIDARG);
  196. }
  197. // extract password flag and verify with password
  198. bool bPassword = *((BYTE*)vntEncryptedKey.parray->pvData) ? true : false;
  199. if (bPassword)
  200. {
  201. if ((pszPassword == NULL) || (pszPassword[0] == NULL))
  202. {
  203. _com_issue_error(HRESULT_FROM_WIN32(ERROR_INVALID_PASSWORD));
  204. }
  205. }
  206. else
  207. {
  208. if (pszPassword && pszPassword[0])
  209. {
  210. _com_issue_error(HRESULT_FROM_WIN32(ERROR_INVALID_PASSWORD));
  211. }
  212. }
  213. // create key from password
  214. CCryptHash hashPassword(CreateHash(CALG_SHA1));
  215. if (pszPassword && pszPassword[0])
  216. {
  217. hashPassword.Hash(pszPassword);
  218. }
  219. else
  220. {
  221. BYTE b = 0;
  222. hashPassword.Hash(&b, 1);
  223. }
  224. CCryptKey keyPassword(DeriveKey(CALG_3DES, hashPassword));
  225. // encrypted data
  226. _variant_t vntEncrypted;
  227. DWORD cbEncrypted = vntEncryptedKey.parray->rgsabound[0].cElements - 1;
  228. CreateByteArray(cbEncrypted, vntEncrypted);
  229. memcpy(vntEncrypted.parray->pvData, (BYTE*)vntEncryptedKey.parray->pvData + 1, cbEncrypted);
  230. // Trace(_T(" vntEncrypted={ %s }\n"), (LPCTSTR)DebugByteArray(vntEncrypted));
  231. // decrypt encryption key bytes plus hash
  232. _variant_t vntDecrypted = keyPassword.Decrypt(NULL, true, vntEncrypted);
  233. // Trace(_T(" vntDecrypted={ %s }\n"), (LPCTSTR)DebugByteArray(vntDecrypted));
  234. // extract encryption key bytes
  235. _variant_t vntBytes;
  236. CreateByteArray(ENCRYPTION_KEY_SIZE, vntBytes);
  237. memcpy(vntBytes.parray->pvData, (BYTE*)vntDecrypted.parray->pvData, ENCRYPTION_KEY_SIZE);
  238. Trace(_T(" vntBytes={ %s }\r\n"), (LPCTSTR)DebugByteArray(vntBytes));
  239. // extract hash of encryption key bytes
  240. _variant_t vntHashValue;
  241. DWORD cbHashValue = vntDecrypted.parray->rgsabound[0].cElements - ENCRYPTION_KEY_SIZE;
  242. CreateByteArray(cbHashValue, vntHashValue);
  243. memcpy(vntHashValue.parray->pvData, (BYTE*)vntDecrypted.parray->pvData + ENCRYPTION_KEY_SIZE, cbHashValue);
  244. // Trace(_T(" vntHashValue={ %s }\n"), (LPCTSTR)DebugByteArray(vntHashValue));
  245. // create hash from bytes and create hash from hash value
  246. CCryptHash hashA(CreateHash(CALG_SHA1));
  247. hashA.Hash(vntBytes);
  248. CCryptHash hashB(CreateHash(CALG_SHA1));
  249. hashB.SetValue(vntHashValue);
  250. // if hashes compare store encryption key bytes
  251. if (hashA == hashB)
  252. {
  253. StoreBytes(m_szIdPrefix, vntBytes);
  254. }
  255. else
  256. {
  257. _com_issue_error(HRESULT_FROM_WIN32(ERROR_INVALID_PASSWORD));
  258. }
  259. }
  260. // ImportSessionKey Method
  261. void CSourceCrypt::ImportSessionKey(const _variant_t& vntEncryptedKey)
  262. {
  263. Trace(_T("ImportSessionKey(vntEncryptedKey={ %s })\r\n"), (LPCTSTR)DebugByteArray(vntEncryptedKey));
  264. // validate parameters
  265. if ((vntEncryptedKey.vt != (VT_UI1|VT_ARRAY)) || ((vntEncryptedKey.parray == NULL)))
  266. {
  267. _com_issue_error(E_INVALIDARG);
  268. }
  269. // get encryption key
  270. CCryptKey keyEncryption(GetEncryptionKey(m_szIdPrefix));
  271. // decrypt session key bytes plus hash
  272. _variant_t vntDecrypted = keyEncryption.Decrypt(NULL, true, vntEncryptedKey);
  273. // extract session key bytes
  274. _variant_t vntBytes;
  275. CreateByteArray(SESSION_KEY_SIZE, vntBytes);
  276. memcpy(vntBytes.parray->pvData, vntDecrypted.parray->pvData, SESSION_KEY_SIZE);
  277. // extract hash of session key bytes
  278. _variant_t vntHashValue;
  279. DWORD cbHashValue = vntDecrypted.parray->rgsabound[0].cElements - SESSION_KEY_SIZE;
  280. CreateByteArray(cbHashValue, vntHashValue);
  281. memcpy(vntHashValue.parray->pvData, (BYTE*)vntDecrypted.parray->pvData + SESSION_KEY_SIZE, cbHashValue);
  282. // create hash from bytes and create hash from hash value
  283. CCryptHash hashA(CreateHash(CALG_SHA1));
  284. hashA.Hash(vntBytes);
  285. CCryptHash hashB(CreateHash(CALG_SHA1));
  286. hashB.SetValue(vntHashValue);
  287. // if hashes compare
  288. if (hashA == hashB)
  289. {
  290. // derive session key from session key bytes hash
  291. m_keySession.Attach(DeriveKey(CALG_3DES, hashA));
  292. }
  293. else
  294. {
  295. _com_issue_error(E_FAIL);
  296. }
  297. }
  298. // Decrypt Method
  299. _bstr_t CSourceCrypt::Decrypt(const _variant_t& vntData)
  300. {
  301. Trace(_T("Decrypt(vntData={ %s })\r\n"), (LPCTSTR)DebugByteArray(vntData));
  302. // decrypt data
  303. _variant_t vnt = m_keySession.Decrypt(NULL, true, vntData);
  304. // convert into string
  305. BSTR bstr;
  306. HRESULT hr = BstrFromVector(vnt.parray, &bstr);
  307. if (FAILED(hr))
  308. {
  309. _com_issue_error(hr);
  310. }
  311. return bstr;
  312. }
  313. //---------------------------------------------------------------------------
  314. // Domain Crypt Class
  315. //---------------------------------------------------------------------------
  316. // Constructor
  317. CDomainCrypt::CDomainCrypt()
  318. {
  319. Trace(_T("CDomainCrypt::CDomainCrypt()\r\n"));
  320. }
  321. // Destructor
  322. CDomainCrypt::~CDomainCrypt()
  323. {
  324. Trace(_T("CDomainCrypt::~CDomainCrypt()\r\n"));
  325. }
  326. // GetEncryptionKey Method
  327. HCRYPTKEY CDomainCrypt::GetEncryptionKey(LPCTSTR pszKeyId)
  328. {
  329. // retrieve bytes
  330. _variant_t vntBytes = RetrieveBytes(pszKeyId);
  331. // set hash value
  332. CCryptHash hash;
  333. hash.Attach(CreateHash(CALG_SHA1));
  334. hash.Hash(vntBytes);
  335. // create encryption key derived from bytes
  336. return DeriveKey(CALG_3DES, hash);
  337. }
  338. // StoreBytes Method
  339. void CDomainCrypt::StoreBytes(LPCTSTR pszId, const _variant_t& vntBytes)
  340. {
  341. // validate parameters
  342. if ((pszId == NULL) || (pszId[0] == NULL))
  343. {
  344. _com_issue_error(E_INVALIDARG);
  345. }
  346. if ((vntBytes.vt != VT_EMPTY) && (vntBytes.vt != (VT_UI1|VT_ARRAY)))
  347. {
  348. _com_issue_error(E_INVALIDARG);
  349. }
  350. if ((vntBytes.vt == (VT_UI1|VT_ARRAY)) && (vntBytes.parray == NULL))
  351. {
  352. _com_issue_error(E_INVALIDARG);
  353. }
  354. LSA_HANDLE hPolicy = NULL;
  355. try
  356. {
  357. // open policy object
  358. LSA_OBJECT_ATTRIBUTES loa = { sizeof(LSA_OBJECT_ATTRIBUTES), NULL, NULL, 0, NULL, NULL };
  359. NTSTATUS ntsStatus = LsaOpenPolicy(NULL, &loa, POLICY_CREATE_SECRET, &hPolicy);
  360. if (!LSA_SUCCESS(ntsStatus))
  361. {
  362. _com_issue_error(HRESULT_FROM_WIN32(LsaNtStatusToWinError(ntsStatus)));
  363. }
  364. // store data
  365. PWSTR pwsKey = const_cast<PWSTR>(pszId);
  366. USHORT cbKey = _tcslen(pszId) * sizeof(_TCHAR);
  367. PWSTR pwsData = NULL;
  368. USHORT cbData = 0;
  369. if (vntBytes.vt != VT_EMPTY)
  370. {
  371. pwsData = reinterpret_cast<PWSTR>(vntBytes.parray->pvData);
  372. cbData = (USHORT) vntBytes.parray->rgsabound[0].cElements;
  373. }
  374. LSA_UNICODE_STRING lusKey = { cbKey, cbKey, pwsKey };
  375. LSA_UNICODE_STRING lusData = { cbData, cbData, pwsData };
  376. ntsStatus = LsaStorePrivateData(hPolicy, &lusKey, &lusData);
  377. if (!LSA_SUCCESS(ntsStatus))
  378. {
  379. _com_issue_error(HRESULT_FROM_WIN32(LsaNtStatusToWinError(ntsStatus)));
  380. }
  381. // close policy object
  382. LsaClose(hPolicy);
  383. }
  384. catch (...)
  385. {
  386. if (hPolicy)
  387. {
  388. LsaClose(hPolicy);
  389. }
  390. throw;
  391. }
  392. }
  393. // RetrievePrivateData Method
  394. _variant_t CDomainCrypt::RetrieveBytes(LPCTSTR pszId)
  395. {
  396. _variant_t vntBytes;
  397. // validate parameters
  398. if ((pszId == NULL) || (pszId[0] == NULL))
  399. {
  400. _com_issue_error(E_INVALIDARG);
  401. }
  402. LSA_HANDLE hPolicy = NULL;
  403. try
  404. {
  405. // open policy object
  406. LSA_OBJECT_ATTRIBUTES loa = { sizeof(LSA_OBJECT_ATTRIBUTES), NULL, NULL, 0, NULL, NULL };
  407. NTSTATUS ntsStatus = LsaOpenPolicy(NULL, &loa, POLICY_GET_PRIVATE_INFORMATION, &hPolicy);
  408. if (!LSA_SUCCESS(ntsStatus))
  409. {
  410. _com_issue_error(HRESULT_FROM_WIN32(LsaNtStatusToWinError(ntsStatus)));
  411. }
  412. // retrieve data
  413. PWSTR pwsKey = const_cast<PWSTR>(pszId);
  414. USHORT cbKey = _tcslen(pszId) * sizeof(_TCHAR);
  415. LSA_UNICODE_STRING lusKey = { cbKey, cbKey, pwsKey };
  416. PLSA_UNICODE_STRING plusData;
  417. ntsStatus = LsaRetrievePrivateData(hPolicy, &lusKey, &plusData);
  418. if (!LSA_SUCCESS(ntsStatus))
  419. {
  420. _com_issue_error(HRESULT_FROM_WIN32(LsaNtStatusToWinError(ntsStatus)));
  421. }
  422. vntBytes.parray = SafeArrayCreateVector(VT_UI1, 0, plusData->Length);
  423. if (vntBytes.parray == NULL)
  424. {
  425. LsaFreeMemory(plusData);
  426. _com_issue_error(E_OUTOFMEMORY);
  427. }
  428. vntBytes.vt = VT_UI1|VT_ARRAY;
  429. memcpy(vntBytes.parray->pvData, plusData->Buffer, plusData->Length);
  430. LsaFreeMemory(plusData);
  431. // close policy object
  432. LsaClose(hPolicy);
  433. }
  434. catch (...)
  435. {
  436. if (hPolicy)
  437. {
  438. LsaClose(hPolicy);
  439. }
  440. throw;
  441. }
  442. return vntBytes;
  443. }
  444. // private data key identifier
  445. _TCHAR CDomainCrypt::m_szIdPrefix[] = _T("L$6A2899C0-CECE-459A-B5EB-7ED04DE61388");
  446. //---------------------------------------------------------------------------
  447. // Crypt Provider Class
  448. //---------------------------------------------------------------------------
  449. // Constructors
  450. //
  451. // Notes:
  452. // If the enhanced provider is not installed, CryptAcquireContext() generates
  453. // the following error: (0x80090019) The keyset is not defined.
  454. CCryptProvider::CCryptProvider() :
  455. m_hProvider(NULL)
  456. {
  457. Trace(_T("E CCryptProvider::CCryptProvider(this=0x%p)\r\n"), this);
  458. if (!CryptAcquireContext(&m_hProvider, NULL, MS_ENHANCED_PROV, PROV_RSA_FULL, CRYPT_MACHINE_KEYSET|CRYPT_VERIFYCONTEXT))
  459. {
  460. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  461. }
  462. #ifdef _DEBUG
  463. char szProvider[256];
  464. DWORD cbProvider = sizeof(szProvider);
  465. if (CryptGetProvParam(m_hProvider, PP_NAME, (BYTE*) szProvider, &cbProvider, 0))
  466. {
  467. }
  468. DWORD dwVersion;
  469. DWORD cbVersion = sizeof(dwVersion);
  470. if (CryptGetProvParam(m_hProvider, PP_VERSION, (BYTE*) &dwVersion, &cbVersion, 0))
  471. {
  472. }
  473. // char szContainer[256];
  474. // DWORD cbContainer = sizeof(szContainer);
  475. // if (CryptGetProvParam(m_hProvider, PP_CONTAINER, (BYTE*) szContainer, &cbContainer, 0))
  476. // {
  477. // }
  478. #endif
  479. Trace(_T("L CCryptProvider::CCryptProvider()\r\n"));
  480. }
  481. CCryptProvider::CCryptProvider(const CCryptProvider& r) :
  482. m_hProvider(r.m_hProvider)
  483. {
  484. // if (!CryptContextAddRef(r.m_hProvider, NULL, 0))
  485. // {
  486. // _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  487. // }
  488. }
  489. // Destructor
  490. CCryptProvider::~CCryptProvider()
  491. {
  492. Trace(_T("E CCryptProvider::~CCryptProvider()\r\n"));
  493. if (m_hProvider)
  494. {
  495. if (!CryptReleaseContext(m_hProvider, 0))
  496. {
  497. #ifdef _DEBUG
  498. DebugBreak();
  499. #endif
  500. }
  501. }
  502. Trace(_T("L CCryptProvider::~CCryptProvider()\r\n"));
  503. }
  504. // assignment operators
  505. CCryptProvider& CCryptProvider::operator =(const CCryptProvider& r)
  506. {
  507. m_hProvider = r.m_hProvider;
  508. // if (!CryptContextAddRef(r.m_hProvider, NULL, 0))
  509. // {
  510. // _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  511. // }
  512. return *this;
  513. }
  514. // CreateHash Method
  515. HCRYPTHASH CCryptProvider::CreateHash(ALG_ID aid)
  516. {
  517. HCRYPTHASH hHash;
  518. if (!CryptCreateHash(m_hProvider, aid, 0, 0, &hHash))
  519. {
  520. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  521. }
  522. return hHash;
  523. }
  524. // DeriveKey Method
  525. HCRYPTKEY CCryptProvider::DeriveKey(ALG_ID aid, HCRYPTHASH hHash, DWORD dwFlags)
  526. {
  527. HCRYPTKEY hKey;
  528. if (!CryptDeriveKey(m_hProvider, aid, hHash, dwFlags, &hKey))
  529. {
  530. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  531. }
  532. return hKey;
  533. }
  534. // GenerateRandom Method
  535. //
  536. // Generates a specified number of random bytes.
  537. _variant_t CCryptProvider::GenerateRandom(DWORD dwNumberOfBytes) const
  538. {
  539. _variant_t vntRandom;
  540. // create byte array of specified length
  541. vntRandom.parray = SafeArrayCreateVector(VT_UI1, 0, dwNumberOfBytes);
  542. if (vntRandom.parray == NULL)
  543. {
  544. _com_issue_error(E_OUTOFMEMORY);
  545. }
  546. vntRandom.vt = VT_UI1|VT_ARRAY;
  547. // generate specified number of random bytes
  548. GenerateRandom((BYTE*)vntRandom.parray->pvData, dwNumberOfBytes);
  549. return vntRandom;
  550. }
  551. // GenerateRandom Method
  552. //
  553. // Generates a specified number of random bytes.
  554. void CCryptProvider::GenerateRandom(BYTE* pbData, DWORD cbData) const
  555. {
  556. // generate specified number of random bytes
  557. if (!CryptGenRandom(m_hProvider, cbData, pbData))
  558. {
  559. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  560. }
  561. }
  562. //---------------------------------------------------------------------------
  563. // Crypt Key Class
  564. //---------------------------------------------------------------------------
  565. // Constructor
  566. CCryptKey::CCryptKey(HCRYPTKEY hKey) :
  567. m_hKey(hKey)
  568. {
  569. }
  570. // Destructor
  571. CCryptKey::~CCryptKey()
  572. {
  573. if (m_hKey)
  574. {
  575. if (!CryptDestroyKey(m_hKey))
  576. {
  577. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  578. }
  579. }
  580. }
  581. // Encrypt Method
  582. _variant_t CCryptKey::Encrypt(HCRYPTHASH hHash, bool bFinal, const _variant_t& vntData)
  583. {
  584. _variant_t vntEncrypted;
  585. // validate parameters
  586. if ((vntData.vt != (VT_UI1|VT_ARRAY)) || ((vntData.parray == NULL)))
  587. {
  588. _com_issue_error(E_INVALIDARG);
  589. }
  590. // get encrypted data size
  591. DWORD cbData = vntData.parray->rgsabound[0].cElements;
  592. DWORD cbBuffer = cbData;
  593. if (!CryptEncrypt(m_hKey, hHash, bFinal ? TRUE : FALSE, 0, NULL, &cbBuffer, 0))
  594. {
  595. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  596. }
  597. // create encrypted data buffer
  598. vntEncrypted.parray = SafeArrayCreateVector(VT_UI1, 0, cbBuffer);
  599. if (vntEncrypted.parray == NULL)
  600. {
  601. _com_issue_error(E_OUTOFMEMORY);
  602. }
  603. vntEncrypted.vt = VT_UI1|VT_ARRAY;
  604. // copy data to encrypted buffer
  605. memcpy(vntEncrypted.parray->pvData, vntData.parray->pvData, cbData);
  606. // encrypt data
  607. BYTE* pbData = (BYTE*) vntEncrypted.parray->pvData;
  608. if (!CryptEncrypt(m_hKey, hHash, bFinal ? TRUE : FALSE, 0, pbData, &cbData, cbBuffer))
  609. {
  610. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  611. }
  612. return vntEncrypted;
  613. }
  614. // Decrypt Method
  615. _variant_t CCryptKey::Decrypt(HCRYPTHASH hHash, bool bFinal, const _variant_t& vntData)
  616. {
  617. _variant_t vntDecrypted;
  618. // validate parameters
  619. if ((vntData.vt != (VT_UI1|VT_ARRAY)) || ((vntData.parray == NULL)))
  620. {
  621. _com_issue_error(E_INVALIDARG);
  622. }
  623. // decrypt data
  624. _variant_t vnt = vntData;
  625. BYTE* pb = (BYTE*) vnt.parray->pvData;
  626. DWORD cb = vnt.parray->rgsabound[0].cElements;
  627. if (!CryptDecrypt(m_hKey, hHash, bFinal ? TRUE : FALSE, 0, pb, &cb))
  628. {
  629. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  630. }
  631. // create decrypted byte array
  632. // the number of decrypted bytes may be less than
  633. // the number of encrypted bytes
  634. vntDecrypted.parray = SafeArrayCreateVector(VT_UI1, 0, cb);
  635. if (vntDecrypted.parray == NULL)
  636. {
  637. _com_issue_error(E_OUTOFMEMORY);
  638. }
  639. vntDecrypted.vt = VT_UI1|VT_ARRAY;
  640. memcpy(vntDecrypted.parray->pvData, vnt.parray->pvData, cb);
  641. return vntDecrypted;
  642. }
  643. //---------------------------------------------------------------------------
  644. // Crypt Hash Class
  645. //---------------------------------------------------------------------------
  646. // Constructor
  647. CCryptHash::CCryptHash(HCRYPTHASH hHash) :
  648. m_hHash(hHash)
  649. {
  650. }
  651. // Destructor
  652. CCryptHash::~CCryptHash()
  653. {
  654. if (m_hHash)
  655. {
  656. if (!CryptDestroyHash(m_hHash))
  657. {
  658. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  659. }
  660. }
  661. }
  662. // GetValue Method
  663. _variant_t CCryptHash::GetValue() const
  664. {
  665. _variant_t vntValue;
  666. // get hash size
  667. DWORD dwHashSize;
  668. DWORD cbHashSize = sizeof(DWORD);
  669. if (!CryptGetHashParam(m_hHash, HP_HASHSIZE, (BYTE*)&dwHashSize, &cbHashSize, 0))
  670. {
  671. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  672. }
  673. // allocate buffer
  674. vntValue.parray = SafeArrayCreateVector(VT_UI1, 0, dwHashSize);
  675. if (vntValue.parray == NULL)
  676. {
  677. _com_issue_error(E_OUTOFMEMORY);
  678. }
  679. vntValue.vt = VT_UI1|VT_ARRAY;
  680. // get hash value
  681. if (!CryptGetHashParam(m_hHash, HP_HASHVAL, (BYTE*)vntValue.parray->pvData, &dwHashSize, 0))
  682. {
  683. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  684. }
  685. return vntValue;
  686. }
  687. // SetValue Method
  688. void CCryptHash::SetValue(const _variant_t& vntValue)
  689. {
  690. // if parameter is valid
  691. if ((vntValue.vt == (VT_UI1|VT_ARRAY)) && ((vntValue.parray != NULL)))
  692. {
  693. // get hash size
  694. DWORD dwHashSize;
  695. DWORD cbHashSize = sizeof(DWORD);
  696. if (!CryptGetHashParam(m_hHash, HP_HASHSIZE, (BYTE*)&dwHashSize, &cbHashSize, 0))
  697. {
  698. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  699. }
  700. // validate hash size
  701. BYTE* pbValue = (BYTE*)vntValue.parray->pvData;
  702. DWORD cbValue = vntValue.parray->rgsabound[0].cElements;
  703. if (cbValue != dwHashSize)
  704. {
  705. _com_issue_error(E_INVALIDARG);
  706. }
  707. // set hash value
  708. if (!CryptSetHashParam(m_hHash, HP_HASHVAL, (BYTE*)pbValue, 0))
  709. {
  710. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  711. }
  712. }
  713. else
  714. {
  715. _com_issue_error(E_INVALIDARG);
  716. }
  717. }
  718. // Hash Method
  719. void CCryptHash::Hash(LPCTSTR pszData)
  720. {
  721. if (pszData && pszData[0])
  722. {
  723. Hash((BYTE*)pszData, _tcslen(pszData) * sizeof(_TCHAR));
  724. }
  725. else
  726. {
  727. _com_issue_error(E_INVALIDARG);
  728. }
  729. }
  730. // Hash Method
  731. void CCryptHash::Hash(const _variant_t& vntData)
  732. {
  733. if ((vntData.vt == (VT_UI1|VT_ARRAY)) && ((vntData.parray != NULL)))
  734. {
  735. Hash((BYTE*)vntData.parray->pvData, vntData.parray->rgsabound[0].cElements);
  736. }
  737. else
  738. {
  739. _com_issue_error(E_INVALIDARG);
  740. }
  741. }
  742. // Hash Method
  743. void CCryptHash::Hash(BYTE* pbData, DWORD cbData)
  744. {
  745. if ((pbData != NULL) && (cbData > 0))
  746. {
  747. if (!CryptHashData(m_hHash, pbData, cbData, 0))
  748. {
  749. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  750. }
  751. }
  752. else
  753. {
  754. _com_issue_error(E_INVALIDARG);
  755. }
  756. }
  757. bool CCryptHash::operator ==(const CCryptHash& hash)
  758. {
  759. bool bEqual = false;
  760. DWORD cbSize = sizeof(DWORD);
  761. // compare hash sizes
  762. DWORD dwSizeA;
  763. DWORD dwSizeB;
  764. if (!CryptGetHashParam(m_hHash, HP_HASHSIZE, (BYTE*)&dwSizeA, &cbSize, 0))
  765. {
  766. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  767. }
  768. if (!CryptGetHashParam(hash.m_hHash, HP_HASHSIZE, (BYTE*)&dwSizeB, &cbSize, 0))
  769. {
  770. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  771. }
  772. // if sizes are equal
  773. if (dwSizeA == dwSizeB)
  774. {
  775. // compare hashes
  776. BYTE* pbA;
  777. BYTE* pbB;
  778. try
  779. {
  780. pbA = (BYTE*) _alloca(dwSizeA);
  781. pbB = (BYTE*) _alloca(dwSizeB);
  782. }
  783. catch (...)
  784. {
  785. _com_issue_error(E_OUTOFMEMORY);
  786. }
  787. if (!CryptGetHashParam(m_hHash, HP_HASHVAL, pbA, &dwSizeA, 0))
  788. {
  789. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  790. }
  791. if (!CryptGetHashParam(hash.m_hHash, HP_HASHVAL, pbB, &dwSizeB, 0))
  792. {
  793. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  794. }
  795. if (memcmp(pbA, pbB, dwSizeA) == 0)
  796. {
  797. bEqual = true;
  798. }
  799. }
  800. return bEqual;
  801. }