Team Fortress 2 Source Code as on 22/4/2020
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

834 lines
24 KiB

  1. // cryptlib.cpp - written and placed in the public domain by Wei Dai
  2. #include "pch.h"
  3. #ifndef CRYPTOPP_IMPORTS
  4. #include "cryptlib.h"
  5. #include "misc.h"
  6. #include "filters.h"
  7. #include "algparam.h"
  8. #include "fips140.h"
  9. #include "argnames.h"
  10. #include "fltrimpl.h"
  11. #include "trdlocal.h"
  12. #include "osrng.h"
  13. #include <memory>
  14. NAMESPACE_BEGIN(CryptoPP)
  15. CRYPTOPP_COMPILE_ASSERT(sizeof(byte) == 1);
  16. CRYPTOPP_COMPILE_ASSERT(sizeof(word16) == 2);
  17. CRYPTOPP_COMPILE_ASSERT(sizeof(word32) == 4);
  18. CRYPTOPP_COMPILE_ASSERT(sizeof(word64) == 8);
  19. #ifdef CRYPTOPP_NATIVE_DWORD_AVAILABLE
  20. CRYPTOPP_COMPILE_ASSERT(sizeof(dword) == 2*sizeof(word));
  21. #endif
  22. // VALVE, changed DEFAULT_CHANNEL to a basic type from std::string
  23. const char * DEFAULT_CHANNEL = "";
  24. const std::string AAD_CHANNEL = "AAD";
  25. const std::string &BufferedTransformation::NULL_CHANNEL = DEFAULT_CHANNEL;
  26. class NullNameValuePairs : public NameValuePairs
  27. {
  28. public:
  29. NullNameValuePairs() {}
  30. bool GetVoidValue(const char *name, const std::type_info &valueType, void *pValue) const {return false;}
  31. };
  32. // VALVE: Our debug allocator doesn't much care for global objects like this, it registers them
  33. // as a memory leak during validation. So make them const.
  34. //simple_ptr<NullNameValuePairs> s_pNullNameValuePairs(new NullNameValuePairs);
  35. //const NameValuePairs &g_nullNameValuePairs = *s_pNullNameValuePairs.m_p;
  36. const NullNameValuePairs s_NullNameValuePairs;
  37. const NameValuePairs &g_nullNameValuePairs = s_NullNameValuePairs;
  38. BufferedTransformation & TheBitBucket()
  39. {
  40. static BitBucket bitBucket;
  41. return bitBucket;
  42. }
  43. Algorithm::Algorithm(bool checkSelfTestStatus)
  44. {
  45. if (checkSelfTestStatus && FIPS_140_2_ComplianceEnabled())
  46. {
  47. if (GetPowerUpSelfTestStatus() == POWER_UP_SELF_TEST_NOT_DONE && !PowerUpSelfTestInProgressOnThisThread())
  48. throw SelfTestFailure("Cryptographic algorithms are disabled before the power-up self tests are performed.");
  49. if (GetPowerUpSelfTestStatus() == POWER_UP_SELF_TEST_FAILED)
  50. throw SelfTestFailure("Cryptographic algorithms are disabled after a power-up self test failed.");
  51. }
  52. }
  53. void SimpleKeyingInterface::SetKey(const byte *key, size_t length, const NameValuePairs &params)
  54. {
  55. this->ThrowIfInvalidKeyLength(length);
  56. this->UncheckedSetKey(key, (unsigned int)length, params);
  57. }
  58. void SimpleKeyingInterface::SetKeyWithRounds(const byte *key, size_t length, int rounds)
  59. {
  60. SetKey(key, length, MakeParameters(Name::Rounds(), rounds));
  61. }
  62. void SimpleKeyingInterface::SetKeyWithIV(const byte *key, size_t length, const byte *iv, size_t ivLength)
  63. {
  64. SetKey(key, length, MakeParameters(Name::IV(), ConstByteArrayParameter(iv, ivLength)));
  65. }
  66. void SimpleKeyingInterface::ThrowIfInvalidKeyLength(size_t length)
  67. {
  68. if (!IsValidKeyLength(length))
  69. throw InvalidKeyLength(GetAlgorithm().AlgorithmName(), length);
  70. }
  71. void SimpleKeyingInterface::ThrowIfResynchronizable()
  72. {
  73. if (IsResynchronizable())
  74. throw InvalidArgument(GetAlgorithm().AlgorithmName() + ": this object requires an IV");
  75. }
  76. void SimpleKeyingInterface::ThrowIfInvalidIV(const byte *iv)
  77. {
  78. if (!iv && IVRequirement() == UNPREDICTABLE_RANDOM_IV)
  79. throw InvalidArgument(GetAlgorithm().AlgorithmName() + ": this object cannot use a null IV");
  80. }
  81. size_t SimpleKeyingInterface::ThrowIfInvalidIVLength(int size)
  82. {
  83. if (size < 0)
  84. return IVSize();
  85. else if ((size_t)size < MinIVLength())
  86. throw InvalidArgument(GetAlgorithm().AlgorithmName() + ": IV length " + IntToString(size) + " is less than the minimum of " + IntToString(MinIVLength()));
  87. else if ((size_t)size > MaxIVLength())
  88. throw InvalidArgument(GetAlgorithm().AlgorithmName() + ": IV length " + IntToString(size) + " exceeds the maximum of " + IntToString(MaxIVLength()));
  89. else
  90. return size;
  91. }
  92. const byte * SimpleKeyingInterface::GetIVAndThrowIfInvalid(const NameValuePairs &params, size_t &size)
  93. {
  94. ConstByteArrayParameter ivWithLength;
  95. const byte *iv;
  96. bool found = false;
  97. try {found = params.GetValue(Name::IV(), ivWithLength);}
  98. catch (const NameValuePairs::ValueTypeMismatch &) {}
  99. if (found)
  100. {
  101. iv = ivWithLength.begin();
  102. ThrowIfInvalidIV(iv);
  103. size = ThrowIfInvalidIVLength((int)ivWithLength.size());
  104. return iv;
  105. }
  106. else if (params.GetValue(Name::IV(), iv))
  107. {
  108. ThrowIfInvalidIV(iv);
  109. size = IVSize();
  110. return iv;
  111. }
  112. else
  113. {
  114. ThrowIfResynchronizable();
  115. size = 0;
  116. return NULL;
  117. }
  118. }
  119. void SimpleKeyingInterface::GetNextIV(RandomNumberGenerator &rng, byte *IV)
  120. {
  121. rng.GenerateBlock(IV, IVSize());
  122. }
  123. size_t BlockTransformation::AdvancedProcessBlocks(const byte *inBlocks, const byte *xorBlocks, byte *outBlocks, size_t length, word32 flags) const
  124. {
  125. size_t blockSize = BlockSize();
  126. size_t inIncrement = (flags & (BT_InBlockIsCounter|BT_DontIncrementInOutPointers)) ? 0 : blockSize;
  127. size_t xorIncrement = xorBlocks ? blockSize : 0;
  128. size_t outIncrement = (flags & BT_DontIncrementInOutPointers) ? 0 : blockSize;
  129. if (flags & BT_ReverseDirection)
  130. {
  131. assert(length % blockSize == 0);
  132. inBlocks += length - blockSize;
  133. xorBlocks += length - blockSize;
  134. outBlocks += length - blockSize;
  135. inIncrement = 0-inIncrement;
  136. xorIncrement = 0-xorIncrement;
  137. outIncrement = 0-outIncrement;
  138. }
  139. while (length >= blockSize)
  140. {
  141. if (flags & BT_XorInput)
  142. {
  143. xorbuf(outBlocks, xorBlocks, inBlocks, blockSize);
  144. ProcessBlock(outBlocks);
  145. }
  146. else
  147. ProcessAndXorBlock(inBlocks, xorBlocks, outBlocks);
  148. if (flags & BT_InBlockIsCounter)
  149. const_cast<byte *>(inBlocks)[blockSize-1]++;
  150. inBlocks += inIncrement;
  151. outBlocks += outIncrement;
  152. xorBlocks += xorIncrement;
  153. length -= blockSize;
  154. }
  155. return length;
  156. }
  157. unsigned int BlockTransformation::OptimalDataAlignment() const
  158. {
  159. return GetAlignmentOf<word32>();
  160. }
  161. unsigned int StreamTransformation::OptimalDataAlignment() const
  162. {
  163. return GetAlignmentOf<word32>();
  164. }
  165. unsigned int HashTransformation::OptimalDataAlignment() const
  166. {
  167. return GetAlignmentOf<word32>();
  168. }
  169. void StreamTransformation::ProcessLastBlock(byte *outString, const byte *inString, size_t length)
  170. {
  171. assert(MinLastBlockSize() == 0); // this function should be overriden otherwise
  172. if (length == MandatoryBlockSize())
  173. ProcessData(outString, inString, length);
  174. else if (length != 0)
  175. throw NotImplemented(AlgorithmName() + ": this object does't support a special last block");
  176. }
  177. void AuthenticatedSymmetricCipher::SpecifyDataLengths(lword headerLength, lword messageLength, lword footerLength)
  178. {
  179. if (headerLength > MaxHeaderLength())
  180. throw InvalidArgument(GetAlgorithm().AlgorithmName() + ": header length " + IntToString(headerLength) + " exceeds the maximum of " + IntToString(MaxHeaderLength()));
  181. if (messageLength > MaxMessageLength())
  182. throw InvalidArgument(GetAlgorithm().AlgorithmName() + ": message length " + IntToString(messageLength) + " exceeds the maximum of " + IntToString(MaxMessageLength()));
  183. if (footerLength > MaxFooterLength())
  184. throw InvalidArgument(GetAlgorithm().AlgorithmName() + ": footer length " + IntToString(footerLength) + " exceeds the maximum of " + IntToString(MaxFooterLength()));
  185. UncheckedSpecifyDataLengths(headerLength, messageLength, footerLength);
  186. }
  187. void AuthenticatedSymmetricCipher::EncryptAndAuthenticate(byte *ciphertext, byte *mac, size_t macSize, const byte *iv, int ivLength, const byte *header, size_t headerLength, const byte *message, size_t messageLength)
  188. {
  189. Resynchronize(iv, ivLength);
  190. SpecifyDataLengths(headerLength, messageLength);
  191. Update(header, headerLength);
  192. ProcessString(ciphertext, message, messageLength);
  193. TruncatedFinal(mac, macSize);
  194. }
  195. bool AuthenticatedSymmetricCipher::DecryptAndVerify(byte *message, const byte *mac, size_t macLength, const byte *iv, int ivLength, const byte *header, size_t headerLength, const byte *ciphertext, size_t ciphertextLength)
  196. {
  197. Resynchronize(iv, ivLength);
  198. SpecifyDataLengths(headerLength, ciphertextLength);
  199. Update(header, headerLength);
  200. ProcessString(message, ciphertext, ciphertextLength);
  201. return TruncatedVerify(mac, macLength);
  202. }
  203. unsigned int RandomNumberGenerator::GenerateBit()
  204. {
  205. return GenerateByte() & 1;
  206. }
  207. byte RandomNumberGenerator::GenerateByte()
  208. {
  209. byte b;
  210. GenerateBlock(&b, 1);
  211. return b;
  212. }
  213. word32 RandomNumberGenerator::GenerateWord32(word32 min, word32 max)
  214. {
  215. word32 range = max-min;
  216. const int maxBits = BitPrecision(range);
  217. word32 value;
  218. do
  219. {
  220. GenerateBlock((byte *)&value, sizeof(value));
  221. value = Crop(value, maxBits);
  222. } while (value > range);
  223. return value+min;
  224. }
  225. void RandomNumberGenerator::GenerateBlock(byte *output, size_t size)
  226. {
  227. ArraySink s(output, size);
  228. GenerateIntoBufferedTransformation(s, DEFAULT_CHANNEL, size);
  229. }
  230. void RandomNumberGenerator::DiscardBytes(size_t n)
  231. {
  232. GenerateIntoBufferedTransformation(TheBitBucket(), DEFAULT_CHANNEL, n);
  233. }
  234. void RandomNumberGenerator::GenerateIntoBufferedTransformation(BufferedTransformation &target, const std::string &channel, lword length)
  235. {
  236. FixedSizeSecBlock<byte, 256> buffer;
  237. while (length)
  238. {
  239. size_t len = UnsignedMin(buffer.size(), length);
  240. GenerateBlock(buffer, len);
  241. target.ChannelPut(channel, buffer, len);
  242. length -= len;
  243. }
  244. }
  245. //! see NullRNG()
  246. class ClassNullRNG : public RandomNumberGenerator
  247. {
  248. public:
  249. std::string AlgorithmName() const {return "NullRNG";}
  250. void GenerateBlock(byte *output, size_t size) {throw NotImplemented("NullRNG: NullRNG should only be passed to functions that don't need to generate random bytes");}
  251. };
  252. RandomNumberGenerator & NullRNG()
  253. {
  254. static ClassNullRNG s_nullRNG;
  255. return s_nullRNG;
  256. }
  257. bool HashTransformation::TruncatedVerify(const byte *digestIn, size_t digestLength)
  258. {
  259. ThrowIfInvalidTruncatedSize(digestLength);
  260. SecByteBlock digest(digestLength);
  261. TruncatedFinal(digest, digestLength);
  262. return VerifyBufsEqual(digest, digestIn, digestLength);
  263. }
  264. void HashTransformation::ThrowIfInvalidTruncatedSize(size_t size) const
  265. {
  266. if (size > DigestSize())
  267. throw InvalidArgument("HashTransformation: can't truncate a " + IntToString(DigestSize()) + " byte digest to " + IntToString(size) + " bytes");
  268. }
  269. unsigned int BufferedTransformation::GetMaxWaitObjectCount() const
  270. {
  271. const BufferedTransformation *t = AttachedTransformation();
  272. return t ? t->GetMaxWaitObjectCount() : 0;
  273. }
  274. void BufferedTransformation::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
  275. {
  276. BufferedTransformation *t = AttachedTransformation();
  277. if (t)
  278. t->GetWaitObjects(container, callStack); // reduce clutter by not adding to stack here
  279. }
  280. void BufferedTransformation::Initialize(const NameValuePairs &parameters, int propagation)
  281. {
  282. assert(!AttachedTransformation());
  283. IsolatedInitialize(parameters);
  284. }
  285. bool BufferedTransformation::Flush(bool hardFlush, int propagation, bool blocking)
  286. {
  287. assert(!AttachedTransformation());
  288. return IsolatedFlush(hardFlush, blocking);
  289. }
  290. bool BufferedTransformation::MessageSeriesEnd(int propagation, bool blocking)
  291. {
  292. assert(!AttachedTransformation());
  293. return IsolatedMessageSeriesEnd(blocking);
  294. }
  295. byte * BufferedTransformation::ChannelCreatePutSpace(const std::string &channel, size_t &size)
  296. {
  297. if (channel.empty())
  298. return CreatePutSpace(size);
  299. else
  300. throw NoChannelSupport(AlgorithmName());
  301. }
  302. size_t BufferedTransformation::ChannelPut2(const std::string &channel, const byte *begin, size_t length, int messageEnd, bool blocking)
  303. {
  304. if (channel.empty())
  305. return Put2(begin, length, messageEnd, blocking);
  306. else
  307. throw NoChannelSupport(AlgorithmName());
  308. }
  309. size_t BufferedTransformation::ChannelPutModifiable2(const std::string &channel, byte *begin, size_t length, int messageEnd, bool blocking)
  310. {
  311. if (channel.empty())
  312. return PutModifiable2(begin, length, messageEnd, blocking);
  313. else
  314. return ChannelPut2(channel, begin, length, messageEnd, blocking);
  315. }
  316. bool BufferedTransformation::ChannelFlush(const std::string &channel, bool completeFlush, int propagation, bool blocking)
  317. {
  318. if (channel.empty())
  319. return Flush(completeFlush, propagation, blocking);
  320. else
  321. throw NoChannelSupport(AlgorithmName());
  322. }
  323. bool BufferedTransformation::ChannelMessageSeriesEnd(const std::string &channel, int propagation, bool blocking)
  324. {
  325. if (channel.empty())
  326. return MessageSeriesEnd(propagation, blocking);
  327. else
  328. throw NoChannelSupport(AlgorithmName());
  329. }
  330. lword BufferedTransformation::MaxRetrievable() const
  331. {
  332. if (AttachedTransformation())
  333. return AttachedTransformation()->MaxRetrievable();
  334. else
  335. return CopyTo(TheBitBucket());
  336. }
  337. bool BufferedTransformation::AnyRetrievable() const
  338. {
  339. if (AttachedTransformation())
  340. return AttachedTransformation()->AnyRetrievable();
  341. else
  342. {
  343. byte b;
  344. return Peek(b) != 0;
  345. }
  346. }
  347. size_t BufferedTransformation::Get(byte &outByte)
  348. {
  349. if (AttachedTransformation())
  350. return AttachedTransformation()->Get(outByte);
  351. else
  352. return Get(&outByte, 1);
  353. }
  354. size_t BufferedTransformation::Get(byte *outString, size_t getMax)
  355. {
  356. if (AttachedTransformation())
  357. return AttachedTransformation()->Get(outString, getMax);
  358. else
  359. {
  360. ArraySink arraySink(outString, getMax);
  361. return (size_t)TransferTo(arraySink, getMax);
  362. }
  363. }
  364. size_t BufferedTransformation::Peek(byte &outByte) const
  365. {
  366. if (AttachedTransformation())
  367. return AttachedTransformation()->Peek(outByte);
  368. else
  369. return Peek(&outByte, 1);
  370. }
  371. size_t BufferedTransformation::Peek(byte *outString, size_t peekMax) const
  372. {
  373. if (AttachedTransformation())
  374. return AttachedTransformation()->Peek(outString, peekMax);
  375. else
  376. {
  377. ArraySink arraySink(outString, peekMax);
  378. return (size_t)CopyTo(arraySink, peekMax);
  379. }
  380. }
  381. lword BufferedTransformation::Skip(lword skipMax)
  382. {
  383. if (AttachedTransformation())
  384. return AttachedTransformation()->Skip(skipMax);
  385. else
  386. return TransferTo(TheBitBucket(), skipMax);
  387. }
  388. lword BufferedTransformation::TotalBytesRetrievable() const
  389. {
  390. if (AttachedTransformation())
  391. return AttachedTransformation()->TotalBytesRetrievable();
  392. else
  393. return MaxRetrievable();
  394. }
  395. unsigned int BufferedTransformation::NumberOfMessages() const
  396. {
  397. if (AttachedTransformation())
  398. return AttachedTransformation()->NumberOfMessages();
  399. else
  400. return CopyMessagesTo(TheBitBucket());
  401. }
  402. bool BufferedTransformation::AnyMessages() const
  403. {
  404. if (AttachedTransformation())
  405. return AttachedTransformation()->AnyMessages();
  406. else
  407. return NumberOfMessages() != 0;
  408. }
  409. bool BufferedTransformation::GetNextMessage()
  410. {
  411. if (AttachedTransformation())
  412. return AttachedTransformation()->GetNextMessage();
  413. else
  414. {
  415. assert(!AnyMessages());
  416. return false;
  417. }
  418. }
  419. unsigned int BufferedTransformation::SkipMessages(unsigned int count)
  420. {
  421. if (AttachedTransformation())
  422. return AttachedTransformation()->SkipMessages(count);
  423. else
  424. return TransferMessagesTo(TheBitBucket(), count);
  425. }
  426. size_t BufferedTransformation::TransferMessagesTo2(BufferedTransformation &target, unsigned int &messageCount, const std::string &channel, bool blocking)
  427. {
  428. if (AttachedTransformation())
  429. return AttachedTransformation()->TransferMessagesTo2(target, messageCount, channel, blocking);
  430. else
  431. {
  432. unsigned int maxMessages = messageCount;
  433. for (messageCount=0; messageCount < maxMessages && AnyMessages(); messageCount++)
  434. {
  435. size_t blockedBytes;
  436. lword transferredBytes;
  437. while (AnyRetrievable())
  438. {
  439. transferredBytes = LWORD_MAX;
  440. blockedBytes = TransferTo2(target, transferredBytes, channel, blocking);
  441. if (blockedBytes > 0)
  442. return blockedBytes;
  443. }
  444. if (target.ChannelMessageEnd(channel, GetAutoSignalPropagation(), blocking))
  445. return 1;
  446. bool result = GetNextMessage();
  447. assert(result);
  448. }
  449. return 0;
  450. }
  451. }
  452. unsigned int BufferedTransformation::CopyMessagesTo(BufferedTransformation &target, unsigned int count, const std::string &channel) const
  453. {
  454. if (AttachedTransformation())
  455. return AttachedTransformation()->CopyMessagesTo(target, count, channel);
  456. else
  457. return 0;
  458. }
  459. void BufferedTransformation::SkipAll()
  460. {
  461. if (AttachedTransformation())
  462. AttachedTransformation()->SkipAll();
  463. else
  464. {
  465. while (SkipMessages()) {}
  466. while (Skip()) {}
  467. }
  468. }
  469. size_t BufferedTransformation::TransferAllTo2(BufferedTransformation &target, const std::string &channel, bool blocking)
  470. {
  471. if (AttachedTransformation())
  472. return AttachedTransformation()->TransferAllTo2(target, channel, blocking);
  473. else
  474. {
  475. assert(!NumberOfMessageSeries());
  476. unsigned int messageCount;
  477. do
  478. {
  479. messageCount = UINT_MAX;
  480. size_t blockedBytes = TransferMessagesTo2(target, messageCount, channel, blocking);
  481. if (blockedBytes)
  482. return blockedBytes;
  483. }
  484. while (messageCount != 0);
  485. lword byteCount;
  486. do
  487. {
  488. byteCount = ULONG_MAX;
  489. size_t blockedBytes = TransferTo2(target, byteCount, channel, blocking);
  490. if (blockedBytes)
  491. return blockedBytes;
  492. }
  493. while (byteCount != 0);
  494. return 0;
  495. }
  496. }
  497. void BufferedTransformation::CopyAllTo(BufferedTransformation &target, const std::string &channel) const
  498. {
  499. if (AttachedTransformation())
  500. AttachedTransformation()->CopyAllTo(target, channel);
  501. else
  502. {
  503. assert(!NumberOfMessageSeries());
  504. while (CopyMessagesTo(target, UINT_MAX, channel)) {}
  505. }
  506. }
  507. void BufferedTransformation::SetRetrievalChannel(const std::string &channel)
  508. {
  509. if (AttachedTransformation())
  510. AttachedTransformation()->SetRetrievalChannel(channel);
  511. }
  512. size_t BufferedTransformation::ChannelPutWord16(const std::string &channel, word16 value, ByteOrder order, bool blocking)
  513. {
  514. PutWord(false, order, m_buf, value);
  515. return ChannelPut(channel, m_buf, 2, blocking);
  516. }
  517. size_t BufferedTransformation::ChannelPutWord32(const std::string &channel, word32 value, ByteOrder order, bool blocking)
  518. {
  519. PutWord(false, order, m_buf, value);
  520. return ChannelPut(channel, m_buf, 4, blocking);
  521. }
  522. size_t BufferedTransformation::PutWord16(word16 value, ByteOrder order, bool blocking)
  523. {
  524. return ChannelPutWord16(DEFAULT_CHANNEL, value, order, blocking);
  525. }
  526. size_t BufferedTransformation::PutWord32(word32 value, ByteOrder order, bool blocking)
  527. {
  528. return ChannelPutWord32(DEFAULT_CHANNEL, value, order, blocking);
  529. }
  530. size_t BufferedTransformation::PeekWord16(word16 &value, ByteOrder order) const
  531. {
  532. byte buf[2] = {0, 0};
  533. size_t len = Peek(buf, 2);
  534. if (order)
  535. value = (buf[0] << 8) | buf[1];
  536. else
  537. value = (buf[1] << 8) | buf[0];
  538. return len;
  539. }
  540. size_t BufferedTransformation::PeekWord32(word32 &value, ByteOrder order) const
  541. {
  542. byte buf[4] = {0, 0, 0, 0};
  543. size_t len = Peek(buf, 4);
  544. if (order)
  545. value = (buf[0] << 24) | (buf[1] << 16) | (buf[2] << 8) | buf [3];
  546. else
  547. value = (buf[3] << 24) | (buf[2] << 16) | (buf[1] << 8) | buf [0];
  548. return len;
  549. }
  550. size_t BufferedTransformation::GetWord16(word16 &value, ByteOrder order)
  551. {
  552. return (size_t)Skip(PeekWord16(value, order));
  553. }
  554. size_t BufferedTransformation::GetWord32(word32 &value, ByteOrder order)
  555. {
  556. return (size_t)Skip(PeekWord32(value, order));
  557. }
  558. void BufferedTransformation::Attach(BufferedTransformation *newOut)
  559. {
  560. if (AttachedTransformation() && AttachedTransformation()->Attachable())
  561. AttachedTransformation()->Attach(newOut);
  562. else
  563. Detach(newOut);
  564. }
  565. void GeneratableCryptoMaterial::GenerateRandomWithKeySize(RandomNumberGenerator &rng, unsigned int keySize)
  566. {
  567. GenerateRandom(rng, MakeParameters("KeySize", (int)keySize));
  568. }
  569. class PK_DefaultEncryptionFilter : public Unflushable<Filter>
  570. {
  571. public:
  572. PK_DefaultEncryptionFilter(RandomNumberGenerator &rng, const PK_Encryptor &encryptor, BufferedTransformation *attachment, const NameValuePairs &parameters)
  573. : m_rng(rng), m_encryptor(encryptor), m_parameters(parameters)
  574. {
  575. Detach(attachment);
  576. }
  577. size_t Put2(const byte *inString, size_t length, int messageEnd, bool blocking)
  578. {
  579. FILTER_BEGIN;
  580. m_plaintextQueue.Put(inString, length);
  581. if (messageEnd)
  582. {
  583. {
  584. size_t plaintextLength;
  585. if (!SafeConvert(m_plaintextQueue.CurrentSize(), plaintextLength))
  586. throw InvalidArgument("PK_DefaultEncryptionFilter: plaintext too long");
  587. size_t ciphertextLength = m_encryptor.CiphertextLength(plaintextLength);
  588. SecByteBlock plaintext(plaintextLength);
  589. m_plaintextQueue.Get(plaintext, plaintextLength);
  590. m_ciphertext.resize(ciphertextLength);
  591. m_encryptor.Encrypt(m_rng, plaintext, plaintextLength, m_ciphertext, m_parameters);
  592. }
  593. FILTER_OUTPUT(1, m_ciphertext, m_ciphertext.size(), messageEnd);
  594. }
  595. FILTER_END_NO_MESSAGE_END;
  596. }
  597. RandomNumberGenerator &m_rng;
  598. const PK_Encryptor &m_encryptor;
  599. const NameValuePairs &m_parameters;
  600. ByteQueue m_plaintextQueue;
  601. SecByteBlock m_ciphertext;
  602. };
  603. BufferedTransformation * PK_Encryptor::CreateEncryptionFilter(RandomNumberGenerator &rng, BufferedTransformation *attachment, const NameValuePairs &parameters) const
  604. {
  605. return new PK_DefaultEncryptionFilter(rng, *this, attachment, parameters);
  606. }
  607. class PK_DefaultDecryptionFilter : public Unflushable<Filter>
  608. {
  609. public:
  610. PK_DefaultDecryptionFilter(RandomNumberGenerator &rng, const PK_Decryptor &decryptor, BufferedTransformation *attachment, const NameValuePairs &parameters)
  611. : m_rng(rng), m_decryptor(decryptor), m_parameters(parameters)
  612. {
  613. Detach(attachment);
  614. }
  615. size_t Put2(const byte *inString, size_t length, int messageEnd, bool blocking)
  616. {
  617. FILTER_BEGIN;
  618. m_ciphertextQueue.Put(inString, length);
  619. if (messageEnd)
  620. {
  621. {
  622. size_t ciphertextLength;
  623. if (!SafeConvert(m_ciphertextQueue.CurrentSize(), ciphertextLength))
  624. throw InvalidArgument("PK_DefaultDecryptionFilter: ciphertext too long");
  625. size_t maxPlaintextLength = m_decryptor.MaxPlaintextLength(ciphertextLength);
  626. SecByteBlock ciphertext(ciphertextLength);
  627. m_ciphertextQueue.Get(ciphertext, ciphertextLength);
  628. m_plaintext.resize(maxPlaintextLength);
  629. m_result = m_decryptor.Decrypt(m_rng, ciphertext, ciphertextLength, m_plaintext, m_parameters);
  630. if (!m_result.isValidCoding)
  631. throw InvalidCiphertext(m_decryptor.AlgorithmName() + ": invalid ciphertext");
  632. }
  633. FILTER_OUTPUT(1, m_plaintext, m_result.messageLength, messageEnd);
  634. }
  635. FILTER_END_NO_MESSAGE_END;
  636. }
  637. RandomNumberGenerator &m_rng;
  638. const PK_Decryptor &m_decryptor;
  639. const NameValuePairs &m_parameters;
  640. ByteQueue m_ciphertextQueue;
  641. SecByteBlock m_plaintext;
  642. DecodingResult m_result;
  643. };
  644. BufferedTransformation * PK_Decryptor::CreateDecryptionFilter(RandomNumberGenerator &rng, BufferedTransformation *attachment, const NameValuePairs &parameters) const
  645. {
  646. return new PK_DefaultDecryptionFilter(rng, *this, attachment, parameters);
  647. }
  648. size_t PK_Signer::Sign(RandomNumberGenerator &rng, PK_MessageAccumulator *messageAccumulator, byte *signature) const
  649. {
  650. std::auto_ptr<PK_MessageAccumulator> m(messageAccumulator);
  651. return SignAndRestart(rng, *m, signature, false);
  652. }
  653. size_t PK_Signer::SignMessage(RandomNumberGenerator &rng, const byte *message, size_t messageLen, byte *signature) const
  654. {
  655. std::auto_ptr<PK_MessageAccumulator> m(NewSignatureAccumulator(rng));
  656. m->Update(message, messageLen);
  657. return SignAndRestart(rng, *m, signature, false);
  658. }
  659. size_t PK_Signer::SignMessageWithRecovery(RandomNumberGenerator &rng, const byte *recoverableMessage, size_t recoverableMessageLength,
  660. const byte *nonrecoverableMessage, size_t nonrecoverableMessageLength, byte *signature) const
  661. {
  662. std::auto_ptr<PK_MessageAccumulator> m(NewSignatureAccumulator(rng));
  663. InputRecoverableMessage(*m, recoverableMessage, recoverableMessageLength);
  664. m->Update(nonrecoverableMessage, nonrecoverableMessageLength);
  665. return SignAndRestart(rng, *m, signature, false);
  666. }
  667. bool PK_Verifier::Verify(PK_MessageAccumulator *messageAccumulator) const
  668. {
  669. std::auto_ptr<PK_MessageAccumulator> m(messageAccumulator);
  670. return VerifyAndRestart(*m);
  671. }
  672. bool PK_Verifier::VerifyMessage(const byte *message, size_t messageLen, const byte *signature, size_t signatureLength) const
  673. {
  674. std::auto_ptr<PK_MessageAccumulator> m(NewVerificationAccumulator());
  675. InputSignature(*m, signature, signatureLength);
  676. m->Update(message, messageLen);
  677. return VerifyAndRestart(*m);
  678. }
  679. DecodingResult PK_Verifier::Recover(byte *recoveredMessage, PK_MessageAccumulator *messageAccumulator) const
  680. {
  681. std::auto_ptr<PK_MessageAccumulator> m(messageAccumulator);
  682. return RecoverAndRestart(recoveredMessage, *m);
  683. }
  684. DecodingResult PK_Verifier::RecoverMessage(byte *recoveredMessage,
  685. const byte *nonrecoverableMessage, size_t nonrecoverableMessageLength,
  686. const byte *signature, size_t signatureLength) const
  687. {
  688. std::auto_ptr<PK_MessageAccumulator> m(NewVerificationAccumulator());
  689. InputSignature(*m, signature, signatureLength);
  690. m->Update(nonrecoverableMessage, nonrecoverableMessageLength);
  691. return RecoverAndRestart(recoveredMessage, *m);
  692. }
  693. void SimpleKeyAgreementDomain::GenerateKeyPair(RandomNumberGenerator &rng, byte *privateKey, byte *publicKey) const
  694. {
  695. GeneratePrivateKey(rng, privateKey);
  696. GeneratePublicKey(rng, privateKey, publicKey);
  697. }
  698. void AuthenticatedKeyAgreementDomain::GenerateStaticKeyPair(RandomNumberGenerator &rng, byte *privateKey, byte *publicKey) const
  699. {
  700. GenerateStaticPrivateKey(rng, privateKey);
  701. GenerateStaticPublicKey(rng, privateKey, publicKey);
  702. }
  703. void AuthenticatedKeyAgreementDomain::GenerateEphemeralKeyPair(RandomNumberGenerator &rng, byte *privateKey, byte *publicKey) const
  704. {
  705. GenerateEphemeralPrivateKey(rng, privateKey);
  706. GenerateEphemeralPublicKey(rng, privateKey, publicKey);
  707. }
  708. NAMESPACE_END
  709. #endif