Team Fortress 2 Source Code as on 22/4/2020
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

476 lines
11 KiB

  1. // ecp.cpp - written and placed in the public domain by Wei Dai
  2. #include "pch.h"
  3. #ifndef CRYPTOPP_IMPORTS
  4. #include "ecp.h"
  5. #include "asn.h"
  6. #include "integer.h"
  7. #include "nbtheory.h"
  8. #include "modarith.h"
  9. #include "filters.h"
  10. #include "algebra.cpp"
  11. NAMESPACE_BEGIN(CryptoPP)
  12. ANONYMOUS_NAMESPACE_BEGIN
  13. static inline ECP::Point ToMontgomery(const ModularArithmetic &mr, const ECP::Point &P)
  14. {
  15. return P.identity ? P : ECP::Point(mr.ConvertIn(P.x), mr.ConvertIn(P.y));
  16. }
  17. static inline ECP::Point FromMontgomery(const ModularArithmetic &mr, const ECP::Point &P)
  18. {
  19. return P.identity ? P : ECP::Point(mr.ConvertOut(P.x), mr.ConvertOut(P.y));
  20. }
  21. NAMESPACE_END
  22. ECP::ECP(const ECP &ecp, bool convertToMontgomeryRepresentation)
  23. {
  24. if (convertToMontgomeryRepresentation && !ecp.GetField().IsMontgomeryRepresentation())
  25. {
  26. m_fieldPtr.reset(new MontgomeryRepresentation(ecp.GetField().GetModulus()));
  27. m_a = GetField().ConvertIn(ecp.m_a);
  28. m_b = GetField().ConvertIn(ecp.m_b);
  29. }
  30. else
  31. operator=(ecp);
  32. }
  33. ECP::ECP(BufferedTransformation &bt)
  34. : m_fieldPtr(new Field(bt))
  35. {
  36. BERSequenceDecoder seq(bt);
  37. GetField().BERDecodeElement(seq, m_a);
  38. GetField().BERDecodeElement(seq, m_b);
  39. // skip optional seed
  40. if (!seq.EndReached())
  41. {
  42. SecByteBlock seed;
  43. unsigned int unused;
  44. BERDecodeBitString(seq, seed, unused);
  45. }
  46. seq.MessageEnd();
  47. }
  48. void ECP::DEREncode(BufferedTransformation &bt) const
  49. {
  50. GetField().DEREncode(bt);
  51. DERSequenceEncoder seq(bt);
  52. GetField().DEREncodeElement(seq, m_a);
  53. GetField().DEREncodeElement(seq, m_b);
  54. seq.MessageEnd();
  55. }
  56. bool ECP::DecodePoint(ECP::Point &P, const byte *encodedPoint, size_t encodedPointLen) const
  57. {
  58. StringStore store(encodedPoint, encodedPointLen);
  59. return DecodePoint(P, store, encodedPointLen);
  60. }
  61. bool ECP::DecodePoint(ECP::Point &P, BufferedTransformation &bt, size_t encodedPointLen) const
  62. {
  63. byte type;
  64. if (encodedPointLen < 1 || !bt.Get(type))
  65. return false;
  66. switch (type)
  67. {
  68. case 0:
  69. P.identity = true;
  70. return true;
  71. case 2:
  72. case 3:
  73. {
  74. if (encodedPointLen != EncodedPointSize(true))
  75. return false;
  76. Integer p = FieldSize();
  77. P.identity = false;
  78. P.x.Decode(bt, GetField().MaxElementByteLength());
  79. P.y = ((P.x*P.x+m_a)*P.x+m_b) % p;
  80. if (Jacobi(P.y, p) !=1)
  81. return false;
  82. P.y = ModularSquareRoot(P.y, p);
  83. if ((type & 1) != P.y.GetBit(0))
  84. P.y = p-P.y;
  85. return true;
  86. }
  87. case 4:
  88. {
  89. if (encodedPointLen != EncodedPointSize(false))
  90. return false;
  91. unsigned int len = GetField().MaxElementByteLength();
  92. P.identity = false;
  93. P.x.Decode(bt, len);
  94. P.y.Decode(bt, len);
  95. return true;
  96. }
  97. default:
  98. return false;
  99. }
  100. }
  101. void ECP::EncodePoint(BufferedTransformation &bt, const Point &P, bool compressed) const
  102. {
  103. if (P.identity)
  104. NullStore().TransferTo(bt, EncodedPointSize(compressed));
  105. else if (compressed)
  106. {
  107. bt.Put(2 + P.y.GetBit(0));
  108. P.x.Encode(bt, GetField().MaxElementByteLength());
  109. }
  110. else
  111. {
  112. unsigned int len = GetField().MaxElementByteLength();
  113. bt.Put(4); // uncompressed
  114. P.x.Encode(bt, len);
  115. P.y.Encode(bt, len);
  116. }
  117. }
  118. void ECP::EncodePoint(byte *encodedPoint, const Point &P, bool compressed) const
  119. {
  120. ArraySink sink(encodedPoint, EncodedPointSize(compressed));
  121. EncodePoint(sink, P, compressed);
  122. assert(sink.TotalPutLength() == EncodedPointSize(compressed));
  123. }
  124. ECP::Point ECP::BERDecodePoint(BufferedTransformation &bt) const
  125. {
  126. SecByteBlock str;
  127. BERDecodeOctetString(bt, str);
  128. Point P;
  129. if (!DecodePoint(P, str, str.size()))
  130. BERDecodeError();
  131. return P;
  132. }
  133. void ECP::DEREncodePoint(BufferedTransformation &bt, const Point &P, bool compressed) const
  134. {
  135. SecByteBlock str(EncodedPointSize(compressed));
  136. EncodePoint(str, P, compressed);
  137. DEREncodeOctetString(bt, str);
  138. }
  139. bool ECP::ValidateParameters(RandomNumberGenerator &rng, unsigned int level) const
  140. {
  141. Integer p = FieldSize();
  142. bool pass = p.IsOdd();
  143. pass = pass && !m_a.IsNegative() && m_a<p && !m_b.IsNegative() && m_b<p;
  144. if (level >= 1)
  145. pass = pass && ((4*m_a*m_a*m_a+27*m_b*m_b)%p).IsPositive();
  146. if (level >= 2)
  147. pass = pass && VerifyPrime(rng, p);
  148. return pass;
  149. }
  150. bool ECP::VerifyPoint(const Point &P) const
  151. {
  152. const FieldElement &x = P.x, &y = P.y;
  153. Integer p = FieldSize();
  154. return P.identity ||
  155. (!x.IsNegative() && x<p && !y.IsNegative() && y<p
  156. && !(((x*x+m_a)*x+m_b-y*y)%p));
  157. }
  158. bool ECP::Equal(const Point &P, const Point &Q) const
  159. {
  160. if (P.identity && Q.identity)
  161. return true;
  162. if (P.identity && !Q.identity)
  163. return false;
  164. if (!P.identity && Q.identity)
  165. return false;
  166. return (GetField().Equal(P.x,Q.x) && GetField().Equal(P.y,Q.y));
  167. }
  168. const ECP::Point& ECP::Identity() const
  169. {
  170. return Singleton<Point>().Ref();
  171. }
  172. const ECP::Point& ECP::Inverse(const Point &P) const
  173. {
  174. if (P.identity)
  175. return P;
  176. else
  177. {
  178. m_R.identity = false;
  179. m_R.x = P.x;
  180. m_R.y = GetField().Inverse(P.y);
  181. return m_R;
  182. }
  183. }
  184. const ECP::Point& ECP::Add(const Point &P, const Point &Q) const
  185. {
  186. if (P.identity) return Q;
  187. if (Q.identity) return P;
  188. if (GetField().Equal(P.x, Q.x))
  189. return GetField().Equal(P.y, Q.y) ? Double(P) : Identity();
  190. FieldElement t = GetField().Subtract(Q.y, P.y);
  191. t = GetField().Divide(t, GetField().Subtract(Q.x, P.x));
  192. FieldElement x = GetField().Subtract(GetField().Subtract(GetField().Square(t), P.x), Q.x);
  193. m_R.y = GetField().Subtract(GetField().Multiply(t, GetField().Subtract(P.x, x)), P.y);
  194. m_R.x.swap(x);
  195. m_R.identity = false;
  196. return m_R;
  197. }
  198. const ECP::Point& ECP::Double(const Point &P) const
  199. {
  200. if (P.identity || P.y==GetField().Identity()) return Identity();
  201. FieldElement t = GetField().Square(P.x);
  202. t = GetField().Add(GetField().Add(GetField().Double(t), t), m_a);
  203. t = GetField().Divide(t, GetField().Double(P.y));
  204. FieldElement x = GetField().Subtract(GetField().Subtract(GetField().Square(t), P.x), P.x);
  205. m_R.y = GetField().Subtract(GetField().Multiply(t, GetField().Subtract(P.x, x)), P.y);
  206. m_R.x.swap(x);
  207. m_R.identity = false;
  208. return m_R;
  209. }
  210. template <class T, class Iterator> void ParallelInvert(const AbstractRing<T> &ring, Iterator begin, Iterator end)
  211. {
  212. size_t n = end-begin;
  213. if (n == 1)
  214. *begin = ring.MultiplicativeInverse(*begin);
  215. else if (n > 1)
  216. {
  217. std::vector<T> vec((n+1)/2);
  218. unsigned int i;
  219. Iterator it;
  220. for (i=0, it=begin; i<n/2; i++, it+=2)
  221. vec[i] = ring.Multiply(*it, *(it+1));
  222. if (n%2 == 1)
  223. vec[n/2] = *it;
  224. ParallelInvert(ring, vec.begin(), vec.end());
  225. for (i=0, it=begin; i<n/2; i++, it+=2)
  226. {
  227. if (!vec[i])
  228. {
  229. *it = ring.MultiplicativeInverse(*it);
  230. *(it+1) = ring.MultiplicativeInverse(*(it+1));
  231. }
  232. else
  233. {
  234. std::swap(*it, *(it+1));
  235. *it = ring.Multiply(*it, vec[i]);
  236. *(it+1) = ring.Multiply(*(it+1), vec[i]);
  237. }
  238. }
  239. if (n%2 == 1)
  240. *it = vec[n/2];
  241. }
  242. }
  243. struct ProjectivePoint
  244. {
  245. ProjectivePoint() {}
  246. ProjectivePoint(const Integer &x, const Integer &y, const Integer &z)
  247. : x(x), y(y), z(z) {}
  248. Integer x,y,z;
  249. };
  250. class ProjectiveDoubling
  251. {
  252. public:
  253. ProjectiveDoubling(const ModularArithmetic &mr, const Integer &m_a, const Integer &m_b, const ECPPoint &Q)
  254. : mr(mr), firstDoubling(true), negated(false)
  255. {
  256. CRYPTOPP_UNUSED(m_b);
  257. if (Q.identity)
  258. {
  259. sixteenY4 = P.x = P.y = mr.MultiplicativeIdentity();
  260. aZ4 = P.z = mr.Identity();
  261. }
  262. else
  263. {
  264. P.x = Q.x;
  265. P.y = Q.y;
  266. sixteenY4 = P.z = mr.MultiplicativeIdentity();
  267. aZ4 = m_a;
  268. }
  269. }
  270. void Double()
  271. {
  272. twoY = mr.Double(P.y);
  273. P.z = mr.Multiply(P.z, twoY);
  274. fourY2 = mr.Square(twoY);
  275. S = mr.Multiply(fourY2, P.x);
  276. aZ4 = mr.Multiply(aZ4, sixteenY4);
  277. M = mr.Square(P.x);
  278. M = mr.Add(mr.Add(mr.Double(M), M), aZ4);
  279. P.x = mr.Square(M);
  280. mr.Reduce(P.x, S);
  281. mr.Reduce(P.x, S);
  282. mr.Reduce(S, P.x);
  283. P.y = mr.Multiply(M, S);
  284. sixteenY4 = mr.Square(fourY2);
  285. mr.Reduce(P.y, mr.Half(sixteenY4));
  286. }
  287. const ModularArithmetic &mr;
  288. ProjectivePoint P;
  289. bool firstDoubling, negated;
  290. Integer sixteenY4, aZ4, twoY, fourY2, S, M;
  291. };
  292. struct ZIterator
  293. {
  294. ZIterator() {}
  295. ZIterator(std::vector<ProjectivePoint>::iterator it) : it(it) {}
  296. Integer& operator*() {return it->z;}
  297. int operator-(ZIterator it2) {return int(it-it2.it);}
  298. ZIterator operator+(int i) {return ZIterator(it+i);}
  299. ZIterator& operator+=(int i) {it+=i; return *this;}
  300. std::vector<ProjectivePoint>::iterator it;
  301. };
  302. ECP::Point ECP::ScalarMultiply(const Point &P, const Integer &k) const
  303. {
  304. Element result;
  305. if (k.BitCount() <= 5)
  306. AbstractGroup<ECPPoint>::SimultaneousMultiply(&result, P, &k, 1);
  307. else
  308. ECP::SimultaneousMultiply(&result, P, &k, 1);
  309. return result;
  310. }
  311. void ECP::SimultaneousMultiply(ECP::Point *results, const ECP::Point &P, const Integer *expBegin, unsigned int expCount) const
  312. {
  313. if (!GetField().IsMontgomeryRepresentation())
  314. {
  315. ECP ecpmr(*this, true);
  316. const ModularArithmetic &mr = ecpmr.GetField();
  317. ecpmr.SimultaneousMultiply(results, ToMontgomery(mr, P), expBegin, expCount);
  318. for (unsigned int i=0; i<expCount; i++)
  319. results[i] = FromMontgomery(mr, results[i]);
  320. return;
  321. }
  322. ProjectiveDoubling rd(GetField(), m_a, m_b, P);
  323. std::vector<ProjectivePoint> bases;
  324. std::vector<WindowSlider> exponents;
  325. exponents.reserve(expCount);
  326. std::vector<std::vector<word32> > baseIndices(expCount);
  327. std::vector<std::vector<bool> > negateBase(expCount);
  328. std::vector<std::vector<word32> > exponentWindows(expCount);
  329. unsigned int i;
  330. for (i=0; i<expCount; i++)
  331. {
  332. assert(expBegin->NotNegative());
  333. exponents.push_back(WindowSlider(*expBegin++, InversionIsFast(), 5));
  334. exponents[i].FindNextWindow();
  335. }
  336. unsigned int expBitPosition = 0;
  337. bool notDone = true;
  338. while (notDone)
  339. {
  340. notDone = false;
  341. bool baseAdded = false;
  342. for (i=0; i<expCount; i++)
  343. {
  344. if (!exponents[i].finished && expBitPosition == exponents[i].windowBegin)
  345. {
  346. if (!baseAdded)
  347. {
  348. bases.push_back(rd.P);
  349. baseAdded =true;
  350. }
  351. exponentWindows[i].push_back(exponents[i].expWindow);
  352. baseIndices[i].push_back((word32)bases.size()-1);
  353. negateBase[i].push_back(exponents[i].negateNext);
  354. exponents[i].FindNextWindow();
  355. }
  356. notDone = notDone || !exponents[i].finished;
  357. }
  358. if (notDone)
  359. {
  360. rd.Double();
  361. expBitPosition++;
  362. }
  363. }
  364. // convert from projective to affine coordinates
  365. ParallelInvert(GetField(), ZIterator(bases.begin()), ZIterator(bases.end()));
  366. for (i=0; i<bases.size(); i++)
  367. {
  368. if (bases[i].z.NotZero())
  369. {
  370. bases[i].y = GetField().Multiply(bases[i].y, bases[i].z);
  371. bases[i].z = GetField().Square(bases[i].z);
  372. bases[i].x = GetField().Multiply(bases[i].x, bases[i].z);
  373. bases[i].y = GetField().Multiply(bases[i].y, bases[i].z);
  374. }
  375. }
  376. std::vector<BaseAndExponent<Point, Integer> > finalCascade;
  377. for (i=0; i<expCount; i++)
  378. {
  379. finalCascade.resize(baseIndices[i].size());
  380. for (unsigned int j=0; j<baseIndices[i].size(); j++)
  381. {
  382. ProjectivePoint &base = bases[baseIndices[i][j]];
  383. if (base.z.IsZero())
  384. finalCascade[j].base.identity = true;
  385. else
  386. {
  387. finalCascade[j].base.identity = false;
  388. finalCascade[j].base.x = base.x;
  389. if (negateBase[i][j])
  390. finalCascade[j].base.y = GetField().Inverse(base.y);
  391. else
  392. finalCascade[j].base.y = base.y;
  393. }
  394. finalCascade[j].exponent = Integer(Integer::POSITIVE, 0, exponentWindows[i][j]);
  395. }
  396. results[i] = GeneralCascadeMultiplication(*this, finalCascade.begin(), finalCascade.end());
  397. }
  398. }
  399. ECP::Point ECP::CascadeScalarMultiply(const Point &P, const Integer &k1, const Point &Q, const Integer &k2) const
  400. {
  401. if (!GetField().IsMontgomeryRepresentation())
  402. {
  403. ECP ecpmr(*this, true);
  404. const ModularArithmetic &mr = ecpmr.GetField();
  405. return FromMontgomery(mr, ecpmr.CascadeScalarMultiply(ToMontgomery(mr, P), k1, ToMontgomery(mr, Q), k2));
  406. }
  407. else
  408. return AbstractGroup<Point>::CascadeScalarMultiply(P, k1, Q, k2);
  409. }
  410. NAMESPACE_END
  411. #endif