Leaked source code of windows server 2003
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1185 lines
25 KiB

  1. #include <StdAfx.h>
  2. #include "AdmtCrypt.h"
  3. #include "Array.h"
  4. #include <NtSecApi.h>
  5. #pragma comment( lib, "AdvApi32.lib" )
  6. namespace
  7. {
  8. void __stdcall CreateByteArray(DWORD cb, _variant_t& vntByteArray)
  9. {
  10. vntByteArray.Clear();
  11. vntByteArray.parray = SafeArrayCreateVector(VT_UI1, 0, cb);
  12. if (vntByteArray.parray == NULL)
  13. {
  14. _com_issue_error(E_OUTOFMEMORY);
  15. }
  16. vntByteArray.vt = VT_UI1|VT_ARRAY;
  17. }
  18. _variant_t operator +(const _variant_t& vntByteArrayA, const _variant_t& vntByteArrayB)
  19. {
  20. _variant_t vntByteArrayC;
  21. // validate parameters
  22. if ((vntByteArrayA.vt != (VT_UI1|VT_ARRAY)) || ((vntByteArrayA.parray == NULL)))
  23. {
  24. _com_issue_error(E_INVALIDARG);
  25. }
  26. if ((vntByteArrayB.vt != (VT_UI1|VT_ARRAY)) || ((vntByteArrayB.parray == NULL)))
  27. {
  28. _com_issue_error(E_INVALIDARG);
  29. }
  30. // concatenate byte arrays
  31. DWORD cbA = vntByteArrayA.parray->rgsabound[0].cElements;
  32. DWORD cbB = vntByteArrayB.parray->rgsabound[0].cElements;
  33. CreateByteArray(cbA + cbB, vntByteArrayC);
  34. memcpy(vntByteArrayC.parray->pvData, vntByteArrayA.parray->pvData, cbA);
  35. memcpy((BYTE*)vntByteArrayC.parray->pvData + cbA, vntByteArrayB.parray->pvData, cbB);
  36. return vntByteArrayC;
  37. }
  38. #ifdef _DEBUG
  39. _bstr_t __stdcall DebugByteArray(const _variant_t& vnt)
  40. {
  41. _bstr_t strArray;
  42. if ((vnt.vt == (VT_UI1|VT_ARRAY)) && ((vnt.parray != NULL)))
  43. {
  44. _TCHAR szArray[256] = _T("");
  45. DWORD c = vnt.parray->rgsabound[0].cElements;
  46. BYTE* pb = (BYTE*) vnt.parray->pvData;
  47. for (DWORD i = 0; i < c; i++, pb++)
  48. {
  49. _TCHAR sz[48];
  50. wsprintf(sz, _T("%02X"), (UINT)(USHORT)*pb);
  51. if (i > 0)
  52. {
  53. _tcscat(szArray, _T(" "));
  54. }
  55. _tcscat(szArray, sz);
  56. }
  57. strArray = szArray;
  58. }
  59. return strArray;
  60. }
  61. #define TRACE_BUFFER_SIZE 1024
  62. void _cdecl Trace(LPCTSTR pszFormat, ...)
  63. {
  64. _TCHAR szMessage[TRACE_BUFFER_SIZE];
  65. if (pszFormat)
  66. {
  67. va_list args;
  68. va_start(args, pszFormat);
  69. _vsntprintf(szMessage, TRACE_BUFFER_SIZE, pszFormat, args);
  70. va_end(args);
  71. #if 0
  72. OutputDebugString(szMessage);
  73. #else
  74. HANDLE hFile = CreateFile(L"C:\\AdmtCrypt.log", GENERIC_WRITE, FILE_SHARE_READ, NULL, OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL);
  75. if (hFile != INVALID_HANDLE_VALUE)
  76. {
  77. SetFilePointer(hFile, 0, NULL, FILE_END);
  78. DWORD dwWritten;
  79. WriteFile(hFile, szMessage, _tcslen(szMessage) * sizeof(_TCHAR), &dwWritten, NULL);
  80. CloseHandle(hFile);
  81. }
  82. #endif
  83. }
  84. }
  85. #else
  86. _bstr_t __stdcall DebugByteArray(const _variant_t& vnt)
  87. {
  88. return _T("");
  89. }
  90. void _cdecl Trace(LPCTSTR pszFormat, ...)
  91. {
  92. }
  93. #endif
  94. }
  95. //---------------------------------------------------------------------------
  96. // Target Crypt Class
  97. //---------------------------------------------------------------------------
  98. // Constructor
  99. CTargetCrypt::CTargetCrypt()
  100. {
  101. Trace(_T("CTargetCrypt::CTargetCrypt()\r\n"));
  102. }
  103. // Destructor
  104. CTargetCrypt::~CTargetCrypt()
  105. {
  106. Trace(_T("CTargetCrypt::~CTargetCrypt()\r\n"));
  107. }
  108. // CreateEncryptionKey Method
  109. _variant_t CTargetCrypt::CreateEncryptionKey(LPCTSTR pszKeyId, LPCTSTR pszPassword)
  110. {
  111. Trace(_T("CreateEncryptionKey(pszKeyId='%s', pszPassword='%s')\r\n"), pszKeyId, pszPassword);
  112. // generate encryption key bytes
  113. _variant_t vntBytes = GenerateRandom(ENCRYPTION_KEY_SIZE);
  114. Trace(_T(" vntBytes={ %s }\r\n"), (LPCTSTR)DebugByteArray(vntBytes));
  115. // store encryption key bytes
  116. StoreBytes(pszKeyId, vntBytes);
  117. // create key from password
  118. CCryptHash hashPassword(CreateHash(CALG_SHA1));
  119. if (pszPassword && pszPassword[0])
  120. {
  121. hashPassword.Hash(pszPassword);
  122. }
  123. else
  124. {
  125. BYTE b = 0;
  126. hashPassword.Hash(&b, 1);
  127. }
  128. CCryptKey keyPassword(DeriveKey(CALG_3DES, hashPassword));
  129. _variant_t vntPasswordFlag;
  130. CreateByteArray(1, vntPasswordFlag);
  131. *((BYTE*)vntPasswordFlag.parray->pvData) = (pszPassword && pszPassword[0]) ? 0xFF : 0x00;
  132. // concatenate encryption key bytes and hash of encryption key bytes
  133. CCryptHash hashBytes(CreateHash(CALG_SHA1));
  134. hashBytes.Hash(vntBytes);
  135. _variant_t vntDecrypted = vntBytes + hashBytes.GetValue();
  136. // Trace(_T(" vntDecrypted={ %s }\n"), (LPCTSTR)DebugByteArray(vntDecrypted));
  137. // encrypt bytes / hash pair
  138. _variant_t vntEncrypted = keyPassword.Encrypt(NULL, true, vntDecrypted);
  139. // Trace(_T(" vntEncrypted={ %s }\n"), (LPCTSTR)DebugByteArray(vntEncrypted));
  140. SecureZeroMemory(GET_BYTE_ARRAY_DATA(vntBytes), GET_BYTE_ARRAY_SIZE(vntBytes));
  141. SecureZeroMemory(GET_BYTE_ARRAY_DATA(vntDecrypted), GET_BYTE_ARRAY_SIZE(vntDecrypted));
  142. return vntPasswordFlag + vntEncrypted;
  143. }
  144. // CreateSession Method
  145. _variant_t CTargetCrypt::CreateSession(LPCTSTR pszKeyId)
  146. {
  147. Trace(_T("CreateSession(pszKeyId='%s')\r\n"), pszKeyId);
  148. // get encryption key
  149. CCryptHash hashEncryption(CreateHash(CALG_SHA1));
  150. hashEncryption.Hash(RetrieveBytes(pszKeyId));
  151. CCryptKey keyEncryption(DeriveKey(CALG_3DES, hashEncryption));
  152. // generate session key bytes
  153. _variant_t vntBytes = GenerateRandom(SESSION_KEY_SIZE);
  154. // create session key
  155. CCryptHash hash(CreateHash(CALG_SHA1));
  156. hash.Hash(vntBytes);
  157. m_keySession.Attach(DeriveKey(CALG_3DES, hash));
  158. // concatenate session key bytes and hash of session key bytes
  159. _variant_t vntDecrypted = vntBytes + hash.GetValue();
  160. // encrypt session bytes and include hash
  161. _variant_t varEncrypted = keyEncryption.Encrypt(NULL, true, vntDecrypted);
  162. SecureZeroMemory(GET_BYTE_ARRAY_DATA(vntBytes), GET_BYTE_ARRAY_SIZE(vntBytes));
  163. SecureZeroMemory(GET_BYTE_ARRAY_DATA(vntDecrypted), GET_BYTE_ARRAY_SIZE(vntDecrypted));
  164. return varEncrypted;
  165. }
  166. // Encrypt Method
  167. _variant_t CTargetCrypt::Encrypt(_bstr_t strData)
  168. {
  169. Trace(_T("Encrypt(strData='%s')\r\n"), (LPCTSTR)strData);
  170. // convert string to byte array
  171. _variant_t vnt;
  172. HRESULT hr = VectorFromBstr(strData, &vnt.parray);
  173. if (FAILED(hr))
  174. {
  175. _com_issue_error(hr);
  176. }
  177. vnt.vt = VT_UI1|VT_ARRAY;
  178. // encrypt data
  179. _variant_t varEncrypted = m_keySession.Encrypt(NULL, true, vnt);
  180. SecureZeroMemory(GET_BYTE_ARRAY_DATA(vnt), GET_BYTE_ARRAY_SIZE(vnt));
  181. return varEncrypted;
  182. }
  183. //---------------------------------------------------------------------------
  184. // Source Crypt Class
  185. //---------------------------------------------------------------------------
  186. // Constructor
  187. CSourceCrypt::CSourceCrypt()
  188. {
  189. Trace(_T("CSourceCrypt::CSourceCrypt()\r\n"));
  190. }
  191. // Destructor
  192. CSourceCrypt::~CSourceCrypt()
  193. {
  194. Trace(_T("CSourceCrypt::~CSourceCrypt()\r\n"));
  195. }
  196. // ImportEncryptionKey Method
  197. void CSourceCrypt::ImportEncryptionKey(const _variant_t& vntEncryptedKey, LPCTSTR pszPassword)
  198. {
  199. Trace(_T("ImportEncryptionKey(vntEncryptedKey={ %s }, pszPassword='%s')\r\n"), (LPCTSTR)DebugByteArray(vntEncryptedKey), pszPassword);
  200. // validate parameters
  201. if ((vntEncryptedKey.vt != (VT_UI1|VT_ARRAY)) || ((vntEncryptedKey.parray == NULL)))
  202. {
  203. _com_issue_error(E_INVALIDARG);
  204. }
  205. // extract password flag and verify with password
  206. bool bPassword = *((BYTE*)vntEncryptedKey.parray->pvData) ? true : false;
  207. if (bPassword)
  208. {
  209. if ((pszPassword == NULL) || (pszPassword[0] == NULL))
  210. {
  211. _com_issue_error(HRESULT_FROM_WIN32(ERROR_INVALID_PASSWORD));
  212. }
  213. }
  214. else
  215. {
  216. if (pszPassword && pszPassword[0])
  217. {
  218. _com_issue_error(HRESULT_FROM_WIN32(ERROR_INVALID_PASSWORD));
  219. }
  220. }
  221. // create key from password
  222. CCryptHash hashPassword(CreateHash(CALG_SHA1));
  223. if (pszPassword && pszPassword[0])
  224. {
  225. hashPassword.Hash(pszPassword);
  226. }
  227. else
  228. {
  229. BYTE b = 0;
  230. hashPassword.Hash(&b, 1);
  231. }
  232. CCryptKey keyPassword(DeriveKey(CALG_3DES, hashPassword));
  233. // encrypted data
  234. _variant_t vntEncrypted;
  235. DWORD cbEncrypted = vntEncryptedKey.parray->rgsabound[0].cElements - 1;
  236. CreateByteArray(cbEncrypted, vntEncrypted);
  237. memcpy(vntEncrypted.parray->pvData, (BYTE*)vntEncryptedKey.parray->pvData + 1, cbEncrypted);
  238. // Trace(_T(" vntEncrypted={ %s }\n"), (LPCTSTR)DebugByteArray(vntEncrypted));
  239. // decrypt encryption key bytes plus hash
  240. _variant_t vntDecrypted = keyPassword.Decrypt(NULL, true, vntEncrypted);
  241. // Trace(_T(" vntDecrypted={ %s }\n"), (LPCTSTR)DebugByteArray(vntDecrypted));
  242. // extract encryption key bytes
  243. _variant_t vntBytes;
  244. CreateByteArray(ENCRYPTION_KEY_SIZE, vntBytes);
  245. memcpy(vntBytes.parray->pvData, (BYTE*)vntDecrypted.parray->pvData, ENCRYPTION_KEY_SIZE);
  246. Trace(_T(" vntBytes={ %s }\r\n"), (LPCTSTR)DebugByteArray(vntBytes));
  247. // extract hash of encryption key bytes
  248. _variant_t vntHashValue;
  249. DWORD cbHashValue = vntDecrypted.parray->rgsabound[0].cElements - ENCRYPTION_KEY_SIZE;
  250. CreateByteArray(cbHashValue, vntHashValue);
  251. memcpy(vntHashValue.parray->pvData, (BYTE*)vntDecrypted.parray->pvData + ENCRYPTION_KEY_SIZE, cbHashValue);
  252. // Trace(_T(" vntHashValue={ %s }\n"), (LPCTSTR)DebugByteArray(vntHashValue));
  253. // create hash from bytes and create hash from hash value
  254. CCryptHash hashA(CreateHash(CALG_SHA1));
  255. hashA.Hash(vntBytes);
  256. CCryptHash hashB(CreateHash(CALG_SHA1));
  257. hashB.SetValue(vntHashValue);
  258. // if hashes compare store encryption key bytes
  259. if (hashA == hashB)
  260. {
  261. StoreBytes(m_szIdPrefix, vntBytes);
  262. }
  263. else
  264. {
  265. _com_issue_error(HRESULT_FROM_WIN32(ERROR_INVALID_PASSWORD));
  266. }
  267. SecureZeroMemory(GET_BYTE_ARRAY_DATA(vntDecrypted), GET_BYTE_ARRAY_SIZE(vntDecrypted));
  268. SecureZeroMemory(GET_BYTE_ARRAY_DATA(vntBytes), GET_BYTE_ARRAY_SIZE(vntBytes));
  269. SecureZeroMemory(GET_BYTE_ARRAY_DATA(vntHashValue), GET_BYTE_ARRAY_SIZE(vntHashValue));
  270. }
  271. // ImportSessionKey Method
  272. void CSourceCrypt::ImportSessionKey(const _variant_t& vntEncryptedKey)
  273. {
  274. Trace(_T("ImportSessionKey(vntEncryptedKey={ %s })\r\n"), (LPCTSTR)DebugByteArray(vntEncryptedKey));
  275. // validate parameters
  276. if ((vntEncryptedKey.vt != (VT_UI1|VT_ARRAY)) || ((vntEncryptedKey.parray == NULL)))
  277. {
  278. _com_issue_error(E_INVALIDARG);
  279. }
  280. // get encryption key
  281. CCryptKey keyEncryption(GetEncryptionKey(m_szIdPrefix));
  282. // decrypt session key bytes plus hash
  283. _variant_t vntDecrypted = keyEncryption.Decrypt(NULL, true, vntEncryptedKey);
  284. // extract session key bytes
  285. _variant_t vntBytes;
  286. CreateByteArray(SESSION_KEY_SIZE, vntBytes);
  287. memcpy(vntBytes.parray->pvData, vntDecrypted.parray->pvData, SESSION_KEY_SIZE);
  288. // extract hash of session key bytes
  289. _variant_t vntHashValue;
  290. DWORD cbHashValue = vntDecrypted.parray->rgsabound[0].cElements - SESSION_KEY_SIZE;
  291. CreateByteArray(cbHashValue, vntHashValue);
  292. memcpy(vntHashValue.parray->pvData, (BYTE*)vntDecrypted.parray->pvData + SESSION_KEY_SIZE, cbHashValue);
  293. // create hash from bytes and create hash from hash value
  294. CCryptHash hashA(CreateHash(CALG_SHA1));
  295. hashA.Hash(vntBytes);
  296. CCryptHash hashB(CreateHash(CALG_SHA1));
  297. hashB.SetValue(vntHashValue);
  298. // if hashes compare
  299. if (hashA == hashB)
  300. {
  301. // derive session key from session key bytes hash
  302. m_keySession.Attach(DeriveKey(CALG_3DES, hashA));
  303. }
  304. else
  305. {
  306. _com_issue_error(E_FAIL);
  307. }
  308. SecureZeroMemory(GET_BYTE_ARRAY_DATA(vntDecrypted), GET_BYTE_ARRAY_SIZE(vntDecrypted));
  309. SecureZeroMemory(GET_BYTE_ARRAY_DATA(vntBytes), GET_BYTE_ARRAY_SIZE(vntBytes));
  310. SecureZeroMemory(GET_BYTE_ARRAY_DATA(vntHashValue), GET_BYTE_ARRAY_SIZE(vntHashValue));
  311. }
  312. // Decrypt Method
  313. _bstr_t CSourceCrypt::Decrypt(const _variant_t& vntData)
  314. {
  315. Trace(_T("Decrypt(vntData={ %s })\r\n"), (LPCTSTR)DebugByteArray(vntData));
  316. // decrypt data
  317. _variant_t vnt = m_keySession.Decrypt(NULL, true, vntData);
  318. // convert into string
  319. BSTR bstr;
  320. HRESULT hr = BstrFromVector(vnt.parray, &bstr);
  321. if (FAILED(hr))
  322. {
  323. _com_issue_error(hr);
  324. }
  325. SecureZeroMemory(GET_BYTE_ARRAY_DATA(vnt), GET_BYTE_ARRAY_SIZE(vnt));
  326. return bstr;
  327. }
  328. //---------------------------------------------------------------------------
  329. // Domain Crypt Class
  330. //---------------------------------------------------------------------------
  331. // Constructor
  332. CDomainCrypt::CDomainCrypt()
  333. {
  334. Trace(_T("CDomainCrypt::CDomainCrypt()\r\n"));
  335. }
  336. // Destructor
  337. CDomainCrypt::~CDomainCrypt()
  338. {
  339. Trace(_T("CDomainCrypt::~CDomainCrypt()\r\n"));
  340. }
  341. // GetEncryptionKey Method
  342. HCRYPTKEY CDomainCrypt::GetEncryptionKey(LPCTSTR pszKeyId)
  343. {
  344. // retrieve bytes
  345. _variant_t vntBytes = RetrieveBytes(pszKeyId);
  346. // set hash value
  347. CCryptHash hash;
  348. hash.Attach(CreateHash(CALG_SHA1));
  349. hash.Hash(vntBytes);
  350. // create encryption key derived from bytes
  351. SecureZeroMemory(GET_BYTE_ARRAY_DATA(vntBytes), GET_BYTE_ARRAY_SIZE(vntBytes));
  352. return DeriveKey(CALG_3DES, hash);
  353. }
  354. // StoreBytes Method
  355. void CDomainCrypt::StoreBytes(LPCTSTR pszId, const _variant_t& vntBytes)
  356. {
  357. // validate parameters
  358. if ((pszId == NULL) || (pszId[0] == NULL))
  359. {
  360. _com_issue_error(E_INVALIDARG);
  361. }
  362. if ((vntBytes.vt != VT_EMPTY) && (vntBytes.vt != (VT_UI1|VT_ARRAY)))
  363. {
  364. _com_issue_error(E_INVALIDARG);
  365. }
  366. if ((vntBytes.vt == (VT_UI1|VT_ARRAY)) && (vntBytes.parray == NULL))
  367. {
  368. _com_issue_error(E_INVALIDARG);
  369. }
  370. LSA_HANDLE hPolicy = NULL;
  371. try
  372. {
  373. // open policy object
  374. LSA_OBJECT_ATTRIBUTES loa = { sizeof(LSA_OBJECT_ATTRIBUTES), NULL, NULL, 0, NULL, NULL };
  375. NTSTATUS ntsStatus = LsaOpenPolicy(NULL, &loa, POLICY_CREATE_SECRET, &hPolicy);
  376. if (!LSA_SUCCESS(ntsStatus))
  377. {
  378. _com_issue_error(HRESULT_FROM_WIN32(LsaNtStatusToWinError(ntsStatus)));
  379. }
  380. // store data
  381. PWSTR pwsKey = const_cast<PWSTR>(pszId);
  382. USHORT cbKey = _tcslen(pszId) * sizeof(_TCHAR);
  383. PWSTR pwsData = NULL;
  384. USHORT cbData = 0;
  385. if (vntBytes.vt != VT_EMPTY)
  386. {
  387. pwsData = reinterpret_cast<PWSTR>(vntBytes.parray->pvData);
  388. cbData = (USHORT) vntBytes.parray->rgsabound[0].cElements;
  389. }
  390. LSA_UNICODE_STRING lusKey = { cbKey, cbKey, pwsKey };
  391. LSA_UNICODE_STRING lusData = { cbData, cbData, pwsData };
  392. ntsStatus = LsaStorePrivateData(hPolicy, &lusKey, &lusData);
  393. if (!LSA_SUCCESS(ntsStatus))
  394. {
  395. _com_issue_error(HRESULT_FROM_WIN32(LsaNtStatusToWinError(ntsStatus)));
  396. }
  397. // close policy object
  398. LsaClose(hPolicy);
  399. }
  400. catch (...)
  401. {
  402. if (hPolicy)
  403. {
  404. LsaClose(hPolicy);
  405. }
  406. throw;
  407. }
  408. }
  409. // RetrievePrivateData Method
  410. _variant_t CDomainCrypt::RetrieveBytes(LPCTSTR pszId)
  411. {
  412. _variant_t vntBytes;
  413. // validate parameters
  414. if ((pszId == NULL) || (pszId[0] == NULL))
  415. {
  416. _com_issue_error(E_INVALIDARG);
  417. }
  418. LSA_HANDLE hPolicy = NULL;
  419. try
  420. {
  421. // open policy object
  422. LSA_OBJECT_ATTRIBUTES loa = { sizeof(LSA_OBJECT_ATTRIBUTES), NULL, NULL, 0, NULL, NULL };
  423. NTSTATUS ntsStatus = LsaOpenPolicy(NULL, &loa, POLICY_GET_PRIVATE_INFORMATION, &hPolicy);
  424. if (!LSA_SUCCESS(ntsStatus))
  425. {
  426. _com_issue_error(HRESULT_FROM_WIN32(LsaNtStatusToWinError(ntsStatus)));
  427. }
  428. // retrieve data
  429. PWSTR pwsKey = const_cast<PWSTR>(pszId);
  430. USHORT cbKey = _tcslen(pszId) * sizeof(_TCHAR);
  431. LSA_UNICODE_STRING lusKey = { cbKey, cbKey, pwsKey };
  432. PLSA_UNICODE_STRING plusData;
  433. ntsStatus = LsaRetrievePrivateData(hPolicy, &lusKey, &plusData);
  434. if (!LSA_SUCCESS(ntsStatus))
  435. {
  436. _com_issue_error(HRESULT_FROM_WIN32(LsaNtStatusToWinError(ntsStatus)));
  437. }
  438. vntBytes.parray = SafeArrayCreateVector(VT_UI1, 0, plusData->Length);
  439. if (vntBytes.parray == NULL)
  440. {
  441. LsaFreeMemory(plusData);
  442. _com_issue_error(E_OUTOFMEMORY);
  443. }
  444. vntBytes.vt = VT_UI1|VT_ARRAY;
  445. memcpy(vntBytes.parray->pvData, plusData->Buffer, plusData->Length);
  446. LsaFreeMemory(plusData);
  447. // close policy object
  448. LsaClose(hPolicy);
  449. }
  450. catch (...)
  451. {
  452. if (hPolicy)
  453. {
  454. LsaClose(hPolicy);
  455. }
  456. throw;
  457. }
  458. return vntBytes;
  459. }
  460. // private data key identifier
  461. _TCHAR CDomainCrypt::m_szIdPrefix[] = _T("L$6A2899C0-CECE-459A-B5EB-7ED04DE61388");
  462. //---------------------------------------------------------------------------
  463. // Crypt Provider Class
  464. //---------------------------------------------------------------------------
  465. // Constructors
  466. //
  467. // Notes:
  468. // If the enhanced provider is not installed, CryptAcquireContext() generates
  469. // the following error: (0x80090019) The keyset is not defined.
  470. CCryptProvider::CCryptProvider() :
  471. m_hProvider(NULL)
  472. {
  473. Trace(_T("E CCryptProvider::CCryptProvider(this=0x%p)\r\n"), this);
  474. if (!CryptAcquireContext(&m_hProvider, NULL, MS_ENHANCED_PROV, PROV_RSA_FULL, CRYPT_MACHINE_KEYSET|CRYPT_VERIFYCONTEXT))
  475. {
  476. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  477. }
  478. #ifdef _DEBUG
  479. char szProvider[256];
  480. DWORD cbProvider = sizeof(szProvider);
  481. if (CryptGetProvParam(m_hProvider, PP_NAME, (BYTE*) szProvider, &cbProvider, 0))
  482. {
  483. }
  484. DWORD dwVersion;
  485. DWORD cbVersion = sizeof(dwVersion);
  486. if (CryptGetProvParam(m_hProvider, PP_VERSION, (BYTE*) &dwVersion, &cbVersion, 0))
  487. {
  488. }
  489. // char szContainer[256];
  490. // DWORD cbContainer = sizeof(szContainer);
  491. // if (CryptGetProvParam(m_hProvider, PP_CONTAINER, (BYTE*) szContainer, &cbContainer, 0))
  492. // {
  493. // }
  494. #endif
  495. Trace(_T("L CCryptProvider::CCryptProvider()\r\n"));
  496. }
  497. CCryptProvider::CCryptProvider(const CCryptProvider& r) :
  498. m_hProvider(r.m_hProvider)
  499. {
  500. // if (!CryptContextAddRef(r.m_hProvider, NULL, 0))
  501. // {
  502. // _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  503. // }
  504. }
  505. // Destructor
  506. CCryptProvider::~CCryptProvider()
  507. {
  508. Trace(_T("E CCryptProvider::~CCryptProvider()\r\n"));
  509. if (m_hProvider)
  510. {
  511. if (!CryptReleaseContext(m_hProvider, 0))
  512. {
  513. #ifdef _DEBUG
  514. DebugBreak();
  515. #endif
  516. }
  517. }
  518. Trace(_T("L CCryptProvider::~CCryptProvider()\r\n"));
  519. }
  520. // assignment operators
  521. CCryptProvider& CCryptProvider::operator =(const CCryptProvider& r)
  522. {
  523. m_hProvider = r.m_hProvider;
  524. // if (!CryptContextAddRef(r.m_hProvider, NULL, 0))
  525. // {
  526. // _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  527. // }
  528. return *this;
  529. }
  530. // CreateHash Method
  531. HCRYPTHASH CCryptProvider::CreateHash(ALG_ID aid)
  532. {
  533. HCRYPTHASH hHash;
  534. if (!CryptCreateHash(m_hProvider, aid, 0, 0, &hHash))
  535. {
  536. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  537. }
  538. return hHash;
  539. }
  540. // DeriveKey Method
  541. HCRYPTKEY CCryptProvider::DeriveKey(ALG_ID aid, HCRYPTHASH hHash, DWORD dwFlags)
  542. {
  543. HCRYPTKEY hKey;
  544. if (!CryptDeriveKey(m_hProvider, aid, hHash, dwFlags, &hKey))
  545. {
  546. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  547. }
  548. return hKey;
  549. }
  550. // GenerateRandom Method
  551. //
  552. // Generates a specified number of random bytes.
  553. _variant_t CCryptProvider::GenerateRandom(DWORD dwNumberOfBytes) const
  554. {
  555. _variant_t vntRandom;
  556. // create byte array of specified length
  557. vntRandom.parray = SafeArrayCreateVector(VT_UI1, 0, dwNumberOfBytes);
  558. if (vntRandom.parray == NULL)
  559. {
  560. _com_issue_error(E_OUTOFMEMORY);
  561. }
  562. vntRandom.vt = VT_UI1|VT_ARRAY;
  563. // generate specified number of random bytes
  564. GenerateRandom((BYTE*)vntRandom.parray->pvData, dwNumberOfBytes);
  565. return vntRandom;
  566. }
  567. // GenerateRandom Method
  568. //
  569. // Generates a specified number of random bytes.
  570. void CCryptProvider::GenerateRandom(BYTE* pbData, DWORD cbData) const
  571. {
  572. // generate specified number of random bytes
  573. if (!CryptGenRandom(m_hProvider, cbData, pbData))
  574. {
  575. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  576. }
  577. }
  578. //---------------------------------------------------------------------------
  579. // Crypt Key Class
  580. //---------------------------------------------------------------------------
  581. // Constructor
  582. CCryptKey::CCryptKey(HCRYPTKEY hKey) :
  583. m_hKey(hKey)
  584. {
  585. }
  586. // Destructor
  587. CCryptKey::~CCryptKey()
  588. {
  589. if (m_hKey)
  590. {
  591. if (!CryptDestroyKey(m_hKey))
  592. {
  593. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  594. }
  595. }
  596. }
  597. // Encrypt Method
  598. _variant_t CCryptKey::Encrypt(HCRYPTHASH hHash, bool bFinal, const _variant_t& vntData)
  599. {
  600. _variant_t vntEncrypted;
  601. // validate parameters
  602. if ((vntData.vt != (VT_UI1|VT_ARRAY)) || ((vntData.parray == NULL)))
  603. {
  604. _com_issue_error(E_INVALIDARG);
  605. }
  606. // get encrypted data size
  607. DWORD cbData = vntData.parray->rgsabound[0].cElements;
  608. DWORD cbBuffer = cbData;
  609. if (!CryptEncrypt(m_hKey, hHash, bFinal ? TRUE : FALSE, 0, NULL, &cbBuffer, 0))
  610. {
  611. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  612. }
  613. // create encrypted data buffer
  614. vntEncrypted.parray = SafeArrayCreateVector(VT_UI1, 0, cbBuffer);
  615. if (vntEncrypted.parray == NULL)
  616. {
  617. _com_issue_error(E_OUTOFMEMORY);
  618. }
  619. vntEncrypted.vt = VT_UI1|VT_ARRAY;
  620. // copy data to encrypted buffer
  621. memcpy(vntEncrypted.parray->pvData, vntData.parray->pvData, cbData);
  622. // encrypt data
  623. BYTE* pbData = (BYTE*) vntEncrypted.parray->pvData;
  624. if (!CryptEncrypt(m_hKey, hHash, bFinal ? TRUE : FALSE, 0, pbData, &cbData, cbBuffer))
  625. {
  626. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  627. }
  628. return vntEncrypted;
  629. }
  630. // Decrypt Method
  631. _variant_t CCryptKey::Decrypt(HCRYPTHASH hHash, bool bFinal, const _variant_t& vntData)
  632. {
  633. _variant_t vntDecrypted;
  634. // validate parameters
  635. if ((vntData.vt != (VT_UI1|VT_ARRAY)) || ((vntData.parray == NULL)))
  636. {
  637. _com_issue_error(E_INVALIDARG);
  638. }
  639. // decrypt data
  640. _variant_t vnt = vntData;
  641. BYTE* pb = (BYTE*) vnt.parray->pvData;
  642. DWORD cb = vnt.parray->rgsabound[0].cElements;
  643. if (!CryptDecrypt(m_hKey, hHash, bFinal ? TRUE : FALSE, 0, pb, &cb))
  644. {
  645. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  646. }
  647. // create decrypted byte array
  648. // the number of decrypted bytes may be less than
  649. // the number of encrypted bytes
  650. vntDecrypted.parray = SafeArrayCreateVector(VT_UI1, 0, cb);
  651. if (vntDecrypted.parray == NULL)
  652. {
  653. _com_issue_error(E_OUTOFMEMORY);
  654. }
  655. vntDecrypted.vt = VT_UI1|VT_ARRAY;
  656. memcpy(vntDecrypted.parray->pvData, vnt.parray->pvData, cb);
  657. SecureZeroMemory(GET_BYTE_ARRAY_DATA(vnt), GET_BYTE_ARRAY_SIZE(vnt));
  658. return vntDecrypted;
  659. }
  660. //---------------------------------------------------------------------------
  661. // Crypt Hash Class
  662. //---------------------------------------------------------------------------
  663. // Constructor
  664. CCryptHash::CCryptHash(HCRYPTHASH hHash) :
  665. m_hHash(hHash)
  666. {
  667. }
  668. // Destructor
  669. CCryptHash::~CCryptHash()
  670. {
  671. if (m_hHash)
  672. {
  673. if (!CryptDestroyHash(m_hHash))
  674. {
  675. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  676. }
  677. }
  678. }
  679. // GetValue Method
  680. _variant_t CCryptHash::GetValue() const
  681. {
  682. _variant_t vntValue;
  683. // get hash size
  684. DWORD dwHashSize;
  685. DWORD cbHashSize = sizeof(DWORD);
  686. if (!CryptGetHashParam(m_hHash, HP_HASHSIZE, (BYTE*)&dwHashSize, &cbHashSize, 0))
  687. {
  688. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  689. }
  690. // allocate buffer
  691. vntValue.parray = SafeArrayCreateVector(VT_UI1, 0, dwHashSize);
  692. if (vntValue.parray == NULL)
  693. {
  694. _com_issue_error(E_OUTOFMEMORY);
  695. }
  696. vntValue.vt = VT_UI1|VT_ARRAY;
  697. // get hash value
  698. if (!CryptGetHashParam(m_hHash, HP_HASHVAL, (BYTE*)vntValue.parray->pvData, &dwHashSize, 0))
  699. {
  700. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  701. }
  702. return vntValue;
  703. }
  704. // SetValue Method
  705. void CCryptHash::SetValue(const _variant_t& vntValue)
  706. {
  707. // if parameter is valid
  708. if ((vntValue.vt == (VT_UI1|VT_ARRAY)) && ((vntValue.parray != NULL)))
  709. {
  710. // get hash size
  711. DWORD dwHashSize;
  712. DWORD cbHashSize = sizeof(DWORD);
  713. if (!CryptGetHashParam(m_hHash, HP_HASHSIZE, (BYTE*)&dwHashSize, &cbHashSize, 0))
  714. {
  715. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  716. }
  717. // validate hash size
  718. BYTE* pbValue = (BYTE*)vntValue.parray->pvData;
  719. DWORD cbValue = vntValue.parray->rgsabound[0].cElements;
  720. if (cbValue != dwHashSize)
  721. {
  722. _com_issue_error(E_INVALIDARG);
  723. }
  724. // set hash value
  725. if (!CryptSetHashParam(m_hHash, HP_HASHVAL, (BYTE*)pbValue, 0))
  726. {
  727. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  728. }
  729. }
  730. else
  731. {
  732. _com_issue_error(E_INVALIDARG);
  733. }
  734. }
  735. // Hash Method
  736. void CCryptHash::Hash(LPCTSTR pszData)
  737. {
  738. if (pszData && pszData[0])
  739. {
  740. Hash((BYTE*)pszData, _tcslen(pszData) * sizeof(_TCHAR));
  741. }
  742. else
  743. {
  744. _com_issue_error(E_INVALIDARG);
  745. }
  746. }
  747. // Hash Method
  748. void CCryptHash::Hash(const _variant_t& vntData)
  749. {
  750. if ((vntData.vt == (VT_UI1|VT_ARRAY)) && ((vntData.parray != NULL)))
  751. {
  752. Hash((BYTE*)vntData.parray->pvData, vntData.parray->rgsabound[0].cElements);
  753. }
  754. else
  755. {
  756. _com_issue_error(E_INVALIDARG);
  757. }
  758. }
  759. // Hash Method
  760. void CCryptHash::Hash(BYTE* pbData, DWORD cbData)
  761. {
  762. if ((pbData != NULL) && (cbData > 0))
  763. {
  764. if (!CryptHashData(m_hHash, pbData, cbData, 0))
  765. {
  766. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  767. }
  768. }
  769. else
  770. {
  771. _com_issue_error(E_INVALIDARG);
  772. }
  773. }
  774. bool CCryptHash::operator ==(const CCryptHash& hash)
  775. {
  776. bool bEqual = false;
  777. DWORD cbSize = sizeof(DWORD);
  778. // compare hash sizes
  779. DWORD dwSizeA;
  780. DWORD dwSizeB;
  781. if (!CryptGetHashParam(m_hHash, HP_HASHSIZE, (BYTE*)&dwSizeA, &cbSize, 0))
  782. {
  783. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  784. }
  785. if (!CryptGetHashParam(hash.m_hHash, HP_HASHSIZE, (BYTE*)&dwSizeB, &cbSize, 0))
  786. {
  787. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  788. }
  789. // if sizes are equal
  790. if (dwSizeA == dwSizeB)
  791. {
  792. // compare hashes
  793. c_array<BYTE> pbA(dwSizeA);
  794. c_array<BYTE> pbB(dwSizeB);
  795. if (!CryptGetHashParam(m_hHash, HP_HASHVAL, pbA, &dwSizeA, 0))
  796. {
  797. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  798. }
  799. if (!CryptGetHashParam(hash.m_hHash, HP_HASHVAL, pbB, &dwSizeB, 0))
  800. {
  801. _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));
  802. }
  803. if (memcmp(pbA, pbB, dwSizeA) == 0)
  804. {
  805. bEqual = true;
  806. }
  807. }
  808. return bEqual;
  809. }