Counter Strike : Global Offensive Source Code
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

421 lines
11 KiB

  1. // ida.cpp - written and placed in the public domain by Wei Dai
  2. #include "pch.h"
  3. #include "ida.h"
  4. #include "algebra.h"
  5. #include "gf2_32.h"
  6. #include "polynomi.h"
  7. #include <functional>
  8. #include "polynomi.cpp"
  9. ANONYMOUS_NAMESPACE_BEGIN
  10. static const CryptoPP::GF2_32 field;
  11. NAMESPACE_END
  12. using namespace std;
  13. NAMESPACE_BEGIN(CryptoPP)
  14. void RawIDA::IsolatedInitialize(const NameValuePairs &parameters)
  15. {
  16. if (!parameters.GetIntValue("RecoveryThreshold", m_threshold))
  17. throw InvalidArgument("RawIDA: missing RecoveryThreshold argument");
  18. if (m_threshold <= 0)
  19. throw InvalidArgument("RawIDA: RecoveryThreshold must be greater than 0");
  20. m_lastMapPosition = m_inputChannelMap.end();
  21. m_channelsReady = 0;
  22. m_channelsFinished = 0;
  23. m_w.New(m_threshold);
  24. m_y.New(m_threshold);
  25. m_inputQueues.reserve(m_threshold);
  26. m_outputChannelIds.clear();
  27. m_outputChannelIdStrings.clear();
  28. m_outputQueues.clear();
  29. word32 outputChannelID;
  30. if (parameters.GetValue("OutputChannelID", outputChannelID))
  31. AddOutputChannel(outputChannelID);
  32. else
  33. {
  34. int nShares = parameters.GetIntValueWithDefault("NumberOfShares", m_threshold);
  35. for (int i=0; i<nShares; i++)
  36. AddOutputChannel(i);
  37. }
  38. }
  39. unsigned int RawIDA::InsertInputChannel(word32 channelId)
  40. {
  41. if (m_lastMapPosition != m_inputChannelMap.end())
  42. {
  43. if (m_lastMapPosition->first == channelId)
  44. goto skipFind;
  45. ++m_lastMapPosition;
  46. if (m_lastMapPosition != m_inputChannelMap.end() && m_lastMapPosition->first == channelId)
  47. goto skipFind;
  48. }
  49. m_lastMapPosition = m_inputChannelMap.find(channelId);
  50. skipFind:
  51. if (m_lastMapPosition == m_inputChannelMap.end())
  52. {
  53. if (m_inputChannelIds.size() == m_threshold)
  54. return m_threshold;
  55. m_lastMapPosition = m_inputChannelMap.insert(InputChannelMap::value_type(channelId, (unsigned int)m_inputChannelIds.size())).first;
  56. m_inputQueues.push_back(MessageQueue());
  57. m_inputChannelIds.push_back(channelId);
  58. if (m_inputChannelIds.size() == m_threshold)
  59. PrepareInterpolation();
  60. }
  61. return m_lastMapPosition->second;
  62. }
  63. unsigned int RawIDA::LookupInputChannel(word32 channelId) const
  64. {
  65. map<word32, unsigned int>::const_iterator it = m_inputChannelMap.find(channelId);
  66. if (it == m_inputChannelMap.end())
  67. return m_threshold;
  68. else
  69. return it->second;
  70. }
  71. void RawIDA::ChannelData(word32 channelId, const byte *inString, size_t length, bool messageEnd)
  72. {
  73. int i = InsertInputChannel(channelId);
  74. if (i < m_threshold)
  75. {
  76. lword size = m_inputQueues[i].MaxRetrievable();
  77. m_inputQueues[i].Put(inString, length);
  78. if (size < 4 && size + length >= 4)
  79. {
  80. m_channelsReady++;
  81. if (m_channelsReady == m_threshold)
  82. ProcessInputQueues();
  83. }
  84. if (messageEnd)
  85. {
  86. m_inputQueues[i].MessageEnd();
  87. if (m_inputQueues[i].NumberOfMessages() == 1)
  88. {
  89. m_channelsFinished++;
  90. if (m_channelsFinished == m_threshold)
  91. {
  92. m_channelsReady = 0;
  93. for (i=0; i<m_threshold; i++)
  94. m_channelsReady += m_inputQueues[i].AnyRetrievable();
  95. ProcessInputQueues();
  96. }
  97. }
  98. }
  99. }
  100. }
  101. lword RawIDA::InputBuffered(word32 channelId) const
  102. {
  103. int i = LookupInputChannel(channelId);
  104. return i < m_threshold ? m_inputQueues[i].MaxRetrievable() : 0;
  105. }
  106. void RawIDA::ComputeV(unsigned int i)
  107. {
  108. if (i >= m_v.size())
  109. {
  110. m_v.resize(i+1);
  111. m_outputToInput.resize(i+1);
  112. }
  113. m_outputToInput[i] = LookupInputChannel(m_outputChannelIds[i]);
  114. if (m_outputToInput[i] == m_threshold && i * m_threshold <= 1000*1000)
  115. {
  116. m_v[i].resize(m_threshold);
  117. PrepareBulkPolynomialInterpolationAt(field, m_v[i].begin(), m_outputChannelIds[i], &(m_inputChannelIds[0]), m_w.begin(), m_threshold);
  118. }
  119. }
  120. void RawIDA::AddOutputChannel(word32 channelId)
  121. {
  122. m_outputChannelIds.push_back(channelId);
  123. m_outputChannelIdStrings.push_back(WordToString(channelId));
  124. m_outputQueues.push_back(ByteQueue());
  125. if (m_inputChannelIds.size() == m_threshold)
  126. ComputeV((unsigned int)m_outputChannelIds.size() - 1);
  127. }
  128. void RawIDA::PrepareInterpolation()
  129. {
  130. assert(m_inputChannelIds.size() == m_threshold);
  131. PrepareBulkPolynomialInterpolation(field, m_w.begin(), &(m_inputChannelIds[0]), m_threshold);
  132. for (unsigned int i=0; i<m_outputChannelIds.size(); i++)
  133. ComputeV(i);
  134. }
  135. void RawIDA::ProcessInputQueues()
  136. {
  137. bool finished = (m_channelsFinished == m_threshold);
  138. int i;
  139. while (finished ? m_channelsReady > 0 : m_channelsReady == m_threshold)
  140. {
  141. m_channelsReady = 0;
  142. for (i=0; i<m_threshold; i++)
  143. {
  144. MessageQueue &queue = m_inputQueues[i];
  145. queue.GetWord32(m_y[i]);
  146. if (finished)
  147. m_channelsReady += queue.AnyRetrievable();
  148. else
  149. m_channelsReady += queue.NumberOfMessages() > 0 || queue.MaxRetrievable() >= 4;
  150. }
  151. for (i=0; (unsigned int)i<m_outputChannelIds.size(); i++)
  152. {
  153. if (m_outputToInput[i] != m_threshold)
  154. m_outputQueues[i].PutWord32(m_y[m_outputToInput[i]]);
  155. else if (m_v[i].size() == m_threshold)
  156. m_outputQueues[i].PutWord32(BulkPolynomialInterpolateAt(field, m_y.begin(), m_v[i].begin(), m_threshold));
  157. else
  158. {
  159. m_u.resize(m_threshold);
  160. PrepareBulkPolynomialInterpolationAt(field, m_u.begin(), m_outputChannelIds[i], &(m_inputChannelIds[0]), m_w.begin(), m_threshold);
  161. m_outputQueues[i].PutWord32(BulkPolynomialInterpolateAt(field, m_y.begin(), m_u.begin(), m_threshold));
  162. }
  163. }
  164. }
  165. if (m_outputChannelIds.size() > 0 && m_outputQueues[0].AnyRetrievable())
  166. FlushOutputQueues();
  167. if (finished)
  168. {
  169. OutputMessageEnds();
  170. m_channelsReady = 0;
  171. m_channelsFinished = 0;
  172. m_v.clear();
  173. vector<MessageQueue> inputQueues;
  174. vector<word32> inputChannelIds;
  175. inputQueues.swap(m_inputQueues);
  176. inputChannelIds.swap(m_inputChannelIds);
  177. m_inputChannelMap.clear();
  178. m_lastMapPosition = m_inputChannelMap.end();
  179. for (i=0; i<m_threshold; i++)
  180. {
  181. inputQueues[i].GetNextMessage();
  182. inputQueues[i].TransferAllTo(*AttachedTransformation(), WordToString(inputChannelIds[i]));
  183. }
  184. }
  185. }
  186. void RawIDA::FlushOutputQueues()
  187. {
  188. for (unsigned int i=0; i<m_outputChannelIds.size(); i++)
  189. m_outputQueues[i].TransferAllTo(*AttachedTransformation(), m_outputChannelIdStrings[i]);
  190. }
  191. void RawIDA::OutputMessageEnds()
  192. {
  193. if (GetAutoSignalPropagation() != 0)
  194. {
  195. for (unsigned int i=0; i<m_outputChannelIds.size(); i++)
  196. AttachedTransformation()->ChannelMessageEnd(m_outputChannelIdStrings[i], GetAutoSignalPropagation()-1);
  197. }
  198. }
  199. // ****************************************************************
  200. void SecretSharing::IsolatedInitialize(const NameValuePairs &parameters)
  201. {
  202. m_pad = parameters.GetValueWithDefault("AddPadding", true);
  203. m_ida.IsolatedInitialize(parameters);
  204. }
  205. size_t SecretSharing::Put2(const byte *begin, size_t length, int messageEnd, bool blocking)
  206. {
  207. if (!blocking)
  208. throw BlockingInputOnly("SecretSharing");
  209. SecByteBlock buf(UnsignedMin(256, length));
  210. unsigned int threshold = m_ida.GetThreshold();
  211. while (length > 0)
  212. {
  213. size_t len = STDMIN(length, buf.size());
  214. m_ida.ChannelData(0xffffffff, begin, len, false);
  215. for (unsigned int i=0; i<threshold-1; i++)
  216. {
  217. m_rng.GenerateBlock(buf, len);
  218. m_ida.ChannelData(i, buf, len, false);
  219. }
  220. length -= len;
  221. begin += len;
  222. }
  223. if (messageEnd)
  224. {
  225. m_ida.SetAutoSignalPropagation(messageEnd-1);
  226. if (m_pad)
  227. {
  228. SecretSharing::Put(1);
  229. while (m_ida.InputBuffered(0xffffffff) > 0)
  230. SecretSharing::Put(0);
  231. }
  232. m_ida.ChannelData(0xffffffff, NULL, 0, true);
  233. for (unsigned int i=0; i<m_ida.GetThreshold()-1; i++)
  234. m_ida.ChannelData(i, NULL, 0, true);
  235. }
  236. return 0;
  237. }
  238. void SecretRecovery::IsolatedInitialize(const NameValuePairs &parameters)
  239. {
  240. m_pad = parameters.GetValueWithDefault("RemovePadding", true);
  241. RawIDA::IsolatedInitialize(CombinedNameValuePairs(parameters, MakeParameters("OutputChannelID", (word32)0xffffffff)));
  242. }
  243. void SecretRecovery::FlushOutputQueues()
  244. {
  245. if (m_pad)
  246. m_outputQueues[0].TransferTo(*AttachedTransformation(), m_outputQueues[0].MaxRetrievable()-4);
  247. else
  248. m_outputQueues[0].TransferTo(*AttachedTransformation());
  249. }
  250. void SecretRecovery::OutputMessageEnds()
  251. {
  252. if (m_pad)
  253. {
  254. PaddingRemover paddingRemover(new Redirector(*AttachedTransformation()));
  255. m_outputQueues[0].TransferAllTo(paddingRemover);
  256. }
  257. if (GetAutoSignalPropagation() != 0)
  258. AttachedTransformation()->MessageEnd(GetAutoSignalPropagation()-1);
  259. }
  260. // ****************************************************************
  261. void InformationDispersal::IsolatedInitialize(const NameValuePairs &parameters)
  262. {
  263. m_nextChannel = 0;
  264. m_pad = parameters.GetValueWithDefault("AddPadding", true);
  265. m_ida.IsolatedInitialize(parameters);
  266. }
  267. size_t InformationDispersal::Put2(const byte *begin, size_t length, int messageEnd, bool blocking)
  268. {
  269. if (!blocking)
  270. throw BlockingInputOnly("InformationDispersal");
  271. while (length--)
  272. {
  273. m_ida.ChannelData(m_nextChannel, begin, 1, false);
  274. begin++;
  275. m_nextChannel++;
  276. if (m_nextChannel == m_ida.GetThreshold())
  277. m_nextChannel = 0;
  278. }
  279. if (messageEnd)
  280. {
  281. m_ida.SetAutoSignalPropagation(messageEnd-1);
  282. if (m_pad)
  283. InformationDispersal::Put(1);
  284. for (word32 i=0; i<m_ida.GetThreshold(); i++)
  285. m_ida.ChannelData(i, NULL, 0, true);
  286. }
  287. return 0;
  288. }
  289. void InformationRecovery::IsolatedInitialize(const NameValuePairs &parameters)
  290. {
  291. m_pad = parameters.GetValueWithDefault("RemovePadding", true);
  292. RawIDA::IsolatedInitialize(parameters);
  293. }
  294. void InformationRecovery::FlushOutputQueues()
  295. {
  296. while (m_outputQueues[0].AnyRetrievable())
  297. {
  298. for (unsigned int i=0; i<m_outputChannelIds.size(); i++)
  299. m_outputQueues[i].TransferTo(m_queue, 1);
  300. }
  301. if (m_pad)
  302. m_queue.TransferTo(*AttachedTransformation(), m_queue.MaxRetrievable()-4*m_threshold);
  303. else
  304. m_queue.TransferTo(*AttachedTransformation());
  305. }
  306. void InformationRecovery::OutputMessageEnds()
  307. {
  308. if (m_pad)
  309. {
  310. PaddingRemover paddingRemover(new Redirector(*AttachedTransformation()));
  311. m_queue.TransferAllTo(paddingRemover);
  312. }
  313. if (GetAutoSignalPropagation() != 0)
  314. AttachedTransformation()->MessageEnd(GetAutoSignalPropagation()-1);
  315. }
  316. size_t PaddingRemover::Put2(const byte *begin, size_t length, int messageEnd, bool blocking)
  317. {
  318. if (!blocking)
  319. throw BlockingInputOnly("PaddingRemover");
  320. const byte *const end = begin + length;
  321. if (m_possiblePadding)
  322. {
  323. size_t len = find_if(begin, end, bind2nd(not_equal_to<byte>(), 0)) - begin;
  324. m_zeroCount += len;
  325. begin += len;
  326. if (begin == end)
  327. return 0;
  328. AttachedTransformation()->Put(1);
  329. while (m_zeroCount--)
  330. AttachedTransformation()->Put(0);
  331. AttachedTransformation()->Put(*begin++);
  332. m_possiblePadding = false;
  333. }
  334. #if defined(_MSC_VER) && !defined(__MWERKS__) && (_MSC_VER <= 1300)
  335. // VC60 and VC7 workaround: built-in reverse_iterator has two template parameters, Dinkumware only has one
  336. typedef reverse_bidirectional_iterator<const byte *, const byte> RevIt;
  337. #elif defined(_RWSTD_NO_CLASS_PARTIAL_SPEC)
  338. typedef reverse_iterator<const byte *, random_access_iterator_tag, const byte> RevIt;
  339. #else
  340. typedef reverse_iterator<const byte *> RevIt;
  341. #endif
  342. const byte *x = find_if(RevIt(end), RevIt(begin), bind2nd(not_equal_to<byte>(), 0)).base();
  343. if (x != begin && *(x-1) == 1)
  344. {
  345. AttachedTransformation()->Put(begin, x-begin-1);
  346. m_possiblePadding = true;
  347. m_zeroCount = end - x;
  348. }
  349. else
  350. AttachedTransformation()->Put(begin, end-begin);
  351. if (messageEnd)
  352. {
  353. m_possiblePadding = false;
  354. Output(0, begin, length, messageEnd, blocking);
  355. }
  356. return 0;
  357. }
  358. NAMESPACE_END