Counter Strike : Global Offensive Source Code
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

473 lines
11 KiB

  1. // ecp.cpp - written and placed in the public domain by Wei Dai
  2. #include "pch.h"
  3. #ifndef CRYPTOPP_IMPORTS
  4. #include "ecp.h"
  5. #include "asn.h"
  6. #include "nbtheory.h"
  7. #include "algebra.cpp"
  8. NAMESPACE_BEGIN(CryptoPP)
  9. ANONYMOUS_NAMESPACE_BEGIN
  10. static inline ECP::Point ToMontgomery(const ModularArithmetic &mr, const ECP::Point &P)
  11. {
  12. return P.identity ? P : ECP::Point(mr.ConvertIn(P.x), mr.ConvertIn(P.y));
  13. }
  14. static inline ECP::Point FromMontgomery(const ModularArithmetic &mr, const ECP::Point &P)
  15. {
  16. return P.identity ? P : ECP::Point(mr.ConvertOut(P.x), mr.ConvertOut(P.y));
  17. }
  18. NAMESPACE_END
  19. ECP::ECP(const ECP &ecp, bool convertToMontgomeryRepresentation)
  20. {
  21. if (convertToMontgomeryRepresentation && !ecp.GetField().IsMontgomeryRepresentation())
  22. {
  23. m_fieldPtr.reset(new MontgomeryRepresentation(ecp.GetField().GetModulus()));
  24. m_a = GetField().ConvertIn(ecp.m_a);
  25. m_b = GetField().ConvertIn(ecp.m_b);
  26. }
  27. else
  28. operator=(ecp);
  29. }
  30. ECP::ECP(BufferedTransformation &bt)
  31. : m_fieldPtr(new Field(bt))
  32. {
  33. BERSequenceDecoder seq(bt);
  34. GetField().BERDecodeElement(seq, m_a);
  35. GetField().BERDecodeElement(seq, m_b);
  36. // skip optional seed
  37. if (!seq.EndReached())
  38. {
  39. SecByteBlock seed;
  40. unsigned int unused;
  41. BERDecodeBitString(seq, seed, unused);
  42. }
  43. seq.MessageEnd();
  44. }
  45. void ECP::DEREncode(BufferedTransformation &bt) const
  46. {
  47. GetField().DEREncode(bt);
  48. DERSequenceEncoder seq(bt);
  49. GetField().DEREncodeElement(seq, m_a);
  50. GetField().DEREncodeElement(seq, m_b);
  51. seq.MessageEnd();
  52. }
  53. bool ECP::DecodePoint(ECP::Point &P, const byte *encodedPoint, size_t encodedPointLen) const
  54. {
  55. StringStore store(encodedPoint, encodedPointLen);
  56. return DecodePoint(P, store, encodedPointLen);
  57. }
  58. bool ECP::DecodePoint(ECP::Point &P, BufferedTransformation &bt, size_t encodedPointLen) const
  59. {
  60. byte type;
  61. if (encodedPointLen < 1 || !bt.Get(type))
  62. return false;
  63. switch (type)
  64. {
  65. case 0:
  66. P.identity = true;
  67. return true;
  68. case 2:
  69. case 3:
  70. {
  71. if (encodedPointLen != EncodedPointSize(true))
  72. return false;
  73. Integer p = FieldSize();
  74. P.identity = false;
  75. P.x.Decode(bt, GetField().MaxElementByteLength());
  76. P.y = ((P.x*P.x+m_a)*P.x+m_b) % p;
  77. if (Jacobi(P.y, p) !=1)
  78. return false;
  79. P.y = ModularSquareRoot(P.y, p);
  80. if ((type & 1) != P.y.GetBit(0))
  81. P.y = p-P.y;
  82. return true;
  83. }
  84. case 4:
  85. {
  86. if (encodedPointLen != EncodedPointSize(false))
  87. return false;
  88. unsigned int len = GetField().MaxElementByteLength();
  89. P.identity = false;
  90. P.x.Decode(bt, len);
  91. P.y.Decode(bt, len);
  92. return true;
  93. }
  94. default:
  95. return false;
  96. }
  97. }
  98. void ECP::EncodePoint(BufferedTransformation &bt, const Point &P, bool compressed) const
  99. {
  100. if (P.identity)
  101. NullStore().TransferTo(bt, EncodedPointSize(compressed));
  102. else if (compressed)
  103. {
  104. bt.Put(2 + P.y.GetBit(0));
  105. P.x.Encode(bt, GetField().MaxElementByteLength());
  106. }
  107. else
  108. {
  109. unsigned int len = GetField().MaxElementByteLength();
  110. bt.Put(4); // uncompressed
  111. P.x.Encode(bt, len);
  112. P.y.Encode(bt, len);
  113. }
  114. }
  115. void ECP::EncodePoint(byte *encodedPoint, const Point &P, bool compressed) const
  116. {
  117. ArraySink sink(encodedPoint, EncodedPointSize(compressed));
  118. EncodePoint(sink, P, compressed);
  119. assert(sink.TotalPutLength() == EncodedPointSize(compressed));
  120. }
  121. ECP::Point ECP::BERDecodePoint(BufferedTransformation &bt) const
  122. {
  123. SecByteBlock str;
  124. BERDecodeOctetString(bt, str);
  125. Point P;
  126. if (!DecodePoint(P, str, str.size()))
  127. BERDecodeError();
  128. return P;
  129. }
  130. void ECP::DEREncodePoint(BufferedTransformation &bt, const Point &P, bool compressed) const
  131. {
  132. SecByteBlock str(EncodedPointSize(compressed));
  133. EncodePoint(str, P, compressed);
  134. DEREncodeOctetString(bt, str);
  135. }
  136. bool ECP::ValidateParameters(RandomNumberGenerator &rng, unsigned int level) const
  137. {
  138. Integer p = FieldSize();
  139. bool pass = p.IsOdd();
  140. pass = pass && !m_a.IsNegative() && m_a<p && !m_b.IsNegative() && m_b<p;
  141. if (level >= 1)
  142. pass = pass && ((4*m_a*m_a*m_a+27*m_b*m_b)%p).IsPositive();
  143. if (level >= 2)
  144. pass = pass && VerifyPrime(rng, p);
  145. return pass;
  146. }
  147. bool ECP::VerifyPoint(const Point &P) const
  148. {
  149. const FieldElement &x = P.x, &y = P.y;
  150. Integer p = FieldSize();
  151. return P.identity ||
  152. (!x.IsNegative() && x<p && !y.IsNegative() && y<p
  153. && !(((x*x+m_a)*x+m_b-y*y)%p));
  154. }
  155. bool ECP::Equal(const Point &P, const Point &Q) const
  156. {
  157. if (P.identity && Q.identity)
  158. return true;
  159. if (P.identity && !Q.identity)
  160. return false;
  161. if (!P.identity && Q.identity)
  162. return false;
  163. return (GetField().Equal(P.x,Q.x) && GetField().Equal(P.y,Q.y));
  164. }
  165. const ECP::Point& ECP::Identity() const
  166. {
  167. return Singleton<Point>().Ref();
  168. }
  169. const ECP::Point& ECP::Inverse(const Point &P) const
  170. {
  171. if (P.identity)
  172. return P;
  173. else
  174. {
  175. m_R.identity = false;
  176. m_R.x = P.x;
  177. m_R.y = GetField().Inverse(P.y);
  178. return m_R;
  179. }
  180. }
  181. const ECP::Point& ECP::Add(const Point &P, const Point &Q) const
  182. {
  183. if (P.identity) return Q;
  184. if (Q.identity) return P;
  185. if (GetField().Equal(P.x, Q.x))
  186. return GetField().Equal(P.y, Q.y) ? Double(P) : Identity();
  187. FieldElement t = GetField().Subtract(Q.y, P.y);
  188. t = GetField().Divide(t, GetField().Subtract(Q.x, P.x));
  189. FieldElement x = GetField().Subtract(GetField().Subtract(GetField().Square(t), P.x), Q.x);
  190. m_R.y = GetField().Subtract(GetField().Multiply(t, GetField().Subtract(P.x, x)), P.y);
  191. m_R.x.swap(x);
  192. m_R.identity = false;
  193. return m_R;
  194. }
  195. const ECP::Point& ECP::Double(const Point &P) const
  196. {
  197. if (P.identity || P.y==GetField().Identity()) return Identity();
  198. FieldElement t = GetField().Square(P.x);
  199. t = GetField().Add(GetField().Add(GetField().Double(t), t), m_a);
  200. t = GetField().Divide(t, GetField().Double(P.y));
  201. FieldElement x = GetField().Subtract(GetField().Subtract(GetField().Square(t), P.x), P.x);
  202. m_R.y = GetField().Subtract(GetField().Multiply(t, GetField().Subtract(P.x, x)), P.y);
  203. m_R.x.swap(x);
  204. m_R.identity = false;
  205. return m_R;
  206. }
  207. template <class T, class Iterator> void ParallelInvert(const AbstractRing<T> &ring, Iterator begin, Iterator end)
  208. {
  209. size_t n = end-begin;
  210. if (n == 1)
  211. *begin = ring.MultiplicativeInverse(*begin);
  212. else if (n > 1)
  213. {
  214. std::vector<T> vec((n+1)/2);
  215. unsigned int i;
  216. Iterator it;
  217. for (i=0, it=begin; i<n/2; i++, it+=2)
  218. vec[i] = ring.Multiply(*it, *(it+1));
  219. if (n%2 == 1)
  220. vec[n/2] = *it;
  221. ParallelInvert(ring, vec.begin(), vec.end());
  222. for (i=0, it=begin; i<n/2; i++, it+=2)
  223. {
  224. if (!vec[i])
  225. {
  226. *it = ring.MultiplicativeInverse(*it);
  227. *(it+1) = ring.MultiplicativeInverse(*(it+1));
  228. }
  229. else
  230. {
  231. std::swap(*it, *(it+1));
  232. *it = ring.Multiply(*it, vec[i]);
  233. *(it+1) = ring.Multiply(*(it+1), vec[i]);
  234. }
  235. }
  236. if (n%2 == 1)
  237. *it = vec[n/2];
  238. }
  239. }
  240. struct ProjectivePoint
  241. {
  242. ProjectivePoint() {}
  243. ProjectivePoint(const Integer &x, const Integer &y, const Integer &z)
  244. : x(x), y(y), z(z) {}
  245. Integer x,y,z;
  246. };
  247. class ProjectiveDoubling
  248. {
  249. public:
  250. ProjectiveDoubling(const ModularArithmetic &mr, const Integer &m_a, const Integer &m_b, const ECPPoint &Q)
  251. : mr(mr), firstDoubling(true), negated(false)
  252. {
  253. if (Q.identity)
  254. {
  255. sixteenY4 = P.x = P.y = mr.MultiplicativeIdentity();
  256. aZ4 = P.z = mr.Identity();
  257. }
  258. else
  259. {
  260. P.x = Q.x;
  261. P.y = Q.y;
  262. sixteenY4 = P.z = mr.MultiplicativeIdentity();
  263. aZ4 = m_a;
  264. }
  265. }
  266. void Double()
  267. {
  268. twoY = mr.Double(P.y);
  269. P.z = mr.Multiply(P.z, twoY);
  270. fourY2 = mr.Square(twoY);
  271. S = mr.Multiply(fourY2, P.x);
  272. aZ4 = mr.Multiply(aZ4, sixteenY4);
  273. M = mr.Square(P.x);
  274. M = mr.Add(mr.Add(mr.Double(M), M), aZ4);
  275. P.x = mr.Square(M);
  276. mr.Reduce(P.x, S);
  277. mr.Reduce(P.x, S);
  278. mr.Reduce(S, P.x);
  279. P.y = mr.Multiply(M, S);
  280. sixteenY4 = mr.Square(fourY2);
  281. mr.Reduce(P.y, mr.Half(sixteenY4));
  282. }
  283. const ModularArithmetic &mr;
  284. ProjectivePoint P;
  285. bool firstDoubling, negated;
  286. Integer sixteenY4, aZ4, twoY, fourY2, S, M;
  287. };
  288. struct ZIterator
  289. {
  290. ZIterator() {}
  291. ZIterator(std::vector<ProjectivePoint>::iterator it) : it(it) {}
  292. Integer& operator*() {return it->z;}
  293. int operator-(ZIterator it2) {return int(it-it2.it);}
  294. ZIterator operator+(int i) {return ZIterator(it+i);}
  295. ZIterator& operator+=(int i) {it+=i; return *this;}
  296. std::vector<ProjectivePoint>::iterator it;
  297. };
  298. ECP::Point ECP::ScalarMultiply(const Point &P, const Integer &k) const
  299. {
  300. Element result;
  301. if (k.BitCount() <= 5)
  302. AbstractGroup<ECPPoint>::SimultaneousMultiply(&result, P, &k, 1);
  303. else
  304. ECP::SimultaneousMultiply(&result, P, &k, 1);
  305. return result;
  306. }
  307. void ECP::SimultaneousMultiply(ECP::Point *results, const ECP::Point &P, const Integer *expBegin, unsigned int expCount) const
  308. {
  309. if (!GetField().IsMontgomeryRepresentation())
  310. {
  311. ECP ecpmr(*this, true);
  312. const ModularArithmetic &mr = ecpmr.GetField();
  313. ecpmr.SimultaneousMultiply(results, ToMontgomery(mr, P), expBegin, expCount);
  314. for (unsigned int i=0; i<expCount; i++)
  315. results[i] = FromMontgomery(mr, results[i]);
  316. return;
  317. }
  318. ProjectiveDoubling rd(GetField(), m_a, m_b, P);
  319. std::vector<ProjectivePoint> bases;
  320. std::vector<WindowSlider> exponents;
  321. exponents.reserve(expCount);
  322. std::vector<std::vector<word32> > baseIndices(expCount);
  323. std::vector<std::vector<bool> > negateBase(expCount);
  324. std::vector<std::vector<word32> > exponentWindows(expCount);
  325. unsigned int i;
  326. for (i=0; i<expCount; i++)
  327. {
  328. assert(expBegin->NotNegative());
  329. exponents.push_back(WindowSlider(*expBegin++, InversionIsFast(), 5));
  330. exponents[i].FindNextWindow();
  331. }
  332. unsigned int expBitPosition = 0;
  333. bool notDone = true;
  334. while (notDone)
  335. {
  336. notDone = false;
  337. bool baseAdded = false;
  338. for (i=0; i<expCount; i++)
  339. {
  340. if (!exponents[i].finished && expBitPosition == exponents[i].windowBegin)
  341. {
  342. if (!baseAdded)
  343. {
  344. bases.push_back(rd.P);
  345. baseAdded =true;
  346. }
  347. exponentWindows[i].push_back(exponents[i].expWindow);
  348. baseIndices[i].push_back((word32)bases.size()-1);
  349. negateBase[i].push_back(exponents[i].negateNext);
  350. exponents[i].FindNextWindow();
  351. }
  352. notDone = notDone || !exponents[i].finished;
  353. }
  354. if (notDone)
  355. {
  356. rd.Double();
  357. expBitPosition++;
  358. }
  359. }
  360. // convert from projective to affine coordinates
  361. ParallelInvert(GetField(), ZIterator(bases.begin()), ZIterator(bases.end()));
  362. for (i=0; i<bases.size(); i++)
  363. {
  364. if (bases[i].z.NotZero())
  365. {
  366. bases[i].y = GetField().Multiply(bases[i].y, bases[i].z);
  367. bases[i].z = GetField().Square(bases[i].z);
  368. bases[i].x = GetField().Multiply(bases[i].x, bases[i].z);
  369. bases[i].y = GetField().Multiply(bases[i].y, bases[i].z);
  370. }
  371. }
  372. std::vector<BaseAndExponent<Point, Integer> > finalCascade;
  373. for (i=0; i<expCount; i++)
  374. {
  375. finalCascade.resize(baseIndices[i].size());
  376. for (unsigned int j=0; j<baseIndices[i].size(); j++)
  377. {
  378. ProjectivePoint &base = bases[baseIndices[i][j]];
  379. if (base.z.IsZero())
  380. finalCascade[j].base.identity = true;
  381. else
  382. {
  383. finalCascade[j].base.identity = false;
  384. finalCascade[j].base.x = base.x;
  385. if (negateBase[i][j])
  386. finalCascade[j].base.y = GetField().Inverse(base.y);
  387. else
  388. finalCascade[j].base.y = base.y;
  389. }
  390. finalCascade[j].exponent = Integer(Integer::POSITIVE, 0, exponentWindows[i][j]);
  391. }
  392. results[i] = GeneralCascadeMultiplication(*this, finalCascade.begin(), finalCascade.end());
  393. }
  394. }
  395. ECP::Point ECP::CascadeScalarMultiply(const Point &P, const Integer &k1, const Point &Q, const Integer &k2) const
  396. {
  397. if (!GetField().IsMontgomeryRepresentation())
  398. {
  399. ECP ecpmr(*this, true);
  400. const ModularArithmetic &mr = ecpmr.GetField();
  401. return FromMontgomery(mr, ecpmr.CascadeScalarMultiply(ToMontgomery(mr, P), k1, ToMontgomery(mr, Q), k2));
  402. }
  403. else
  404. return AbstractGroup<Point>::CascadeScalarMultiply(P, k1, Q, k2);
  405. }
  406. NAMESPACE_END
  407. #endif