Counter Strike : Global Offensive Source Code
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

833 lines
24 KiB

  1. // cryptlib.cpp - written and placed in the public domain by Wei Dai
  2. #include "pch.h"
  3. #ifndef CRYPTOPP_IMPORTS
  4. #include "cryptlib.h"
  5. #include "misc.h"
  6. #include "filters.h"
  7. #include "algparam.h"
  8. #include "fips140.h"
  9. #include "argnames.h"
  10. #include "fltrimpl.h"
  11. #include "trdlocal.h"
  12. #include "osrng.h"
  13. #include <memory>
  14. NAMESPACE_BEGIN(CryptoPP)
  15. CRYPTOPP_COMPILE_ASSERT(sizeof(byte) == 1);
  16. CRYPTOPP_COMPILE_ASSERT(sizeof(word16) == 2);
  17. CRYPTOPP_COMPILE_ASSERT(sizeof(word32) == 4);
  18. CRYPTOPP_COMPILE_ASSERT(sizeof(word64) == 8);
  19. #ifdef CRYPTOPP_NATIVE_DWORD_AVAILABLE
  20. CRYPTOPP_COMPILE_ASSERT(sizeof(dword) == 2*sizeof(word));
  21. #endif
  22. // VALVE, changed DEFAULT_CHANNEL to a basic type from std::string
  23. const char * DEFAULT_CHANNEL = "";
  24. const std::string AAD_CHANNEL = "AAD";
  25. const std::string &BufferedTransformation::NULL_CHANNEL = DEFAULT_CHANNEL;
  26. class NullNameValuePairs : public NameValuePairs
  27. {
  28. public:
  29. bool GetVoidValue(const char *name, const std::type_info &valueType, void *pValue) const {return false;}
  30. };
  31. // VALVE: Our debug allocator doesn't much care for global objects like this, it registers them
  32. // as a memory leak during validation. So make them const.
  33. //simple_ptr<NullNameValuePairs> s_pNullNameValuePairs(new NullNameValuePairs);
  34. //const NameValuePairs &g_nullNameValuePairs = *s_pNullNameValuePairs.m_p;
  35. const NullNameValuePairs s_NullNameValuePairs;
  36. const NameValuePairs &g_nullNameValuePairs = s_NullNameValuePairs;
  37. BufferedTransformation & TheBitBucket()
  38. {
  39. static BitBucket bitBucket;
  40. return bitBucket;
  41. }
  42. Algorithm::Algorithm(bool checkSelfTestStatus)
  43. {
  44. if (checkSelfTestStatus && FIPS_140_2_ComplianceEnabled())
  45. {
  46. if (GetPowerUpSelfTestStatus() == POWER_UP_SELF_TEST_NOT_DONE && !PowerUpSelfTestInProgressOnThisThread())
  47. throw SelfTestFailure("Cryptographic algorithms are disabled before the power-up self tests are performed.");
  48. if (GetPowerUpSelfTestStatus() == POWER_UP_SELF_TEST_FAILED)
  49. throw SelfTestFailure("Cryptographic algorithms are disabled after a power-up self test failed.");
  50. }
  51. }
  52. void SimpleKeyingInterface::SetKey(const byte *key, size_t length, const NameValuePairs &params)
  53. {
  54. this->ThrowIfInvalidKeyLength(length);
  55. this->UncheckedSetKey(key, (unsigned int)length, params);
  56. }
  57. void SimpleKeyingInterface::SetKeyWithRounds(const byte *key, size_t length, int rounds)
  58. {
  59. SetKey(key, length, MakeParameters(Name::Rounds(), rounds));
  60. }
  61. void SimpleKeyingInterface::SetKeyWithIV(const byte *key, size_t length, const byte *iv, size_t ivLength)
  62. {
  63. SetKey(key, length, MakeParameters(Name::IV(), ConstByteArrayParameter(iv, ivLength)));
  64. }
  65. void SimpleKeyingInterface::ThrowIfInvalidKeyLength(size_t length)
  66. {
  67. if (!IsValidKeyLength(length))
  68. throw InvalidKeyLength(GetAlgorithm().AlgorithmName(), length);
  69. }
  70. void SimpleKeyingInterface::ThrowIfResynchronizable()
  71. {
  72. if (IsResynchronizable())
  73. throw InvalidArgument(GetAlgorithm().AlgorithmName() + ": this object requires an IV");
  74. }
  75. void SimpleKeyingInterface::ThrowIfInvalidIV(const byte *iv)
  76. {
  77. if (!iv && IVRequirement() == UNPREDICTABLE_RANDOM_IV)
  78. throw InvalidArgument(GetAlgorithm().AlgorithmName() + ": this object cannot use a null IV");
  79. }
  80. size_t SimpleKeyingInterface::ThrowIfInvalidIVLength(int size)
  81. {
  82. if (size < 0)
  83. return IVSize();
  84. else if ((size_t)size < MinIVLength())
  85. throw InvalidArgument(GetAlgorithm().AlgorithmName() + ": IV length " + IntToString(size) + " is less than the minimum of " + IntToString(MinIVLength()));
  86. else if ((size_t)size > MaxIVLength())
  87. throw InvalidArgument(GetAlgorithm().AlgorithmName() + ": IV length " + IntToString(size) + " exceeds the maximum of " + IntToString(MaxIVLength()));
  88. else
  89. return size;
  90. }
  91. const byte * SimpleKeyingInterface::GetIVAndThrowIfInvalid(const NameValuePairs &params, size_t &size)
  92. {
  93. ConstByteArrayParameter ivWithLength;
  94. const byte *iv;
  95. bool found = false;
  96. try {found = params.GetValue(Name::IV(), ivWithLength);}
  97. catch (const NameValuePairs::ValueTypeMismatch &) {}
  98. if (found)
  99. {
  100. iv = ivWithLength.begin();
  101. ThrowIfInvalidIV(iv);
  102. size = ThrowIfInvalidIVLength((int)ivWithLength.size());
  103. return iv;
  104. }
  105. else if (params.GetValue(Name::IV(), iv))
  106. {
  107. ThrowIfInvalidIV(iv);
  108. size = IVSize();
  109. return iv;
  110. }
  111. else
  112. {
  113. ThrowIfResynchronizable();
  114. size = 0;
  115. return NULL;
  116. }
  117. }
  118. void SimpleKeyingInterface::GetNextIV(RandomNumberGenerator &rng, byte *IV)
  119. {
  120. rng.GenerateBlock(IV, IVSize());
  121. }
  122. size_t BlockTransformation::AdvancedProcessBlocks(const byte *inBlocks, const byte *xorBlocks, byte *outBlocks, size_t length, word32 flags) const
  123. {
  124. size_t blockSize = BlockSize();
  125. size_t inIncrement = (flags & (BT_InBlockIsCounter|BT_DontIncrementInOutPointers)) ? 0 : blockSize;
  126. size_t xorIncrement = xorBlocks ? blockSize : 0;
  127. size_t outIncrement = (flags & BT_DontIncrementInOutPointers) ? 0 : blockSize;
  128. if (flags & BT_ReverseDirection)
  129. {
  130. assert(length % blockSize == 0);
  131. inBlocks += length - blockSize;
  132. xorBlocks += length - blockSize;
  133. outBlocks += length - blockSize;
  134. inIncrement = 0-inIncrement;
  135. xorIncrement = 0-xorIncrement;
  136. outIncrement = 0-outIncrement;
  137. }
  138. while (length >= blockSize)
  139. {
  140. if (flags & BT_XorInput)
  141. {
  142. xorbuf(outBlocks, xorBlocks, inBlocks, blockSize);
  143. ProcessBlock(outBlocks);
  144. }
  145. else
  146. ProcessAndXorBlock(inBlocks, xorBlocks, outBlocks);
  147. if (flags & BT_InBlockIsCounter)
  148. const_cast<byte *>(inBlocks)[blockSize-1]++;
  149. inBlocks += inIncrement;
  150. outBlocks += outIncrement;
  151. xorBlocks += xorIncrement;
  152. length -= blockSize;
  153. }
  154. return length;
  155. }
  156. unsigned int BlockTransformation::OptimalDataAlignment() const
  157. {
  158. return GetAlignmentOf<word32>();
  159. }
  160. unsigned int StreamTransformation::OptimalDataAlignment() const
  161. {
  162. return GetAlignmentOf<word32>();
  163. }
  164. unsigned int HashTransformation::OptimalDataAlignment() const
  165. {
  166. return GetAlignmentOf<word32>();
  167. }
  168. void StreamTransformation::ProcessLastBlock(byte *outString, const byte *inString, size_t length)
  169. {
  170. assert(MinLastBlockSize() == 0); // this function should be overriden otherwise
  171. if (length == MandatoryBlockSize())
  172. ProcessData(outString, inString, length);
  173. else if (length != 0)
  174. throw NotImplemented(AlgorithmName() + ": this object does't support a special last block");
  175. }
  176. void AuthenticatedSymmetricCipher::SpecifyDataLengths(lword headerLength, lword messageLength, lword footerLength)
  177. {
  178. if (headerLength > MaxHeaderLength())
  179. throw InvalidArgument(GetAlgorithm().AlgorithmName() + ": header length " + IntToString(headerLength) + " exceeds the maximum of " + IntToString(MaxHeaderLength()));
  180. if (messageLength > MaxMessageLength())
  181. throw InvalidArgument(GetAlgorithm().AlgorithmName() + ": message length " + IntToString(messageLength) + " exceeds the maximum of " + IntToString(MaxMessageLength()));
  182. if (footerLength > MaxFooterLength())
  183. throw InvalidArgument(GetAlgorithm().AlgorithmName() + ": footer length " + IntToString(footerLength) + " exceeds the maximum of " + IntToString(MaxFooterLength()));
  184. UncheckedSpecifyDataLengths(headerLength, messageLength, footerLength);
  185. }
  186. void AuthenticatedSymmetricCipher::EncryptAndAuthenticate(byte *ciphertext, byte *mac, size_t macSize, const byte *iv, int ivLength, const byte *header, size_t headerLength, const byte *message, size_t messageLength)
  187. {
  188. Resynchronize(iv, ivLength);
  189. SpecifyDataLengths(headerLength, messageLength);
  190. Update(header, headerLength);
  191. ProcessString(ciphertext, message, messageLength);
  192. TruncatedFinal(mac, macSize);
  193. }
  194. bool AuthenticatedSymmetricCipher::DecryptAndVerify(byte *message, const byte *mac, size_t macLength, const byte *iv, int ivLength, const byte *header, size_t headerLength, const byte *ciphertext, size_t ciphertextLength)
  195. {
  196. Resynchronize(iv, ivLength);
  197. SpecifyDataLengths(headerLength, ciphertextLength);
  198. Update(header, headerLength);
  199. ProcessString(message, ciphertext, ciphertextLength);
  200. return TruncatedVerify(mac, macLength);
  201. }
  202. unsigned int RandomNumberGenerator::GenerateBit()
  203. {
  204. return GenerateByte() & 1;
  205. }
  206. byte RandomNumberGenerator::GenerateByte()
  207. {
  208. byte b;
  209. GenerateBlock(&b, 1);
  210. return b;
  211. }
  212. word32 RandomNumberGenerator::GenerateWord32(word32 min, word32 max)
  213. {
  214. word32 range = max-min;
  215. const int maxBits = BitPrecision(range);
  216. word32 value;
  217. do
  218. {
  219. GenerateBlock((byte *)&value, sizeof(value));
  220. value = Crop(value, maxBits);
  221. } while (value > range);
  222. return value+min;
  223. }
  224. void RandomNumberGenerator::GenerateBlock(byte *output, size_t size)
  225. {
  226. ArraySink s(output, size);
  227. GenerateIntoBufferedTransformation(s, DEFAULT_CHANNEL, size);
  228. }
  229. void RandomNumberGenerator::DiscardBytes(size_t n)
  230. {
  231. GenerateIntoBufferedTransformation(TheBitBucket(), DEFAULT_CHANNEL, n);
  232. }
  233. void RandomNumberGenerator::GenerateIntoBufferedTransformation(BufferedTransformation &target, const std::string &channel, lword length)
  234. {
  235. FixedSizeSecBlock<byte, 256> buffer;
  236. while (length)
  237. {
  238. size_t len = UnsignedMin(buffer.size(), length);
  239. GenerateBlock(buffer, len);
  240. target.ChannelPut(channel, buffer, len);
  241. length -= len;
  242. }
  243. }
  244. //! see NullRNG()
  245. class ClassNullRNG : public RandomNumberGenerator
  246. {
  247. public:
  248. std::string AlgorithmName() const {return "NullRNG";}
  249. void GenerateBlock(byte *output, size_t size) {throw NotImplemented("NullRNG: NullRNG should only be passed to functions that don't need to generate random bytes");}
  250. };
  251. RandomNumberGenerator & NullRNG()
  252. {
  253. static ClassNullRNG s_nullRNG;
  254. return s_nullRNG;
  255. }
  256. bool HashTransformation::TruncatedVerify(const byte *digestIn, size_t digestLength)
  257. {
  258. ThrowIfInvalidTruncatedSize(digestLength);
  259. SecByteBlock digest(digestLength);
  260. TruncatedFinal(digest, digestLength);
  261. return VerifyBufsEqual(digest, digestIn, digestLength);
  262. }
  263. void HashTransformation::ThrowIfInvalidTruncatedSize(size_t size) const
  264. {
  265. if (size > DigestSize())
  266. throw InvalidArgument("HashTransformation: can't truncate a " + IntToString(DigestSize()) + " byte digest to " + IntToString(size) + " bytes");
  267. }
  268. unsigned int BufferedTransformation::GetMaxWaitObjectCount() const
  269. {
  270. const BufferedTransformation *t = AttachedTransformation();
  271. return t ? t->GetMaxWaitObjectCount() : 0;
  272. }
  273. void BufferedTransformation::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
  274. {
  275. BufferedTransformation *t = AttachedTransformation();
  276. if (t)
  277. t->GetWaitObjects(container, callStack); // reduce clutter by not adding to stack here
  278. }
  279. void BufferedTransformation::Initialize(const NameValuePairs &parameters, int propagation)
  280. {
  281. assert(!AttachedTransformation());
  282. IsolatedInitialize(parameters);
  283. }
  284. bool BufferedTransformation::Flush(bool hardFlush, int propagation, bool blocking)
  285. {
  286. assert(!AttachedTransformation());
  287. return IsolatedFlush(hardFlush, blocking);
  288. }
  289. bool BufferedTransformation::MessageSeriesEnd(int propagation, bool blocking)
  290. {
  291. assert(!AttachedTransformation());
  292. return IsolatedMessageSeriesEnd(blocking);
  293. }
  294. byte * BufferedTransformation::ChannelCreatePutSpace(const std::string &channel, size_t &size)
  295. {
  296. if (channel.empty())
  297. return CreatePutSpace(size);
  298. else
  299. throw NoChannelSupport(AlgorithmName());
  300. }
  301. size_t BufferedTransformation::ChannelPut2(const std::string &channel, const byte *begin, size_t length, int messageEnd, bool blocking)
  302. {
  303. if (channel.empty())
  304. return Put2(begin, length, messageEnd, blocking);
  305. else
  306. throw NoChannelSupport(AlgorithmName());
  307. }
  308. size_t BufferedTransformation::ChannelPutModifiable2(const std::string &channel, byte *begin, size_t length, int messageEnd, bool blocking)
  309. {
  310. if (channel.empty())
  311. return PutModifiable2(begin, length, messageEnd, blocking);
  312. else
  313. return ChannelPut2(channel, begin, length, messageEnd, blocking);
  314. }
  315. bool BufferedTransformation::ChannelFlush(const std::string &channel, bool completeFlush, int propagation, bool blocking)
  316. {
  317. if (channel.empty())
  318. return Flush(completeFlush, propagation, blocking);
  319. else
  320. throw NoChannelSupport(AlgorithmName());
  321. }
  322. bool BufferedTransformation::ChannelMessageSeriesEnd(const std::string &channel, int propagation, bool blocking)
  323. {
  324. if (channel.empty())
  325. return MessageSeriesEnd(propagation, blocking);
  326. else
  327. throw NoChannelSupport(AlgorithmName());
  328. }
  329. lword BufferedTransformation::MaxRetrievable() const
  330. {
  331. if (AttachedTransformation())
  332. return AttachedTransformation()->MaxRetrievable();
  333. else
  334. return CopyTo(TheBitBucket());
  335. }
  336. bool BufferedTransformation::AnyRetrievable() const
  337. {
  338. if (AttachedTransformation())
  339. return AttachedTransformation()->AnyRetrievable();
  340. else
  341. {
  342. byte b;
  343. return Peek(b) != 0;
  344. }
  345. }
  346. size_t BufferedTransformation::Get(byte &outByte)
  347. {
  348. if (AttachedTransformation())
  349. return AttachedTransformation()->Get(outByte);
  350. else
  351. return Get(&outByte, 1);
  352. }
  353. size_t BufferedTransformation::Get(byte *outString, size_t getMax)
  354. {
  355. if (AttachedTransformation())
  356. return AttachedTransformation()->Get(outString, getMax);
  357. else
  358. {
  359. ArraySink arraySink(outString, getMax);
  360. return (size_t)TransferTo(arraySink, getMax);
  361. }
  362. }
  363. size_t BufferedTransformation::Peek(byte &outByte) const
  364. {
  365. if (AttachedTransformation())
  366. return AttachedTransformation()->Peek(outByte);
  367. else
  368. return Peek(&outByte, 1);
  369. }
  370. size_t BufferedTransformation::Peek(byte *outString, size_t peekMax) const
  371. {
  372. if (AttachedTransformation())
  373. return AttachedTransformation()->Peek(outString, peekMax);
  374. else
  375. {
  376. ArraySink arraySink(outString, peekMax);
  377. return (size_t)CopyTo(arraySink, peekMax);
  378. }
  379. }
  380. lword BufferedTransformation::Skip(lword skipMax)
  381. {
  382. if (AttachedTransformation())
  383. return AttachedTransformation()->Skip(skipMax);
  384. else
  385. return TransferTo(TheBitBucket(), skipMax);
  386. }
  387. lword BufferedTransformation::TotalBytesRetrievable() const
  388. {
  389. if (AttachedTransformation())
  390. return AttachedTransformation()->TotalBytesRetrievable();
  391. else
  392. return MaxRetrievable();
  393. }
  394. unsigned int BufferedTransformation::NumberOfMessages() const
  395. {
  396. if (AttachedTransformation())
  397. return AttachedTransformation()->NumberOfMessages();
  398. else
  399. return CopyMessagesTo(TheBitBucket());
  400. }
  401. bool BufferedTransformation::AnyMessages() const
  402. {
  403. if (AttachedTransformation())
  404. return AttachedTransformation()->AnyMessages();
  405. else
  406. return NumberOfMessages() != 0;
  407. }
  408. bool BufferedTransformation::GetNextMessage()
  409. {
  410. if (AttachedTransformation())
  411. return AttachedTransformation()->GetNextMessage();
  412. else
  413. {
  414. assert(!AnyMessages());
  415. return false;
  416. }
  417. }
  418. unsigned int BufferedTransformation::SkipMessages(unsigned int count)
  419. {
  420. if (AttachedTransformation())
  421. return AttachedTransformation()->SkipMessages(count);
  422. else
  423. return TransferMessagesTo(TheBitBucket(), count);
  424. }
  425. size_t BufferedTransformation::TransferMessagesTo2(BufferedTransformation &target, unsigned int &messageCount, const std::string &channel, bool blocking)
  426. {
  427. if (AttachedTransformation())
  428. return AttachedTransformation()->TransferMessagesTo2(target, messageCount, channel, blocking);
  429. else
  430. {
  431. unsigned int maxMessages = messageCount;
  432. for (messageCount=0; messageCount < maxMessages && AnyMessages(); messageCount++)
  433. {
  434. size_t blockedBytes;
  435. lword transferredBytes;
  436. while (AnyRetrievable())
  437. {
  438. transferredBytes = LWORD_MAX;
  439. blockedBytes = TransferTo2(target, transferredBytes, channel, blocking);
  440. if (blockedBytes > 0)
  441. return blockedBytes;
  442. }
  443. if (target.ChannelMessageEnd(channel, GetAutoSignalPropagation(), blocking))
  444. return 1;
  445. bool result = GetNextMessage();
  446. assert(result);
  447. }
  448. return 0;
  449. }
  450. }
  451. unsigned int BufferedTransformation::CopyMessagesTo(BufferedTransformation &target, unsigned int count, const std::string &channel) const
  452. {
  453. if (AttachedTransformation())
  454. return AttachedTransformation()->CopyMessagesTo(target, count, channel);
  455. else
  456. return 0;
  457. }
  458. void BufferedTransformation::SkipAll()
  459. {
  460. if (AttachedTransformation())
  461. AttachedTransformation()->SkipAll();
  462. else
  463. {
  464. while (SkipMessages()) {}
  465. while (Skip()) {}
  466. }
  467. }
  468. size_t BufferedTransformation::TransferAllTo2(BufferedTransformation &target, const std::string &channel, bool blocking)
  469. {
  470. if (AttachedTransformation())
  471. return AttachedTransformation()->TransferAllTo2(target, channel, blocking);
  472. else
  473. {
  474. assert(!NumberOfMessageSeries());
  475. unsigned int messageCount;
  476. do
  477. {
  478. messageCount = UINT_MAX;
  479. size_t blockedBytes = TransferMessagesTo2(target, messageCount, channel, blocking);
  480. if (blockedBytes)
  481. return blockedBytes;
  482. }
  483. while (messageCount != 0);
  484. lword byteCount;
  485. do
  486. {
  487. byteCount = ULONG_MAX;
  488. size_t blockedBytes = TransferTo2(target, byteCount, channel, blocking);
  489. if (blockedBytes)
  490. return blockedBytes;
  491. }
  492. while (byteCount != 0);
  493. return 0;
  494. }
  495. }
  496. void BufferedTransformation::CopyAllTo(BufferedTransformation &target, const std::string &channel) const
  497. {
  498. if (AttachedTransformation())
  499. AttachedTransformation()->CopyAllTo(target, channel);
  500. else
  501. {
  502. assert(!NumberOfMessageSeries());
  503. while (CopyMessagesTo(target, UINT_MAX, channel)) {}
  504. }
  505. }
  506. void BufferedTransformation::SetRetrievalChannel(const std::string &channel)
  507. {
  508. if (AttachedTransformation())
  509. AttachedTransformation()->SetRetrievalChannel(channel);
  510. }
  511. size_t BufferedTransformation::ChannelPutWord16(const std::string &channel, word16 value, ByteOrder order, bool blocking)
  512. {
  513. PutWord(false, order, m_buf, value);
  514. return ChannelPut(channel, m_buf, 2, blocking);
  515. }
  516. size_t BufferedTransformation::ChannelPutWord32(const std::string &channel, word32 value, ByteOrder order, bool blocking)
  517. {
  518. PutWord(false, order, m_buf, value);
  519. return ChannelPut(channel, m_buf, 4, blocking);
  520. }
  521. size_t BufferedTransformation::PutWord16(word16 value, ByteOrder order, bool blocking)
  522. {
  523. return ChannelPutWord16(DEFAULT_CHANNEL, value, order, blocking);
  524. }
  525. size_t BufferedTransformation::PutWord32(word32 value, ByteOrder order, bool blocking)
  526. {
  527. return ChannelPutWord32(DEFAULT_CHANNEL, value, order, blocking);
  528. }
  529. size_t BufferedTransformation::PeekWord16(word16 &value, ByteOrder order) const
  530. {
  531. byte buf[2] = {0, 0};
  532. size_t len = Peek(buf, 2);
  533. if (order)
  534. value = (buf[0] << 8) | buf[1];
  535. else
  536. value = (buf[1] << 8) | buf[0];
  537. return len;
  538. }
  539. size_t BufferedTransformation::PeekWord32(word32 &value, ByteOrder order) const
  540. {
  541. byte buf[4] = {0, 0, 0, 0};
  542. size_t len = Peek(buf, 4);
  543. if (order)
  544. value = (buf[0] << 24) | (buf[1] << 16) | (buf[2] << 8) | buf [3];
  545. else
  546. value = (buf[3] << 24) | (buf[2] << 16) | (buf[1] << 8) | buf [0];
  547. return len;
  548. }
  549. size_t BufferedTransformation::GetWord16(word16 &value, ByteOrder order)
  550. {
  551. return (size_t)Skip(PeekWord16(value, order));
  552. }
  553. size_t BufferedTransformation::GetWord32(word32 &value, ByteOrder order)
  554. {
  555. return (size_t)Skip(PeekWord32(value, order));
  556. }
  557. void BufferedTransformation::Attach(BufferedTransformation *newOut)
  558. {
  559. if (AttachedTransformation() && AttachedTransformation()->Attachable())
  560. AttachedTransformation()->Attach(newOut);
  561. else
  562. Detach(newOut);
  563. }
  564. void GeneratableCryptoMaterial::GenerateRandomWithKeySize(RandomNumberGenerator &rng, unsigned int keySize)
  565. {
  566. GenerateRandom(rng, MakeParameters("KeySize", (int)keySize));
  567. }
  568. class PK_DefaultEncryptionFilter : public Unflushable<Filter>
  569. {
  570. public:
  571. PK_DefaultEncryptionFilter(RandomNumberGenerator &rng, const PK_Encryptor &encryptor, BufferedTransformation *attachment, const NameValuePairs &parameters)
  572. : m_rng(rng), m_encryptor(encryptor), m_parameters(parameters)
  573. {
  574. Detach(attachment);
  575. }
  576. size_t Put2(const byte *inString, size_t length, int messageEnd, bool blocking)
  577. {
  578. FILTER_BEGIN;
  579. m_plaintextQueue.Put(inString, length);
  580. if (messageEnd)
  581. {
  582. {
  583. size_t plaintextLength;
  584. if (!SafeConvert(m_plaintextQueue.CurrentSize(), plaintextLength))
  585. throw InvalidArgument("PK_DefaultEncryptionFilter: plaintext too long");
  586. size_t ciphertextLength = m_encryptor.CiphertextLength(plaintextLength);
  587. SecByteBlock plaintext(plaintextLength);
  588. m_plaintextQueue.Get(plaintext, plaintextLength);
  589. m_ciphertext.resize(ciphertextLength);
  590. m_encryptor.Encrypt(m_rng, plaintext, plaintextLength, m_ciphertext, m_parameters);
  591. }
  592. FILTER_OUTPUT(1, m_ciphertext, m_ciphertext.size(), messageEnd);
  593. }
  594. FILTER_END_NO_MESSAGE_END;
  595. }
  596. RandomNumberGenerator &m_rng;
  597. const PK_Encryptor &m_encryptor;
  598. const NameValuePairs &m_parameters;
  599. ByteQueue m_plaintextQueue;
  600. SecByteBlock m_ciphertext;
  601. };
  602. BufferedTransformation * PK_Encryptor::CreateEncryptionFilter(RandomNumberGenerator &rng, BufferedTransformation *attachment, const NameValuePairs &parameters) const
  603. {
  604. return new PK_DefaultEncryptionFilter(rng, *this, attachment, parameters);
  605. }
  606. class PK_DefaultDecryptionFilter : public Unflushable<Filter>
  607. {
  608. public:
  609. PK_DefaultDecryptionFilter(RandomNumberGenerator &rng, const PK_Decryptor &decryptor, BufferedTransformation *attachment, const NameValuePairs &parameters)
  610. : m_rng(rng), m_decryptor(decryptor), m_parameters(parameters)
  611. {
  612. Detach(attachment);
  613. }
  614. size_t Put2(const byte *inString, size_t length, int messageEnd, bool blocking)
  615. {
  616. FILTER_BEGIN;
  617. m_ciphertextQueue.Put(inString, length);
  618. if (messageEnd)
  619. {
  620. {
  621. size_t ciphertextLength;
  622. if (!SafeConvert(m_ciphertextQueue.CurrentSize(), ciphertextLength))
  623. throw InvalidArgument("PK_DefaultDecryptionFilter: ciphertext too long");
  624. size_t maxPlaintextLength = m_decryptor.MaxPlaintextLength(ciphertextLength);
  625. SecByteBlock ciphertext(ciphertextLength);
  626. m_ciphertextQueue.Get(ciphertext, ciphertextLength);
  627. m_plaintext.resize(maxPlaintextLength);
  628. m_result = m_decryptor.Decrypt(m_rng, ciphertext, ciphertextLength, m_plaintext, m_parameters);
  629. if (!m_result.isValidCoding)
  630. throw InvalidCiphertext(m_decryptor.AlgorithmName() + ": invalid ciphertext");
  631. }
  632. FILTER_OUTPUT(1, m_plaintext, m_result.messageLength, messageEnd);
  633. }
  634. FILTER_END_NO_MESSAGE_END;
  635. }
  636. RandomNumberGenerator &m_rng;
  637. const PK_Decryptor &m_decryptor;
  638. const NameValuePairs &m_parameters;
  639. ByteQueue m_ciphertextQueue;
  640. SecByteBlock m_plaintext;
  641. DecodingResult m_result;
  642. };
  643. BufferedTransformation * PK_Decryptor::CreateDecryptionFilter(RandomNumberGenerator &rng, BufferedTransformation *attachment, const NameValuePairs &parameters) const
  644. {
  645. return new PK_DefaultDecryptionFilter(rng, *this, attachment, parameters);
  646. }
  647. size_t PK_Signer::Sign(RandomNumberGenerator &rng, PK_MessageAccumulator *messageAccumulator, byte *signature) const
  648. {
  649. std::auto_ptr<PK_MessageAccumulator> m(messageAccumulator);
  650. return SignAndRestart(rng, *m, signature, false);
  651. }
  652. size_t PK_Signer::SignMessage(RandomNumberGenerator &rng, const byte *message, size_t messageLen, byte *signature) const
  653. {
  654. std::auto_ptr<PK_MessageAccumulator> m(NewSignatureAccumulator(rng));
  655. m->Update(message, messageLen);
  656. return SignAndRestart(rng, *m, signature, false);
  657. }
  658. size_t PK_Signer::SignMessageWithRecovery(RandomNumberGenerator &rng, const byte *recoverableMessage, size_t recoverableMessageLength,
  659. const byte *nonrecoverableMessage, size_t nonrecoverableMessageLength, byte *signature) const
  660. {
  661. std::auto_ptr<PK_MessageAccumulator> m(NewSignatureAccumulator(rng));
  662. InputRecoverableMessage(*m, recoverableMessage, recoverableMessageLength);
  663. m->Update(nonrecoverableMessage, nonrecoverableMessageLength);
  664. return SignAndRestart(rng, *m, signature, false);
  665. }
  666. bool PK_Verifier::Verify(PK_MessageAccumulator *messageAccumulator) const
  667. {
  668. std::auto_ptr<PK_MessageAccumulator> m(messageAccumulator);
  669. return VerifyAndRestart(*m);
  670. }
  671. bool PK_Verifier::VerifyMessage(const byte *message, size_t messageLen, const byte *signature, size_t signatureLength) const
  672. {
  673. std::auto_ptr<PK_MessageAccumulator> m(NewVerificationAccumulator());
  674. InputSignature(*m, signature, signatureLength);
  675. m->Update(message, messageLen);
  676. return VerifyAndRestart(*m);
  677. }
  678. DecodingResult PK_Verifier::Recover(byte *recoveredMessage, PK_MessageAccumulator *messageAccumulator) const
  679. {
  680. std::auto_ptr<PK_MessageAccumulator> m(messageAccumulator);
  681. return RecoverAndRestart(recoveredMessage, *m);
  682. }
  683. DecodingResult PK_Verifier::RecoverMessage(byte *recoveredMessage,
  684. const byte *nonrecoverableMessage, size_t nonrecoverableMessageLength,
  685. const byte *signature, size_t signatureLength) const
  686. {
  687. std::auto_ptr<PK_MessageAccumulator> m(NewVerificationAccumulator());
  688. InputSignature(*m, signature, signatureLength);
  689. m->Update(nonrecoverableMessage, nonrecoverableMessageLength);
  690. return RecoverAndRestart(recoveredMessage, *m);
  691. }
  692. void SimpleKeyAgreementDomain::GenerateKeyPair(RandomNumberGenerator &rng, byte *privateKey, byte *publicKey) const
  693. {
  694. GeneratePrivateKey(rng, privateKey);
  695. GeneratePublicKey(rng, privateKey, publicKey);
  696. }
  697. void AuthenticatedKeyAgreementDomain::GenerateStaticKeyPair(RandomNumberGenerator &rng, byte *privateKey, byte *publicKey) const
  698. {
  699. GenerateStaticPrivateKey(rng, privateKey);
  700. GenerateStaticPublicKey(rng, privateKey, publicKey);
  701. }
  702. void AuthenticatedKeyAgreementDomain::GenerateEphemeralKeyPair(RandomNumberGenerator &rng, byte *privateKey, byte *publicKey) const
  703. {
  704. GenerateEphemeralPrivateKey(rng, privateKey);
  705. GenerateEphemeralPublicKey(rng, privateKey, publicKey);
  706. }
  707. NAMESPACE_END
  708. #endif