Leaked source code of windows server 2003
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

486 lines
9.7 KiB

  1. //#include <StdAfx.h>
  2. #include "AdmtCrypt2.h"
  3. #include <NtSecApi.h>
  4. #pragma comment( lib, "AdvApi32.lib" )
  5. namespace AdmtCrypt2
  6. {
  7. #define SESSION_KEY_SIZE 16 // in bytes
  8. HCRYPTKEY __stdcall DeriveEncryptionKey(HCRYPTPROV hProvider);
  9. bool __stdcall IsDataMatchHash(HCRYPTPROV hProvider, const _variant_t& vntData, const _variant_t& vntHash);
  10. // Provider Methods
  11. HCRYPTKEY __stdcall DeriveKey(HCRYPTPROV hProvider, const _variant_t& vntBytes);
  12. HCRYPTHASH __stdcall CreateHash(HCRYPTPROV hProvider);
  13. bool __stdcall GenRandom(HCRYPTPROV hProvider, BYTE* pbData, DWORD cbData);
  14. // Key Methods
  15. void __stdcall DestroyKey(HCRYPTKEY hKey);
  16. bool __stdcall Decrypt(HCRYPTKEY hKey, const _variant_t& vntEncrypted, _variant_t& vntDecrypted);
  17. // Hash Methods
  18. void __stdcall DestroyHash(HCRYPTHASH hHash);
  19. bool __stdcall HashData(HCRYPTHASH hHash, const _variant_t& vntData);
  20. // Miscellaneous Helpers
  21. bool __stdcall RetrieveEncryptionBytes(_variant_t& vntBytes);
  22. // Variant Helpers
  23. bool __stdcall CreateByteArray(DWORD cb, _variant_t& vntByteArray);
  24. }
  25. using namespace AdmtCrypt2;
  26. //---------------------------------------------------------------------------
  27. // Source Crypt API
  28. //---------------------------------------------------------------------------
  29. // AdmtAcquireContext Method
  30. HCRYPTPROV __stdcall AdmtAcquireContext()
  31. {
  32. HCRYPTPROV hProvider = 0;
  33. BOOL bAcquire = CryptAcquireContext(
  34. &hProvider,
  35. NULL,
  36. MS_ENHANCED_PROV,
  37. PROV_RSA_FULL,
  38. CRYPT_MACHINE_KEYSET|CRYPT_VERIFYCONTEXT
  39. );
  40. if (!bAcquire)
  41. {
  42. hProvider = 0;
  43. }
  44. return hProvider;
  45. }
  46. // AdmtReleaseContext Method
  47. void __stdcall AdmtReleaseContext(HCRYPTPROV hProvider)
  48. {
  49. if (hProvider)
  50. {
  51. CryptReleaseContext(hProvider, 0);
  52. }
  53. }
  54. // AdmtImportSessionKey Method
  55. HCRYPTKEY __stdcall AdmtImportSessionKey(HCRYPTPROV hProvider, const _variant_t& vntEncryptedSessionBytes)
  56. {
  57. HCRYPTKEY hSessionKey = 0;
  58. if (hProvider && (vntEncryptedSessionBytes.vt == (VT_UI1|VT_ARRAY)) && ((vntEncryptedSessionBytes.parray != NULL)))
  59. {
  60. HCRYPTKEY hEncryptionKey = DeriveEncryptionKey(hProvider);
  61. if (hEncryptionKey)
  62. {
  63. _variant_t vntDecryptedSessionBytes;
  64. if (Decrypt(hEncryptionKey, vntEncryptedSessionBytes, vntDecryptedSessionBytes))
  65. {
  66. if (vntDecryptedSessionBytes.parray->rgsabound[0].cElements > SESSION_KEY_SIZE)
  67. {
  68. // extract session key bytes
  69. _variant_t vntBytes;
  70. if (CreateByteArray(SESSION_KEY_SIZE, vntBytes))
  71. {
  72. memcpy(vntBytes.parray->pvData, vntDecryptedSessionBytes.parray->pvData, SESSION_KEY_SIZE);
  73. // extract hash of session key bytes
  74. _variant_t vntHashValue;
  75. DWORD cbHashValue = vntDecryptedSessionBytes.parray->rgsabound[0].cElements - SESSION_KEY_SIZE;
  76. if (CreateByteArray(cbHashValue, vntHashValue))
  77. {
  78. memcpy(vntHashValue.parray->pvData, (BYTE*)vntDecryptedSessionBytes.parray->pvData + SESSION_KEY_SIZE, cbHashValue);
  79. if (IsDataMatchHash(hProvider, vntBytes, vntHashValue))
  80. {
  81. hSessionKey = DeriveKey(hProvider, vntBytes);
  82. }
  83. }
  84. }
  85. }
  86. else
  87. {
  88. SetLastError(ERROR_INVALID_PARAMETER);
  89. }
  90. }
  91. DestroyKey(hEncryptionKey);
  92. }
  93. }
  94. else
  95. {
  96. SetLastError(ERROR_INVALID_PARAMETER);
  97. }
  98. return hSessionKey;
  99. }
  100. // AdmtDecrypt Method
  101. _bstr_t __stdcall AdmtDecrypt(HCRYPTKEY hSessionKey, const _variant_t& vntEncrypted)
  102. {
  103. BSTR bstr = NULL;
  104. _variant_t vntDecrypted;
  105. if (Decrypt(hSessionKey, vntEncrypted, vntDecrypted))
  106. {
  107. HRESULT hr = BstrFromVector(vntDecrypted.parray, &bstr);
  108. if (FAILED(hr))
  109. {
  110. SetLastError(ERROR_NOT_ENOUGH_MEMORY);
  111. }
  112. }
  113. return _bstr_t(bstr, false);
  114. }
  115. // AdmtDestroyKey Method
  116. void __stdcall AdmtDestroyKey(HCRYPTKEY hKey)
  117. {
  118. DestroyKey(hKey);
  119. }
  120. //---------------------------------------------------------------------------
  121. // Private Helpers
  122. //---------------------------------------------------------------------------
  123. namespace AdmtCrypt2
  124. {
  125. HCRYPTKEY __stdcall DeriveEncryptionKey(HCRYPTPROV hProvider)
  126. {
  127. HCRYPTKEY hKey = 0;
  128. _variant_t vntBytes;
  129. if (RetrieveEncryptionBytes(vntBytes))
  130. {
  131. hKey = DeriveKey(hProvider, vntBytes);
  132. }
  133. return hKey;
  134. }
  135. bool __stdcall IsDataMatchHash(HCRYPTPROV hProvider, const _variant_t& vntData, const _variant_t& vntHash)
  136. {
  137. bool bMatch = false;
  138. HCRYPTHASH hHash = CreateHash(hProvider);
  139. if (hHash)
  140. {
  141. if (HashData(hHash, vntData))
  142. {
  143. DWORD dwSizeA;
  144. DWORD cbSize = sizeof(DWORD);
  145. if (CryptGetHashParam(hHash, HP_HASHSIZE, (BYTE*)&dwSizeA, &cbSize, 0))
  146. {
  147. DWORD dwSizeB = vntHash.parray->rgsabound[0].cElements;
  148. if (dwSizeA == dwSizeB)
  149. {
  150. BYTE* pbA = new BYTE[dwSizeA];
  151. if (pbA)
  152. {
  153. if (CryptGetHashParam(hHash, HP_HASHVAL, pbA, &dwSizeA, 0))
  154. {
  155. BYTE* pbB = (BYTE*) vntHash.parray->pvData;
  156. if (memcmp(pbA, pbB, dwSizeA) == 0)
  157. {
  158. bMatch = true;
  159. }
  160. }
  161. delete [] pbA;
  162. }
  163. else
  164. {
  165. SetLastError(ERROR_NOT_ENOUGH_MEMORY);
  166. }
  167. }
  168. }
  169. }
  170. }
  171. return bMatch;
  172. }
  173. // Provider Methods
  174. HCRYPTKEY __stdcall DeriveKey(HCRYPTPROV hProvider, const _variant_t& vntBytes)
  175. {
  176. HCRYPTKEY hKey = 0;
  177. HCRYPTHASH hHash = CreateHash(hProvider);
  178. if (hHash)
  179. {
  180. if (HashData(hHash, vntBytes))
  181. {
  182. if (!CryptDeriveKey(hProvider, CALG_3DES, hHash, 0, &hKey))
  183. {
  184. hKey = 0;
  185. }
  186. }
  187. DestroyHash(hHash);
  188. }
  189. return hKey;
  190. }
  191. HCRYPTHASH __stdcall CreateHash(HCRYPTPROV hProvider)
  192. {
  193. HCRYPTHASH hHash;
  194. if (!CryptCreateHash(hProvider, CALG_SHA1, 0, 0, &hHash))
  195. {
  196. hHash = 0;
  197. }
  198. return hHash;
  199. }
  200. bool __stdcall GenRandom(HCRYPTPROV hProvider, BYTE* pbData, DWORD cbData)
  201. {
  202. return CryptGenRandom(hProvider, cbData, pbData) ? true : false;
  203. }
  204. // Key Methods --------------------------------------------------------------
  205. // DestroyKey Method
  206. void __stdcall DestroyKey(HCRYPTKEY hKey)
  207. {
  208. if (hKey)
  209. {
  210. CryptDestroyKey(hKey);
  211. }
  212. }
  213. // Decrypt Method
  214. bool __stdcall Decrypt(HCRYPTKEY hKey, const _variant_t& vntEncrypted, _variant_t& vntDecrypted)
  215. {
  216. bool bDecrypted = false;
  217. _variant_t vnt = vntEncrypted;
  218. if ((vnt.vt == (VT_UI1|VT_ARRAY)) && (vnt.parray != NULL))
  219. {
  220. // decrypt data
  221. BYTE* pb = (BYTE*) vnt.parray->pvData;
  222. DWORD cb = vnt.parray->rgsabound[0].cElements;
  223. if (CryptDecrypt(hKey, NULL, TRUE, 0, pb, &cb))
  224. {
  225. // create decrypted byte array
  226. // the number of decrypted bytes may be less than
  227. // the number of encrypted bytes
  228. vntDecrypted.parray = SafeArrayCreateVector(VT_UI1, 0, cb);
  229. if (vntDecrypted.parray != NULL)
  230. {
  231. vntDecrypted.vt = VT_UI1|VT_ARRAY;
  232. memcpy(vntDecrypted.parray->pvData, vnt.parray->pvData, cb);
  233. bDecrypted = true;
  234. }
  235. else
  236. {
  237. SetLastError(ERROR_NOT_ENOUGH_MEMORY);
  238. }
  239. }
  240. }
  241. else
  242. {
  243. SetLastError(ERROR_INVALID_PARAMETER);
  244. }
  245. return bDecrypted;
  246. }
  247. // Hash Methods -------------------------------------------------------------
  248. // DestroyHash Method
  249. void __stdcall DestroyHash(HCRYPTHASH hHash)
  250. {
  251. if (hHash)
  252. {
  253. CryptDestroyHash(hHash);
  254. }
  255. }
  256. // HashData Method
  257. bool __stdcall HashData(HCRYPTHASH hHash, const _variant_t& vntData)
  258. {
  259. bool bHash = false;
  260. if ((vntData.vt == (VT_UI1|VT_ARRAY)) && ((vntData.parray != NULL)))
  261. {
  262. if (CryptHashData(hHash, (BYTE*)vntData.parray->pvData, vntData.parray->rgsabound[0].cElements, 0))
  263. {
  264. bHash = true;
  265. }
  266. }
  267. else
  268. {
  269. SetLastError(ERROR_INVALID_PARAMETER);
  270. }
  271. return bHash;
  272. }
  273. // Miscellaneous Helpers ----------------------------------------------------
  274. // RetrieveEncryptionBytes Method
  275. bool __stdcall RetrieveEncryptionBytes(_variant_t& vntBytes)
  276. {
  277. // private data key identifier
  278. _TCHAR c_szIdPrefix[] = _T("L$6A2899C0-CECE-459A-B5EB-7ED04DE61388");
  279. const USHORT c_cbIdPrefix = sizeof(c_szIdPrefix) - sizeof(_TCHAR);
  280. bool bRetrieve = false;
  281. // open policy object
  282. LSA_HANDLE hPolicy;
  283. LSA_OBJECT_ATTRIBUTES lsaoa = { sizeof(LSA_OBJECT_ATTRIBUTES), NULL, NULL, 0, NULL, NULL };
  284. NTSTATUS ntsStatus = LsaOpenPolicy(NULL, &lsaoa, POLICY_GET_PRIVATE_INFORMATION, &hPolicy);
  285. if (LSA_SUCCESS(ntsStatus))
  286. {
  287. // retrieve data
  288. LSA_UNICODE_STRING lsausKey = { c_cbIdPrefix, c_cbIdPrefix, c_szIdPrefix };
  289. PLSA_UNICODE_STRING plsausData;
  290. ntsStatus = LsaRetrievePrivateData(hPolicy, &lsausKey, &plsausData);
  291. if (LSA_SUCCESS(ntsStatus))
  292. {
  293. vntBytes.Clear();
  294. vntBytes.parray = SafeArrayCreateVector(VT_UI1, 0, plsausData->Length);
  295. if (vntBytes.parray != NULL)
  296. {
  297. vntBytes.vt = VT_UI1|VT_ARRAY;
  298. memcpy(vntBytes.parray->pvData, plsausData->Buffer, plsausData->Length);
  299. bRetrieve = true;
  300. }
  301. else
  302. {
  303. SetLastError(ERROR_NOT_ENOUGH_MEMORY);
  304. }
  305. LsaFreeMemory(plsausData);
  306. }
  307. else
  308. {
  309. SetLastError(LsaNtStatusToWinError(ntsStatus));
  310. }
  311. // close policy object
  312. LsaClose(hPolicy);
  313. }
  314. else
  315. {
  316. SetLastError(LsaNtStatusToWinError(ntsStatus));
  317. }
  318. return bRetrieve;
  319. }
  320. // Variant Helpers ----------------------------------------------------------
  321. // CreateByteArray Method
  322. bool __stdcall CreateByteArray(DWORD cb, _variant_t& vntByteArray)
  323. {
  324. bool bCreate = false;
  325. vntByteArray.Clear();
  326. vntByteArray.parray = SafeArrayCreateVector(VT_UI1, 0, cb);
  327. if (vntByteArray.parray)
  328. {
  329. bCreate = true;
  330. }
  331. else
  332. {
  333. SetLastError(ERROR_NOT_ENOUGH_MEMORY);
  334. }
  335. vntByteArray.vt = VT_UI1|VT_ARRAY;
  336. return bCreate;
  337. }
  338. }