Team Fortress 2 Source Code as on 22/4/2020
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

813 lines
23 KiB

  1. // datatest.cpp - written and placed in the public domain by Wei Dai
  2. #define CRYPTOPP_DEFAULT_NO_DLL
  3. #define CRYPTOPP_ENABLE_NAMESPACE_WEAK 1
  4. #include "cryptlib.h"
  5. #include "factory.h"
  6. #include "integer.h"
  7. #include "filters.h"
  8. #include "hex.h"
  9. #include "randpool.h"
  10. #include "files.h"
  11. #include "trunhash.h"
  12. #include "queue.h"
  13. #include "smartptr.h"
  14. #include "validate.h"
  15. #include "hkdf.h"
  16. #include "stdcpp.h"
  17. #include <iostream>
  18. // Aggressive stack checking with VS2005 SP1 and above.
  19. #if (CRYPTOPP_MSC_VERSION >= 1410)
  20. # pragma strict_gs_check (on)
  21. #endif
  22. #if defined(__COVERITY__)
  23. extern "C" void __coverity_tainted_data_sanitize__(void *);
  24. #endif
  25. USING_NAMESPACE(CryptoPP)
  26. USING_NAMESPACE(std)
  27. typedef std::map<std::string, std::string> TestData;
  28. static bool s_thorough = false;
  29. class TestFailure : public Exception
  30. {
  31. public:
  32. TestFailure() : Exception(OTHER_ERROR, "Validation test failed") {}
  33. };
  34. static const TestData *s_currentTestData = NULL;
  35. static void OutputTestData(const TestData &v)
  36. {
  37. for (TestData::const_iterator i = v.begin(); i != v.end(); ++i)
  38. {
  39. cerr << i->first << ": " << i->second << endl;
  40. }
  41. }
  42. static void SignalTestFailure()
  43. {
  44. OutputTestData(*s_currentTestData);
  45. throw TestFailure();
  46. }
  47. static void SignalTestError()
  48. {
  49. OutputTestData(*s_currentTestData);
  50. throw Exception(Exception::OTHER_ERROR, "Unexpected error during validation test");
  51. }
  52. bool DataExists(const TestData &data, const char *name)
  53. {
  54. TestData::const_iterator i = data.find(name);
  55. return (i != data.end());
  56. }
  57. const std::string & GetRequiredDatum(const TestData &data, const char *name)
  58. {
  59. TestData::const_iterator i = data.find(name);
  60. if (i == data.end())
  61. SignalTestError();
  62. return i->second;
  63. }
  64. void RandomizedTransfer(BufferedTransformation &source, BufferedTransformation &target, bool finish, const std::string &channel=DEFAULT_CHANNEL)
  65. {
  66. while (source.MaxRetrievable() > (finish ? 0 : 4096))
  67. {
  68. byte buf[4096+64];
  69. size_t start = GlobalRNG().GenerateWord32(0, 63);
  70. size_t len = GlobalRNG().GenerateWord32(1, UnsignedMin(4096U, 3*source.MaxRetrievable()/2));
  71. len = source.Get(buf+start, len);
  72. target.ChannelPut(channel, buf+start, len);
  73. }
  74. }
  75. void PutDecodedDatumInto(const TestData &data, const char *name, BufferedTransformation &target)
  76. {
  77. std::string s1 = GetRequiredDatum(data, name), s2;
  78. ByteQueue q;
  79. while (!s1.empty())
  80. {
  81. while (s1[0] == ' ')
  82. {
  83. s1 = s1.substr(1);
  84. if (s1.empty())
  85. goto end; // avoid invalid read if s1 is empty
  86. }
  87. int repeat = 1;
  88. if (s1[0] == 'r')
  89. {
  90. repeat = atoi(s1.c_str()+1);
  91. s1 = s1.substr(s1.find(' ')+1);
  92. }
  93. s2 = ""; // MSVC 6 doesn't have clear();
  94. if (s1[0] == '\"')
  95. {
  96. s2 = s1.substr(1, s1.find('\"', 1)-1);
  97. s1 = s1.substr(s2.length() + 2);
  98. }
  99. else if (s1.substr(0, 2) == "0x")
  100. {
  101. StringSource(s1.substr(2, s1.find(' ')), true, new HexDecoder(new StringSink(s2)));
  102. s1 = s1.substr(STDMIN(s1.find(' '), s1.length()));
  103. }
  104. else
  105. {
  106. StringSource(s1.substr(0, s1.find(' ')), true, new HexDecoder(new StringSink(s2)));
  107. s1 = s1.substr(STDMIN(s1.find(' '), s1.length()));
  108. }
  109. while (repeat--)
  110. {
  111. q.Put((const byte *)s2.data(), s2.size());
  112. RandomizedTransfer(q, target, false);
  113. }
  114. }
  115. end:
  116. RandomizedTransfer(q, target, true);
  117. }
  118. std::string GetDecodedDatum(const TestData &data, const char *name)
  119. {
  120. std::string s;
  121. PutDecodedDatumInto(data, name, StringSink(s).Ref());
  122. return s;
  123. }
  124. std::string GetOptionalDecodedDatum(const TestData &data, const char *name)
  125. {
  126. std::string s;
  127. if (DataExists(data, name))
  128. PutDecodedDatumInto(data, name, StringSink(s).Ref());
  129. return s;
  130. }
  131. class TestDataNameValuePairs : public NameValuePairs
  132. {
  133. public:
  134. TestDataNameValuePairs(const TestData &data) : m_data(data) {}
  135. virtual bool GetVoidValue(const char *name, const std::type_info &valueType, void *pValue) const
  136. {
  137. TestData::const_iterator i = m_data.find(name);
  138. if (i == m_data.end())
  139. {
  140. if (std::string(name) == Name::DigestSize() && valueType == typeid(int))
  141. {
  142. i = m_data.find("MAC");
  143. if (i == m_data.end())
  144. i = m_data.find("Digest");
  145. if (i == m_data.end())
  146. return false;
  147. m_temp.resize(0);
  148. PutDecodedDatumInto(m_data, i->first.c_str(), StringSink(m_temp).Ref());
  149. *reinterpret_cast<int *>(pValue) = (int)m_temp.size();
  150. return true;
  151. }
  152. else
  153. return false;
  154. }
  155. const std::string &value = i->second;
  156. if (valueType == typeid(int))
  157. *reinterpret_cast<int *>(pValue) = atoi(value.c_str());
  158. else if (valueType == typeid(Integer))
  159. *reinterpret_cast<Integer *>(pValue) = Integer((std::string(value) + "h").c_str());
  160. else if (valueType == typeid(ConstByteArrayParameter))
  161. {
  162. m_temp.resize(0);
  163. PutDecodedDatumInto(m_data, name, StringSink(m_temp).Ref());
  164. reinterpret_cast<ConstByteArrayParameter *>(pValue)->Assign((const byte *)m_temp.data(), m_temp.size(), false);
  165. }
  166. else
  167. throw ValueTypeMismatch(name, typeid(std::string), valueType);
  168. return true;
  169. }
  170. private:
  171. const TestData &m_data;
  172. mutable std::string m_temp;
  173. };
  174. void TestKeyPairValidAndConsistent(CryptoMaterial &pub, const CryptoMaterial &priv)
  175. {
  176. // "!!" converts between bool <-> integral.
  177. if (!pub.Validate(GlobalRNG(), 2U+!!s_thorough))
  178. SignalTestFailure();
  179. if (!priv.Validate(GlobalRNG(), 2U+!!s_thorough))
  180. SignalTestFailure();
  181. ByteQueue bq1, bq2;
  182. pub.Save(bq1);
  183. pub.AssignFrom(priv);
  184. pub.Save(bq2);
  185. if (bq1 != bq2)
  186. SignalTestFailure();
  187. }
  188. void TestSignatureScheme(TestData &v)
  189. {
  190. std::string name = GetRequiredDatum(v, "Name");
  191. std::string test = GetRequiredDatum(v, "Test");
  192. member_ptr<PK_Signer> signer(ObjectFactoryRegistry<PK_Signer>::Registry().CreateObject(name.c_str()));
  193. member_ptr<PK_Verifier> verifier(ObjectFactoryRegistry<PK_Verifier>::Registry().CreateObject(name.c_str()));
  194. TestDataNameValuePairs pairs(v);
  195. if (test == "GenerateKey")
  196. {
  197. signer->AccessPrivateKey().GenerateRandom(GlobalRNG(), pairs);
  198. verifier->AccessPublicKey().AssignFrom(signer->AccessPrivateKey());
  199. }
  200. else
  201. {
  202. std::string keyFormat = GetRequiredDatum(v, "KeyFormat");
  203. if (keyFormat == "DER")
  204. verifier->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref());
  205. else if (keyFormat == "Component")
  206. verifier->AccessMaterial().AssignFrom(pairs);
  207. if (test == "Verify" || test == "NotVerify")
  208. {
  209. VerifierFilter verifierFilter(*verifier, NULL, VerifierFilter::SIGNATURE_AT_BEGIN);
  210. PutDecodedDatumInto(v, "Signature", verifierFilter);
  211. PutDecodedDatumInto(v, "Message", verifierFilter);
  212. verifierFilter.MessageEnd();
  213. if (verifierFilter.GetLastResult() == (test == "NotVerify"))
  214. SignalTestFailure();
  215. return;
  216. }
  217. else if (test == "PublicKeyValid")
  218. {
  219. if (!verifier->GetMaterial().Validate(GlobalRNG(), 3))
  220. SignalTestFailure();
  221. return;
  222. }
  223. if (keyFormat == "DER")
  224. signer->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref());
  225. else if (keyFormat == "Component")
  226. signer->AccessMaterial().AssignFrom(pairs);
  227. }
  228. if (test == "GenerateKey" || test == "KeyPairValidAndConsistent")
  229. {
  230. TestKeyPairValidAndConsistent(verifier->AccessMaterial(), signer->GetMaterial());
  231. VerifierFilter verifierFilter(*verifier, NULL, VerifierFilter::THROW_EXCEPTION);
  232. verifierFilter.Put((const byte *)"abc", 3);
  233. StringSource ss("abc", true, new SignerFilter(GlobalRNG(), *signer, new Redirector(verifierFilter)));
  234. }
  235. else if (test == "Sign")
  236. {
  237. SignerFilter f(GlobalRNG(), *signer, new HexEncoder(new FileSink(cout)));
  238. StringSource ss(GetDecodedDatum(v, "Message"), true, new Redirector(f));
  239. SignalTestFailure();
  240. }
  241. else if (test == "DeterministicSign")
  242. {
  243. SignalTestError();
  244. assert(false); // TODO: implement
  245. }
  246. else if (test == "RandomSign")
  247. {
  248. SignalTestError();
  249. assert(false); // TODO: implement
  250. }
  251. else
  252. {
  253. SignalTestError();
  254. assert(false);
  255. }
  256. }
  257. void TestAsymmetricCipher(TestData &v)
  258. {
  259. std::string name = GetRequiredDatum(v, "Name");
  260. std::string test = GetRequiredDatum(v, "Test");
  261. member_ptr<PK_Encryptor> encryptor(ObjectFactoryRegistry<PK_Encryptor>::Registry().CreateObject(name.c_str()));
  262. member_ptr<PK_Decryptor> decryptor(ObjectFactoryRegistry<PK_Decryptor>::Registry().CreateObject(name.c_str()));
  263. std::string keyFormat = GetRequiredDatum(v, "KeyFormat");
  264. if (keyFormat == "DER")
  265. {
  266. decryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref());
  267. encryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref());
  268. }
  269. else if (keyFormat == "Component")
  270. {
  271. TestDataNameValuePairs pairs(v);
  272. decryptor->AccessMaterial().AssignFrom(pairs);
  273. encryptor->AccessMaterial().AssignFrom(pairs);
  274. }
  275. if (test == "DecryptMatch")
  276. {
  277. std::string decrypted, expected = GetDecodedDatum(v, "Plaintext");
  278. StringSource ss(GetDecodedDatum(v, "Ciphertext"), true, new PK_DecryptorFilter(GlobalRNG(), *decryptor, new StringSink(decrypted)));
  279. if (decrypted != expected)
  280. SignalTestFailure();
  281. }
  282. else if (test == "KeyPairValidAndConsistent")
  283. {
  284. TestKeyPairValidAndConsistent(encryptor->AccessMaterial(), decryptor->GetMaterial());
  285. }
  286. else
  287. {
  288. SignalTestError();
  289. assert(false);
  290. }
  291. }
  292. void TestSymmetricCipher(TestData &v, const NameValuePairs &overrideParameters)
  293. {
  294. std::string name = GetRequiredDatum(v, "Name");
  295. std::string test = GetRequiredDatum(v, "Test");
  296. std::string key = GetDecodedDatum(v, "Key");
  297. std::string plaintext = GetDecodedDatum(v, "Plaintext");
  298. TestDataNameValuePairs testDataPairs(v);
  299. CombinedNameValuePairs pairs(overrideParameters, testDataPairs);
  300. if (test == "Encrypt" || test == "EncryptXorDigest" || test == "Resync" || test == "EncryptionMCT" || test == "DecryptionMCT")
  301. {
  302. static member_ptr<SymmetricCipher> encryptor, decryptor;
  303. static std::string lastName;
  304. if (name != lastName)
  305. {
  306. encryptor.reset(ObjectFactoryRegistry<SymmetricCipher, ENCRYPTION>::Registry().CreateObject(name.c_str()));
  307. decryptor.reset(ObjectFactoryRegistry<SymmetricCipher, DECRYPTION>::Registry().CreateObject(name.c_str()));
  308. lastName = name;
  309. }
  310. ConstByteArrayParameter iv;
  311. if (pairs.GetValue(Name::IV(), iv) && iv.size() != encryptor->IVSize())
  312. SignalTestFailure();
  313. if (test == "Resync")
  314. {
  315. encryptor->Resynchronize(iv.begin(), (int)iv.size());
  316. decryptor->Resynchronize(iv.begin(), (int)iv.size());
  317. }
  318. else
  319. {
  320. encryptor->SetKey((const byte *)key.data(), key.size(), pairs);
  321. decryptor->SetKey((const byte *)key.data(), key.size(), pairs);
  322. }
  323. int seek = pairs.GetIntValueWithDefault("Seek", 0);
  324. if (seek)
  325. {
  326. encryptor->Seek(seek);
  327. decryptor->Seek(seek);
  328. }
  329. std::string encrypted, xorDigest, ciphertext, ciphertextXorDigest;
  330. if (test == "EncryptionMCT" || test == "DecryptionMCT")
  331. {
  332. SymmetricCipher *cipher = encryptor.get();
  333. SecByteBlock buf((byte *)plaintext.data(), plaintext.size()), keybuf((byte *)key.data(), key.size());
  334. if (test == "DecryptionMCT")
  335. {
  336. cipher = decryptor.get();
  337. ciphertext = GetDecodedDatum(v, "Ciphertext");
  338. buf.Assign((byte *)ciphertext.data(), ciphertext.size());
  339. }
  340. for (int i=0; i<400; i++)
  341. {
  342. encrypted.reserve(10000 * plaintext.size());
  343. for (int j=0; j<10000; j++)
  344. {
  345. cipher->ProcessString(buf.begin(), buf.size());
  346. encrypted.append((char *)buf.begin(), buf.size());
  347. }
  348. encrypted.erase(0, encrypted.size() - keybuf.size());
  349. xorbuf(keybuf.begin(), (const byte *)encrypted.data(), keybuf.size());
  350. cipher->SetKey(keybuf, keybuf.size());
  351. }
  352. encrypted.assign((char *)buf.begin(), buf.size());
  353. ciphertext = GetDecodedDatum(v, test == "EncryptionMCT" ? "Ciphertext" : "Plaintext");
  354. if (encrypted != ciphertext)
  355. {
  356. std::cout << "incorrectly encrypted: ";
  357. StringSource xx(encrypted, false, new HexEncoder(new FileSink(std::cout)));
  358. xx.Pump(256); xx.Flush(false);
  359. std::cout << "\n";
  360. SignalTestFailure();
  361. }
  362. return;
  363. }
  364. StreamTransformationFilter encFilter(*encryptor, new StringSink(encrypted), StreamTransformationFilter::NO_PADDING);
  365. RandomizedTransfer(StringStore(plaintext).Ref(), encFilter, true);
  366. encFilter.MessageEnd();
  367. /*{
  368. std::string z;
  369. encryptor->Seek(seek);
  370. StringSource ss(plaintext, false, new StreamTransformationFilter(*encryptor, new StringSink(z), StreamTransformationFilter::NO_PADDING));
  371. while (ss.Pump(64)) {}
  372. ss.PumpAll();
  373. for (int i=0; i<z.length(); i++)
  374. assert(encrypted[i] == z[i]);
  375. }*/
  376. if (test != "EncryptXorDigest")
  377. ciphertext = GetDecodedDatum(v, "Ciphertext");
  378. else
  379. {
  380. ciphertextXorDigest = GetDecodedDatum(v, "CiphertextXorDigest");
  381. xorDigest.append(encrypted, 0, 64);
  382. for (size_t i=64; i<encrypted.size(); i++)
  383. xorDigest[i%64] ^= encrypted[i];
  384. }
  385. if (test != "EncryptXorDigest" ? encrypted != ciphertext : xorDigest != ciphertextXorDigest)
  386. {
  387. std::cout << "incorrectly encrypted: ";
  388. StringSource xx(encrypted, false, new HexEncoder(new FileSink(std::cout)));
  389. xx.Pump(2048); xx.Flush(false);
  390. std::cout << "\n";
  391. SignalTestFailure();
  392. }
  393. std::string decrypted;
  394. StreamTransformationFilter decFilter(*decryptor, new StringSink(decrypted), StreamTransformationFilter::NO_PADDING);
  395. RandomizedTransfer(StringStore(encrypted).Ref(), decFilter, true);
  396. decFilter.MessageEnd();
  397. if (decrypted != plaintext)
  398. {
  399. std::cout << "incorrectly decrypted: ";
  400. StringSource xx(decrypted, false, new HexEncoder(new FileSink(std::cout)));
  401. xx.Pump(256); xx.Flush(false);
  402. std::cout << "\n";
  403. SignalTestFailure();
  404. }
  405. }
  406. else
  407. {
  408. std::cout << "unexpected test name\n";
  409. SignalTestError();
  410. }
  411. }
  412. void TestAuthenticatedSymmetricCipher(TestData &v, const NameValuePairs &overrideParameters)
  413. {
  414. std::string type = GetRequiredDatum(v, "AlgorithmType");
  415. std::string name = GetRequiredDatum(v, "Name");
  416. std::string test = GetRequiredDatum(v, "Test");
  417. std::string key = GetDecodedDatum(v, "Key");
  418. std::string plaintext = GetOptionalDecodedDatum(v, "Plaintext");
  419. std::string ciphertext = GetOptionalDecodedDatum(v, "Ciphertext");
  420. std::string header = GetOptionalDecodedDatum(v, "Header");
  421. std::string footer = GetOptionalDecodedDatum(v, "Footer");
  422. std::string mac = GetOptionalDecodedDatum(v, "MAC");
  423. TestDataNameValuePairs testDataPairs(v);
  424. CombinedNameValuePairs pairs(overrideParameters, testDataPairs);
  425. if (test == "Encrypt" || test == "EncryptXorDigest" || test == "NotVerify")
  426. {
  427. member_ptr<AuthenticatedSymmetricCipher> asc1, asc2;
  428. asc1.reset(ObjectFactoryRegistry<AuthenticatedSymmetricCipher, ENCRYPTION>::Registry().CreateObject(name.c_str()));
  429. asc2.reset(ObjectFactoryRegistry<AuthenticatedSymmetricCipher, DECRYPTION>::Registry().CreateObject(name.c_str()));
  430. asc1->SetKey((const byte *)key.data(), key.size(), pairs);
  431. asc2->SetKey((const byte *)key.data(), key.size(), pairs);
  432. std::string encrypted, decrypted;
  433. AuthenticatedEncryptionFilter ef(*asc1, new StringSink(encrypted));
  434. bool macAtBegin = !mac.empty() && !GlobalRNG().GenerateBit(); // test both ways randomly
  435. AuthenticatedDecryptionFilter df(*asc2, new StringSink(decrypted), macAtBegin ? AuthenticatedDecryptionFilter::MAC_AT_BEGIN : 0);
  436. if (asc1->NeedsPrespecifiedDataLengths())
  437. {
  438. asc1->SpecifyDataLengths(header.size(), plaintext.size(), footer.size());
  439. asc2->SpecifyDataLengths(header.size(), plaintext.size(), footer.size());
  440. }
  441. StringStore sh(header), sp(plaintext), sc(ciphertext), sf(footer), sm(mac);
  442. if (macAtBegin)
  443. RandomizedTransfer(sm, df, true);
  444. sh.CopyTo(df, LWORD_MAX, AAD_CHANNEL);
  445. RandomizedTransfer(sc, df, true);
  446. sf.CopyTo(df, LWORD_MAX, AAD_CHANNEL);
  447. if (!macAtBegin)
  448. RandomizedTransfer(sm, df, true);
  449. df.MessageEnd();
  450. RandomizedTransfer(sh, ef, true, AAD_CHANNEL);
  451. RandomizedTransfer(sp, ef, true);
  452. RandomizedTransfer(sf, ef, true, AAD_CHANNEL);
  453. ef.MessageEnd();
  454. if (test == "Encrypt" && encrypted != ciphertext+mac)
  455. {
  456. std::cout << "incorrectly encrypted: ";
  457. StringSource xx(encrypted, false, new HexEncoder(new FileSink(std::cout)));
  458. xx.Pump(2048); xx.Flush(false);
  459. std::cout << "\n";
  460. SignalTestFailure();
  461. }
  462. if (test == "Encrypt" && decrypted != plaintext)
  463. {
  464. std::cout << "incorrectly decrypted: ";
  465. StringSource xx(decrypted, false, new HexEncoder(new FileSink(std::cout)));
  466. xx.Pump(256); xx.Flush(false);
  467. std::cout << "\n";
  468. SignalTestFailure();
  469. }
  470. if (ciphertext.size()+mac.size()-plaintext.size() != asc1->DigestSize())
  471. {
  472. std::cout << "bad MAC size\n";
  473. SignalTestFailure();
  474. }
  475. if (df.GetLastResult() != (test == "Encrypt"))
  476. {
  477. std::cout << "MAC incorrectly verified\n";
  478. SignalTestFailure();
  479. }
  480. }
  481. else
  482. {
  483. std::cout << "unexpected test name\n";
  484. SignalTestError();
  485. }
  486. }
  487. void TestDigestOrMAC(TestData &v, bool testDigest)
  488. {
  489. std::string name = GetRequiredDatum(v, "Name");
  490. std::string test = GetRequiredDatum(v, "Test");
  491. const char *digestName = testDigest ? "Digest" : "MAC";
  492. member_ptr<MessageAuthenticationCode> mac;
  493. member_ptr<HashTransformation> hash;
  494. HashTransformation *pHash = NULL;
  495. TestDataNameValuePairs pairs(v);
  496. if (testDigest)
  497. {
  498. hash.reset(ObjectFactoryRegistry<HashTransformation>::Registry().CreateObject(name.c_str()));
  499. pHash = hash.get();
  500. }
  501. else
  502. {
  503. mac.reset(ObjectFactoryRegistry<MessageAuthenticationCode>::Registry().CreateObject(name.c_str()));
  504. pHash = mac.get();
  505. std::string key = GetDecodedDatum(v, "Key");
  506. mac->SetKey((const byte *)key.c_str(), key.size(), pairs);
  507. }
  508. if (test == "Verify" || test == "VerifyTruncated" || test == "NotVerify")
  509. {
  510. int digestSize = -1;
  511. if (test == "VerifyTruncated")
  512. digestSize = pairs.GetIntValueWithDefault(Name::DigestSize(), digestSize);
  513. HashVerificationFilter verifierFilter(*pHash, NULL, HashVerificationFilter::HASH_AT_BEGIN, digestSize);
  514. PutDecodedDatumInto(v, digestName, verifierFilter);
  515. PutDecodedDatumInto(v, "Message", verifierFilter);
  516. verifierFilter.MessageEnd();
  517. if (verifierFilter.GetLastResult() == (test == "NotVerify"))
  518. SignalTestFailure();
  519. }
  520. else
  521. {
  522. SignalTestError();
  523. assert(false);
  524. }
  525. }
  526. void TestKeyDerivationFunction(TestData &v)
  527. {
  528. std::string name = GetRequiredDatum(v, "Name");
  529. std::string test = GetRequiredDatum(v, "Test");
  530. if(test == "Skip") return;
  531. assert(test == "Verify");
  532. std::string key = GetDecodedDatum(v, "Key");
  533. std::string salt = GetDecodedDatum(v, "Salt");
  534. std::string info = GetDecodedDatum(v, "Info");
  535. std::string derived = GetDecodedDatum(v, "DerivedKey");
  536. std::string t = GetDecodedDatum(v, "DerivedKeyLength");
  537. TestDataNameValuePairs pairs(v);
  538. unsigned int length = pairs.GetIntValueWithDefault(Name::DerivedKeyLength(), (int)derived.size());
  539. member_ptr<KeyDerivationFunction> kdf;
  540. kdf.reset(ObjectFactoryRegistry<KeyDerivationFunction>::Registry().CreateObject(name.c_str()));
  541. std::string calc; calc.resize(length);
  542. unsigned int ret = kdf->DeriveKey(reinterpret_cast<byte*>(&calc[0]), calc.size(),
  543. reinterpret_cast<const byte*>(key.data()), key.size(),
  544. reinterpret_cast<const byte*>(salt.data()), salt.size(),
  545. reinterpret_cast<const byte*>(info.data()), info.size());
  546. if(calc != derived || ret != length)
  547. SignalTestFailure();
  548. }
  549. bool GetField(std::istream &is, std::string &name, std::string &value)
  550. {
  551. name.resize(0); // GCC workaround: 2.95.3 doesn't have clear()
  552. is >> name;
  553. #if defined(__COVERITY__)
  554. // The datafile being read is in /usr/share, and it protected by filesystem ACLs
  555. // __coverity_tainted_data_sanitize__(reinterpret_cast<void*>(&name));
  556. #endif
  557. if (name.empty())
  558. return false;
  559. if (name[name.size()-1] != ':')
  560. {
  561. char c;
  562. is >> skipws >> c;
  563. if (c != ':')
  564. SignalTestError();
  565. }
  566. else
  567. name.erase(name.size()-1);
  568. while (is.peek() == ' ')
  569. is.ignore(1);
  570. // VC60 workaround: getline bug
  571. char buffer[128];
  572. value.resize(0); // GCC workaround: 2.95.3 doesn't have clear()
  573. bool continueLine;
  574. do
  575. {
  576. do
  577. {
  578. is.get(buffer, sizeof(buffer));
  579. value += buffer;
  580. }
  581. while (buffer[0] != 0);
  582. is.clear();
  583. is.ignore();
  584. if (!value.empty() && value[value.size()-1] == '\r')
  585. value.resize(value.size()-1);
  586. if (!value.empty() && value[value.size()-1] == '\\')
  587. {
  588. value.resize(value.size()-1);
  589. continueLine = true;
  590. }
  591. else
  592. continueLine = false;
  593. std::string::size_type i = value.find('#');
  594. if (i != std::string::npos)
  595. value.erase(i);
  596. }
  597. while (continueLine);
  598. return true;
  599. }
  600. void OutputPair(const NameValuePairs &v, const char *name)
  601. {
  602. Integer x;
  603. bool b = v.GetValue(name, x);
  604. CRYPTOPP_UNUSED(b); assert(b);
  605. cout << name << ": \\\n ";
  606. x.Encode(HexEncoder(new FileSink(cout), false, 64, "\\\n ").Ref(), x.MinEncodedSize());
  607. cout << endl;
  608. }
  609. void OutputNameValuePairs(const NameValuePairs &v)
  610. {
  611. std::string names = v.GetValueNames();
  612. string::size_type i = 0;
  613. while (i < names.size())
  614. {
  615. string::size_type j = names.find_first_of (';', i);
  616. if (j == string::npos)
  617. return;
  618. else
  619. {
  620. std::string name = names.substr(i, j-i);
  621. if (name.find(':') == string::npos)
  622. OutputPair(v, name.c_str());
  623. }
  624. i = j + 1;
  625. }
  626. }
  627. void TestDataFile(const std::string &filename, const NameValuePairs &overrideParameters, unsigned int &totalTests, unsigned int &failedTests)
  628. {
  629. std::ifstream file(filename.c_str());
  630. if (!file.good())
  631. throw Exception(Exception::OTHER_ERROR, "Can not open file " + filename + " for reading");
  632. TestData v;
  633. s_currentTestData = &v;
  634. std::string name, value, lastAlgName;
  635. while (file)
  636. {
  637. while (file.peek() == '#')
  638. file.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
  639. if (file.peek() == '\n' || file.peek() == '\r')
  640. v.clear();
  641. if (!GetField(file, name, value))
  642. break;
  643. v[name] = value;
  644. if (name == "Test" && (s_thorough || v["SlowTest"] != "1"))
  645. {
  646. bool failed = true;
  647. std::string algType = GetRequiredDatum(v, "AlgorithmType");
  648. if (lastAlgName != GetRequiredDatum(v, "Name"))
  649. {
  650. lastAlgName = GetRequiredDatum(v, "Name");
  651. cout << "\nTesting " << algType.c_str() << " algorithm " << lastAlgName.c_str() << ".\n";
  652. }
  653. try
  654. {
  655. if (algType == "Signature")
  656. TestSignatureScheme(v);
  657. else if (algType == "SymmetricCipher")
  658. TestSymmetricCipher(v, overrideParameters);
  659. else if (algType == "AuthenticatedSymmetricCipher")
  660. TestAuthenticatedSymmetricCipher(v, overrideParameters);
  661. else if (algType == "AsymmetricCipher")
  662. TestAsymmetricCipher(v);
  663. else if (algType == "MessageDigest")
  664. TestDigestOrMAC(v, true);
  665. else if (algType == "MAC")
  666. TestDigestOrMAC(v, false);
  667. else if (algType == "KDF")
  668. TestKeyDerivationFunction(v);
  669. else if (algType == "FileList")
  670. TestDataFile(GetRequiredDatum(v, "Test"), g_nullNameValuePairs, totalTests, failedTests);
  671. else
  672. SignalTestError();
  673. failed = false;
  674. }
  675. catch (TestFailure &)
  676. {
  677. cout << "\nTest failed.\n";
  678. }
  679. catch (CryptoPP::Exception &e)
  680. {
  681. cout << "\nCryptoPP::Exception caught: " << e.what() << endl;
  682. }
  683. catch (std::exception &e)
  684. {
  685. cout << "\nstd::exception caught: " << e.what() << endl;
  686. }
  687. if (failed)
  688. {
  689. cout << "Skipping to next test.\n";
  690. failedTests++;
  691. }
  692. else
  693. cout << "." << flush;
  694. totalTests++;
  695. }
  696. }
  697. }
  698. bool RunTestDataFile(const char *filename, const NameValuePairs &overrideParameters, bool thorough)
  699. {
  700. s_thorough = thorough;
  701. unsigned int totalTests = 0, failedTests = 0;
  702. TestDataFile(filename, overrideParameters, totalTests, failedTests);
  703. cout << dec << "\nTests complete. Total tests = " << totalTests << ". Failed tests = " << failedTests << ".\n";
  704. if (failedTests != 0)
  705. cout << "SOME TESTS FAILED!\n";
  706. return failedTests == 0;
  707. }