Counter Strike : Global Offensive Source Code
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

752 lines
21 KiB

  1. #include "factory.h"
  2. #include "integer.h"
  3. #include "filters.h"
  4. #include "hex.h"
  5. #include "randpool.h"
  6. #include "files.h"
  7. #include "trunhash.h"
  8. #include "queue.h"
  9. #include "validate.h"
  10. #include <iostream>
  11. #include <memory>
  12. USING_NAMESPACE(CryptoPP)
  13. USING_NAMESPACE(std)
  14. typedef std::map<std::string, std::string> TestData;
  15. class TestFailure : public Exception
  16. {
  17. public:
  18. TestFailure() : Exception(OTHER_ERROR, "Validation test failed") {}
  19. };
  20. static const TestData *s_currentTestData = NULL;
  21. static void OutputTestData(const TestData &v)
  22. {
  23. for (TestData::const_iterator i = v.begin(); i != v.end(); ++i)
  24. {
  25. cerr << i->first << ": " << i->second << endl;
  26. }
  27. }
  28. static void SignalTestFailure()
  29. {
  30. OutputTestData(*s_currentTestData);
  31. throw TestFailure();
  32. }
  33. static void SignalTestError()
  34. {
  35. OutputTestData(*s_currentTestData);
  36. throw Exception(Exception::OTHER_ERROR, "Unexpected error during validation test");
  37. }
  38. bool DataExists(const TestData &data, const char *name)
  39. {
  40. TestData::const_iterator i = data.find(name);
  41. return (i != data.end());
  42. }
  43. const std::string & GetRequiredDatum(const TestData &data, const char *name)
  44. {
  45. TestData::const_iterator i = data.find(name);
  46. if (i == data.end())
  47. SignalTestError();
  48. return i->second;
  49. }
  50. void RandomizedTransfer(BufferedTransformation &source, BufferedTransformation &target, bool finish, const std::string &channel=DEFAULT_CHANNEL)
  51. {
  52. while (source.MaxRetrievable() > (finish ? 0 : 4096))
  53. {
  54. byte buf[4096+64];
  55. size_t start = GlobalRNG().GenerateWord32(0, 63);
  56. size_t len = GlobalRNG().GenerateWord32(1, UnsignedMin(4096U, 3*source.MaxRetrievable()/2));
  57. len = source.Get(buf+start, len);
  58. target.ChannelPut(channel, buf+start, len);
  59. }
  60. }
  61. void PutDecodedDatumInto(const TestData &data, const char *name, BufferedTransformation &target)
  62. {
  63. std::string s1 = GetRequiredDatum(data, name), s2;
  64. ByteQueue q;
  65. while (!s1.empty())
  66. {
  67. while (s1[0] == ' ')
  68. {
  69. s1 = s1.substr(1);
  70. if (s1.empty())
  71. goto end; // avoid invalid read if s1 is empty
  72. }
  73. int repeat = 1;
  74. if (s1[0] == 'r')
  75. {
  76. repeat = atoi(s1.c_str()+1);
  77. s1 = s1.substr(s1.find(' ')+1);
  78. }
  79. s2 = ""; // MSVC 6 doesn't have clear();
  80. if (s1[0] == '\"')
  81. {
  82. s2 = s1.substr(1, s1.find('\"', 1)-1);
  83. s1 = s1.substr(s2.length() + 2);
  84. }
  85. else if (s1.substr(0, 2) == "0x")
  86. {
  87. StringSource(s1.substr(2, s1.find(' ')), true, new HexDecoder(new StringSink(s2)));
  88. s1 = s1.substr(STDMIN(s1.find(' '), s1.length()));
  89. }
  90. else
  91. {
  92. StringSource(s1.substr(0, s1.find(' ')), true, new HexDecoder(new StringSink(s2)));
  93. s1 = s1.substr(STDMIN(s1.find(' '), s1.length()));
  94. }
  95. while (repeat--)
  96. {
  97. q.Put((const byte *)s2.data(), s2.size());
  98. RandomizedTransfer(q, target, false);
  99. }
  100. }
  101. end:
  102. RandomizedTransfer(q, target, true);
  103. }
  104. std::string GetDecodedDatum(const TestData &data, const char *name)
  105. {
  106. std::string s;
  107. PutDecodedDatumInto(data, name, StringSink(s).Ref());
  108. return s;
  109. }
  110. std::string GetOptionalDecodedDatum(const TestData &data, const char *name)
  111. {
  112. std::string s;
  113. if (DataExists(data, name))
  114. PutDecodedDatumInto(data, name, StringSink(s).Ref());
  115. return s;
  116. }
  117. class TestDataNameValuePairs : public NameValuePairs
  118. {
  119. public:
  120. TestDataNameValuePairs(const TestData &data) : m_data(data) {}
  121. virtual bool GetVoidValue(const char *name, const std::type_info &valueType, void *pValue) const
  122. {
  123. TestData::const_iterator i = m_data.find(name);
  124. if (i == m_data.end())
  125. {
  126. if (std::string(name) == Name::DigestSize() && valueType == typeid(int))
  127. {
  128. i = m_data.find("MAC");
  129. if (i == m_data.end())
  130. i = m_data.find("Digest");
  131. if (i == m_data.end())
  132. return false;
  133. m_temp.resize(0);
  134. PutDecodedDatumInto(m_data, i->first.c_str(), StringSink(m_temp).Ref());
  135. *reinterpret_cast<int *>(pValue) = (int)m_temp.size();
  136. return true;
  137. }
  138. else
  139. return false;
  140. }
  141. const std::string &value = i->second;
  142. if (valueType == typeid(int))
  143. *reinterpret_cast<int *>(pValue) = atoi(value.c_str());
  144. else if (valueType == typeid(Integer))
  145. *reinterpret_cast<Integer *>(pValue) = Integer((std::string(value) + "h").c_str());
  146. else if (valueType == typeid(ConstByteArrayParameter))
  147. {
  148. m_temp.resize(0);
  149. PutDecodedDatumInto(m_data, name, StringSink(m_temp).Ref());
  150. reinterpret_cast<ConstByteArrayParameter *>(pValue)->Assign((const byte *)m_temp.data(), m_temp.size(), false);
  151. }
  152. else
  153. throw ValueTypeMismatch(name, typeid(std::string), valueType);
  154. return true;
  155. }
  156. private:
  157. const TestData &m_data;
  158. mutable std::string m_temp;
  159. };
  160. void TestKeyPairValidAndConsistent(CryptoMaterial &pub, const CryptoMaterial &priv)
  161. {
  162. if (!pub.Validate(GlobalRNG(), 3))
  163. SignalTestFailure();
  164. if (!priv.Validate(GlobalRNG(), 3))
  165. SignalTestFailure();
  166. /* EqualityComparisonFilter comparison;
  167. pub.Save(ChannelSwitch(comparison, "0"));
  168. pub.AssignFrom(priv);
  169. pub.Save(ChannelSwitch(comparison, "1"));
  170. comparison.ChannelMessageSeriesEnd("0");
  171. comparison.ChannelMessageSeriesEnd("1");
  172. */
  173. }
  174. void TestSignatureScheme(TestData &v)
  175. {
  176. std::string name = GetRequiredDatum(v, "Name");
  177. std::string test = GetRequiredDatum(v, "Test");
  178. std::auto_ptr<PK_Signer> signer(ObjectFactoryRegistry<PK_Signer>::Registry().CreateObject(name.c_str()));
  179. std::auto_ptr<PK_Verifier> verifier(ObjectFactoryRegistry<PK_Verifier>::Registry().CreateObject(name.c_str()));
  180. TestDataNameValuePairs pairs(v);
  181. std::string keyFormat = GetRequiredDatum(v, "KeyFormat");
  182. if (keyFormat == "DER")
  183. verifier->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref());
  184. else if (keyFormat == "Component")
  185. verifier->AccessMaterial().AssignFrom(pairs);
  186. if (test == "Verify" || test == "NotVerify")
  187. {
  188. VerifierFilter verifierFilter(*verifier, NULL, VerifierFilter::SIGNATURE_AT_BEGIN);
  189. PutDecodedDatumInto(v, "Signature", verifierFilter);
  190. PutDecodedDatumInto(v, "Message", verifierFilter);
  191. verifierFilter.MessageEnd();
  192. if (verifierFilter.GetLastResult() == (test == "NotVerify"))
  193. SignalTestFailure();
  194. }
  195. else if (test == "PublicKeyValid")
  196. {
  197. if (!verifier->GetMaterial().Validate(GlobalRNG(), 3))
  198. SignalTestFailure();
  199. }
  200. else
  201. goto privateKeyTests;
  202. return;
  203. privateKeyTests:
  204. if (keyFormat == "DER")
  205. signer->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref());
  206. else if (keyFormat == "Component")
  207. signer->AccessMaterial().AssignFrom(pairs);
  208. if (test == "KeyPairValidAndConsistent")
  209. {
  210. TestKeyPairValidAndConsistent(verifier->AccessMaterial(), signer->GetMaterial());
  211. }
  212. else if (test == "Sign")
  213. {
  214. SignerFilter f(GlobalRNG(), *signer, new HexEncoder(new FileSink(cout)));
  215. StringSource ss(GetDecodedDatum(v, "Message"), true, new Redirector(f));
  216. SignalTestFailure();
  217. }
  218. else if (test == "DeterministicSign")
  219. {
  220. SignalTestError();
  221. assert(false); // TODO: implement
  222. }
  223. else if (test == "RandomSign")
  224. {
  225. SignalTestError();
  226. assert(false); // TODO: implement
  227. }
  228. else if (test == "GenerateKey")
  229. {
  230. SignalTestError();
  231. assert(false);
  232. }
  233. else
  234. {
  235. SignalTestError();
  236. assert(false);
  237. }
  238. }
  239. void TestAsymmetricCipher(TestData &v)
  240. {
  241. std::string name = GetRequiredDatum(v, "Name");
  242. std::string test = GetRequiredDatum(v, "Test");
  243. std::auto_ptr<PK_Encryptor> encryptor(ObjectFactoryRegistry<PK_Encryptor>::Registry().CreateObject(name.c_str()));
  244. std::auto_ptr<PK_Decryptor> decryptor(ObjectFactoryRegistry<PK_Decryptor>::Registry().CreateObject(name.c_str()));
  245. std::string keyFormat = GetRequiredDatum(v, "KeyFormat");
  246. if (keyFormat == "DER")
  247. {
  248. decryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref());
  249. encryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref());
  250. }
  251. else if (keyFormat == "Component")
  252. {
  253. TestDataNameValuePairs pairs(v);
  254. decryptor->AccessMaterial().AssignFrom(pairs);
  255. encryptor->AccessMaterial().AssignFrom(pairs);
  256. }
  257. if (test == "DecryptMatch")
  258. {
  259. std::string decrypted, expected = GetDecodedDatum(v, "Plaintext");
  260. StringSource ss(GetDecodedDatum(v, "Ciphertext"), true, new PK_DecryptorFilter(GlobalRNG(), *decryptor, new StringSink(decrypted)));
  261. if (decrypted != expected)
  262. SignalTestFailure();
  263. }
  264. else if (test == "KeyPairValidAndConsistent")
  265. {
  266. TestKeyPairValidAndConsistent(encryptor->AccessMaterial(), decryptor->GetMaterial());
  267. }
  268. else
  269. {
  270. SignalTestError();
  271. assert(false);
  272. }
  273. }
  274. void TestSymmetricCipher(TestData &v, const NameValuePairs &overrideParameters)
  275. {
  276. std::string name = GetRequiredDatum(v, "Name");
  277. std::string test = GetRequiredDatum(v, "Test");
  278. std::string key = GetDecodedDatum(v, "Key");
  279. std::string plaintext = GetDecodedDatum(v, "Plaintext");
  280. TestDataNameValuePairs testDataPairs(v);
  281. CombinedNameValuePairs pairs(overrideParameters, testDataPairs);
  282. if (test == "Encrypt" || test == "EncryptXorDigest" || test == "Resync" || test == "EncryptionMCT" || test == "DecryptionMCT")
  283. {
  284. static member_ptr<SymmetricCipher> encryptor, decryptor;
  285. static std::string lastName;
  286. if (name != lastName)
  287. {
  288. encryptor.reset(ObjectFactoryRegistry<SymmetricCipher, ENCRYPTION>::Registry().CreateObject(name.c_str()));
  289. decryptor.reset(ObjectFactoryRegistry<SymmetricCipher, DECRYPTION>::Registry().CreateObject(name.c_str()));
  290. lastName = name;
  291. }
  292. ConstByteArrayParameter iv;
  293. if (pairs.GetValue(Name::IV(), iv) && iv.size() != encryptor->IVSize())
  294. SignalTestFailure();
  295. if (test == "Resync")
  296. {
  297. encryptor->Resynchronize(iv.begin(), (int)iv.size());
  298. decryptor->Resynchronize(iv.begin(), (int)iv.size());
  299. }
  300. else
  301. {
  302. encryptor->SetKey((const byte *)key.data(), key.size(), pairs);
  303. decryptor->SetKey((const byte *)key.data(), key.size(), pairs);
  304. }
  305. int seek = pairs.GetIntValueWithDefault("Seek", 0);
  306. if (seek)
  307. {
  308. encryptor->Seek(seek);
  309. decryptor->Seek(seek);
  310. }
  311. std::string encrypted, xorDigest, ciphertext, ciphertextXorDigest;
  312. if (test == "EncryptionMCT" || test == "DecryptionMCT")
  313. {
  314. SymmetricCipher *cipher = encryptor.get();
  315. SecByteBlock buf((byte *)plaintext.data(), plaintext.size()), keybuf((byte *)key.data(), key.size());
  316. if (test == "DecryptionMCT")
  317. {
  318. cipher = decryptor.get();
  319. ciphertext = GetDecodedDatum(v, "Ciphertext");
  320. buf.Assign((byte *)ciphertext.data(), ciphertext.size());
  321. }
  322. for (int i=0; i<400; i++)
  323. {
  324. encrypted.reserve(10000 * plaintext.size());
  325. for (int j=0; j<10000; j++)
  326. {
  327. cipher->ProcessString(buf.begin(), buf.size());
  328. encrypted.append((char *)buf.begin(), buf.size());
  329. }
  330. encrypted.erase(0, encrypted.size() - keybuf.size());
  331. xorbuf(keybuf.begin(), (const byte *)encrypted.data(), keybuf.size());
  332. cipher->SetKey(keybuf, keybuf.size());
  333. }
  334. encrypted.assign((char *)buf.begin(), buf.size());
  335. ciphertext = GetDecodedDatum(v, test == "EncryptionMCT" ? "Ciphertext" : "Plaintext");
  336. if (encrypted != ciphertext)
  337. {
  338. std::cout << "incorrectly encrypted: ";
  339. StringSource xx(encrypted, false, new HexEncoder(new FileSink(std::cout)));
  340. xx.Pump(256); xx.Flush(false);
  341. std::cout << "\n";
  342. SignalTestFailure();
  343. }
  344. return;
  345. }
  346. StreamTransformationFilter encFilter(*encryptor, new StringSink(encrypted), StreamTransformationFilter::NO_PADDING);
  347. RandomizedTransfer(StringStore(plaintext).Ref(), encFilter, true);
  348. encFilter.MessageEnd();
  349. /*{
  350. std::string z;
  351. encryptor->Seek(seek);
  352. StringSource ss(plaintext, false, new StreamTransformationFilter(*encryptor, new StringSink(z), StreamTransformationFilter::NO_PADDING));
  353. while (ss.Pump(64)) {}
  354. ss.PumpAll();
  355. for (int i=0; i<z.length(); i++)
  356. assert(encrypted[i] == z[i]);
  357. }*/
  358. if (test != "EncryptXorDigest")
  359. ciphertext = GetDecodedDatum(v, "Ciphertext");
  360. else
  361. {
  362. ciphertextXorDigest = GetDecodedDatum(v, "CiphertextXorDigest");
  363. xorDigest.append(encrypted, 0, 64);
  364. for (size_t i=64; i<encrypted.size(); i++)
  365. xorDigest[i%64] ^= encrypted[i];
  366. }
  367. if (test != "EncryptXorDigest" ? encrypted != ciphertext : xorDigest != ciphertextXorDigest)
  368. {
  369. std::cout << "incorrectly encrypted: ";
  370. StringSource xx(encrypted, false, new HexEncoder(new FileSink(std::cout)));
  371. xx.Pump(2048); xx.Flush(false);
  372. std::cout << "\n";
  373. SignalTestFailure();
  374. }
  375. std::string decrypted;
  376. StreamTransformationFilter decFilter(*decryptor, new StringSink(decrypted), StreamTransformationFilter::NO_PADDING);
  377. RandomizedTransfer(StringStore(encrypted).Ref(), decFilter, true);
  378. decFilter.MessageEnd();
  379. if (decrypted != plaintext)
  380. {
  381. std::cout << "incorrectly decrypted: ";
  382. StringSource xx(decrypted, false, new HexEncoder(new FileSink(std::cout)));
  383. xx.Pump(256); xx.Flush(false);
  384. std::cout << "\n";
  385. SignalTestFailure();
  386. }
  387. }
  388. else
  389. {
  390. std::cout << "unexpected test name\n";
  391. SignalTestError();
  392. }
  393. }
  394. void TestAuthenticatedSymmetricCipher(TestData &v, const NameValuePairs &overrideParameters)
  395. {
  396. std::string type = GetRequiredDatum(v, "AlgorithmType");
  397. std::string name = GetRequiredDatum(v, "Name");
  398. std::string test = GetRequiredDatum(v, "Test");
  399. std::string key = GetDecodedDatum(v, "Key");
  400. std::string plaintext = GetOptionalDecodedDatum(v, "Plaintext");
  401. std::string ciphertext = GetOptionalDecodedDatum(v, "Ciphertext");
  402. std::string header = GetOptionalDecodedDatum(v, "Header");
  403. std::string footer = GetOptionalDecodedDatum(v, "Footer");
  404. std::string mac = GetOptionalDecodedDatum(v, "MAC");
  405. TestDataNameValuePairs testDataPairs(v);
  406. CombinedNameValuePairs pairs(overrideParameters, testDataPairs);
  407. if (test == "Encrypt" || test == "EncryptXorDigest" || test == "NotVerify")
  408. {
  409. member_ptr<AuthenticatedSymmetricCipher> asc1, asc2;
  410. asc1.reset(ObjectFactoryRegistry<AuthenticatedSymmetricCipher, ENCRYPTION>::Registry().CreateObject(name.c_str()));
  411. asc2.reset(ObjectFactoryRegistry<AuthenticatedSymmetricCipher, DECRYPTION>::Registry().CreateObject(name.c_str()));
  412. asc1->SetKey((const byte *)key.data(), key.size(), pairs);
  413. asc2->SetKey((const byte *)key.data(), key.size(), pairs);
  414. std::string encrypted, decrypted;
  415. AuthenticatedEncryptionFilter ef(*asc1, new StringSink(encrypted));
  416. bool macAtBegin = !mac.empty() && !GlobalRNG().GenerateBit(); // test both ways randomly
  417. AuthenticatedDecryptionFilter df(*asc2, new StringSink(decrypted), macAtBegin ? AuthenticatedDecryptionFilter::MAC_AT_BEGIN : 0);
  418. if (asc1->NeedsPrespecifiedDataLengths())
  419. {
  420. asc1->SpecifyDataLengths(header.size(), plaintext.size(), footer.size());
  421. asc2->SpecifyDataLengths(header.size(), plaintext.size(), footer.size());
  422. }
  423. StringStore sh(header), sp(plaintext), sc(ciphertext), sf(footer), sm(mac);
  424. if (macAtBegin)
  425. RandomizedTransfer(sm, df, true);
  426. sh.CopyTo(df, LWORD_MAX, AAD_CHANNEL);
  427. RandomizedTransfer(sc, df, true);
  428. sf.CopyTo(df, LWORD_MAX, AAD_CHANNEL);
  429. if (!macAtBegin)
  430. RandomizedTransfer(sm, df, true);
  431. df.MessageEnd();
  432. RandomizedTransfer(sh, ef, true, AAD_CHANNEL);
  433. RandomizedTransfer(sp, ef, true);
  434. RandomizedTransfer(sf, ef, true, AAD_CHANNEL);
  435. ef.MessageEnd();
  436. if (test == "Encrypt" && encrypted != ciphertext+mac)
  437. {
  438. std::cout << "incorrectly encrypted: ";
  439. StringSource xx(encrypted, false, new HexEncoder(new FileSink(std::cout)));
  440. xx.Pump(2048); xx.Flush(false);
  441. std::cout << "\n";
  442. SignalTestFailure();
  443. }
  444. if (test == "Encrypt" && decrypted != plaintext)
  445. {
  446. std::cout << "incorrectly decrypted: ";
  447. StringSource xx(decrypted, false, new HexEncoder(new FileSink(std::cout)));
  448. xx.Pump(256); xx.Flush(false);
  449. std::cout << "\n";
  450. SignalTestFailure();
  451. }
  452. if (ciphertext.size()+mac.size()-plaintext.size() != asc1->DigestSize())
  453. {
  454. std::cout << "bad MAC size\n";
  455. SignalTestFailure();
  456. }
  457. if (df.GetLastResult() != (test == "Encrypt"))
  458. {
  459. std::cout << "MAC incorrectly verified\n";
  460. SignalTestFailure();
  461. }
  462. }
  463. else
  464. {
  465. std::cout << "unexpected test name\n";
  466. SignalTestError();
  467. }
  468. }
  469. void TestDigestOrMAC(TestData &v, bool testDigest)
  470. {
  471. std::string name = GetRequiredDatum(v, "Name");
  472. std::string test = GetRequiredDatum(v, "Test");
  473. const char *digestName = testDigest ? "Digest" : "MAC";
  474. member_ptr<MessageAuthenticationCode> mac;
  475. member_ptr<HashTransformation> hash;
  476. HashTransformation *pHash = NULL;
  477. TestDataNameValuePairs pairs(v);
  478. if (testDigest)
  479. {
  480. hash.reset(ObjectFactoryRegistry<HashTransformation>::Registry().CreateObject(name.c_str()));
  481. pHash = hash.get();
  482. }
  483. else
  484. {
  485. mac.reset(ObjectFactoryRegistry<MessageAuthenticationCode>::Registry().CreateObject(name.c_str()));
  486. pHash = mac.get();
  487. std::string key = GetDecodedDatum(v, "Key");
  488. mac->SetKey((const byte *)key.c_str(), key.size(), pairs);
  489. }
  490. if (test == "Verify" || test == "VerifyTruncated" || test == "NotVerify")
  491. {
  492. int digestSize = -1;
  493. if (test == "VerifyTruncated")
  494. pairs.GetIntValue(Name::DigestSize(), digestSize);
  495. HashVerificationFilter verifierFilter(*pHash, NULL, HashVerificationFilter::HASH_AT_BEGIN, digestSize);
  496. PutDecodedDatumInto(v, digestName, verifierFilter);
  497. PutDecodedDatumInto(v, "Message", verifierFilter);
  498. verifierFilter.MessageEnd();
  499. if (verifierFilter.GetLastResult() == (test == "NotVerify"))
  500. SignalTestFailure();
  501. }
  502. else
  503. {
  504. SignalTestError();
  505. assert(false);
  506. }
  507. }
  508. bool GetField(std::istream &is, std::string &name, std::string &value)
  509. {
  510. name.resize(0); // GCC workaround: 2.95.3 doesn't have clear()
  511. is >> name;
  512. if (name.empty())
  513. return false;
  514. if (name[name.size()-1] != ':')
  515. {
  516. char c;
  517. is >> skipws >> c;
  518. if (c != ':')
  519. SignalTestError();
  520. }
  521. else
  522. name.erase(name.size()-1);
  523. while (is.peek() == ' ')
  524. is.ignore(1);
  525. // VC60 workaround: getline bug
  526. char buffer[128];
  527. value.resize(0); // GCC workaround: 2.95.3 doesn't have clear()
  528. bool continueLine;
  529. do
  530. {
  531. do
  532. {
  533. is.get(buffer, sizeof(buffer));
  534. value += buffer;
  535. }
  536. while (buffer[0] != 0);
  537. is.clear();
  538. is.ignore();
  539. if (!value.empty() && value[value.size()-1] == '\r')
  540. value.resize(value.size()-1);
  541. if (!value.empty() && value[value.size()-1] == '\\')
  542. {
  543. value.resize(value.size()-1);
  544. continueLine = true;
  545. }
  546. else
  547. continueLine = false;
  548. std::string::size_type i = value.find('#');
  549. if (i != std::string::npos)
  550. value.erase(i);
  551. }
  552. while (continueLine);
  553. return true;
  554. }
  555. void OutputPair(const NameValuePairs &v, const char *name)
  556. {
  557. Integer x;
  558. bool b = v.GetValue(name, x);
  559. assert(b);
  560. cout << name << ": \\\n ";
  561. x.Encode(HexEncoder(new FileSink(cout), false, 64, "\\\n ").Ref(), x.MinEncodedSize());
  562. cout << endl;
  563. }
  564. void OutputNameValuePairs(const NameValuePairs &v)
  565. {
  566. std::string names = v.GetValueNames();
  567. string::size_type i = 0;
  568. while (i < names.size())
  569. {
  570. string::size_type j = names.find_first_of (';', i);
  571. if (j == string::npos)
  572. return;
  573. else
  574. {
  575. std::string name = names.substr(i, j-i);
  576. if (name.find(':') == string::npos)
  577. OutputPair(v, name.c_str());
  578. }
  579. i = j + 1;
  580. }
  581. }
  582. void TestDataFile(const std::string &filename, const NameValuePairs &overrideParameters, unsigned int &totalTests, unsigned int &failedTests)
  583. {
  584. std::ifstream file(filename.c_str());
  585. if (!file.good())
  586. throw Exception(Exception::OTHER_ERROR, "Can not open file " + filename + " for reading");
  587. TestData v;
  588. s_currentTestData = &v;
  589. std::string name, value, lastAlgName;
  590. while (file)
  591. {
  592. while (file.peek() == '#')
  593. file.ignore(INT_MAX, '\n');
  594. if (file.peek() == '\n' || file.peek() == '\r')
  595. v.clear();
  596. if (!GetField(file, name, value))
  597. break;
  598. v[name] = value;
  599. if (name == "Test")
  600. {
  601. bool failed = true;
  602. std::string algType = GetRequiredDatum(v, "AlgorithmType");
  603. if (lastAlgName != GetRequiredDatum(v, "Name"))
  604. {
  605. lastAlgName = GetRequiredDatum(v, "Name");
  606. cout << "\nTesting " << algType.c_str() << " algorithm " << lastAlgName.c_str() << ".\n";
  607. }
  608. try
  609. {
  610. if (algType == "Signature")
  611. TestSignatureScheme(v);
  612. else if (algType == "SymmetricCipher")
  613. TestSymmetricCipher(v, overrideParameters);
  614. else if (algType == "AuthenticatedSymmetricCipher")
  615. TestAuthenticatedSymmetricCipher(v, overrideParameters);
  616. else if (algType == "AsymmetricCipher")
  617. TestAsymmetricCipher(v);
  618. else if (algType == "MessageDigest")
  619. TestDigestOrMAC(v, true);
  620. else if (algType == "MAC")
  621. TestDigestOrMAC(v, false);
  622. else if (algType == "FileList")
  623. TestDataFile(GetRequiredDatum(v, "Test"), g_nullNameValuePairs, totalTests, failedTests);
  624. else
  625. SignalTestError();
  626. failed = false;
  627. }
  628. catch (TestFailure &)
  629. {
  630. cout << "\nTest failed.\n";
  631. }
  632. catch (CryptoPP::Exception &e)
  633. {
  634. cout << "\nCryptoPP::Exception caught: " << e.what() << endl;
  635. }
  636. catch (std::exception &e)
  637. {
  638. cout << "\nstd::exception caught: " << e.what() << endl;
  639. }
  640. if (failed)
  641. {
  642. cout << "Skipping to next test.\n";
  643. failedTests++;
  644. }
  645. else
  646. cout << "." << flush;
  647. totalTests++;
  648. }
  649. }
  650. }
  651. bool RunTestDataFile(const char *filename, const NameValuePairs &overrideParameters)
  652. {
  653. unsigned int totalTests = 0, failedTests = 0;
  654. TestDataFile(filename, overrideParameters, totalTests, failedTests);
  655. cout << dec << "\nTests complete. Total tests = " << totalTests << ". Failed tests = " << failedTests << ".\n";
  656. if (failedTests != 0)
  657. cout << "SOME TESTS FAILED!\n";
  658. return failedTests == 0;
  659. }