Team Fortress 2 Source Code as on 22/4/2020
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

423 lines
11 KiB

  1. // ida.cpp - written and placed in the public domain by Wei Dai
  2. #include "pch.h"
  3. #include "config.h"
  4. #include "ida.h"
  5. #include "algebra.h"
  6. #include "gf2_32.h"
  7. #include "polynomi.h"
  8. #include "polynomi.cpp"
  9. #include <functional>
  10. ANONYMOUS_NAMESPACE_BEGIN
  11. static const CryptoPP::GF2_32 field;
  12. NAMESPACE_END
  13. NAMESPACE_BEGIN(CryptoPP)
  14. void RawIDA::IsolatedInitialize(const NameValuePairs &parameters)
  15. {
  16. if (!parameters.GetIntValue("RecoveryThreshold", m_threshold))
  17. throw InvalidArgument("RawIDA: missing RecoveryThreshold argument");
  18. assert(m_threshold > 0);
  19. if (m_threshold <= 0)
  20. throw InvalidArgument("RawIDA: RecoveryThreshold must be greater than 0");
  21. m_lastMapPosition = m_inputChannelMap.end();
  22. m_channelsReady = 0;
  23. m_channelsFinished = 0;
  24. m_w.New(m_threshold);
  25. m_y.New(m_threshold);
  26. m_inputQueues.reserve(m_threshold);
  27. m_outputChannelIds.clear();
  28. m_outputChannelIdStrings.clear();
  29. m_outputQueues.clear();
  30. word32 outputChannelID;
  31. if (parameters.GetValue("OutputChannelID", outputChannelID))
  32. AddOutputChannel(outputChannelID);
  33. else
  34. {
  35. int nShares = parameters.GetIntValueWithDefault("NumberOfShares", m_threshold);
  36. assert(nShares > 0);
  37. if (nShares <= 0) {nShares = m_threshold;}
  38. for (unsigned int i=0; i< (unsigned int)(nShares); i++)
  39. AddOutputChannel(i);
  40. }
  41. }
  42. unsigned int RawIDA::InsertInputChannel(word32 channelId)
  43. {
  44. if (m_lastMapPosition != m_inputChannelMap.end())
  45. {
  46. if (m_lastMapPosition->first == channelId)
  47. goto skipFind;
  48. ++m_lastMapPosition;
  49. if (m_lastMapPosition != m_inputChannelMap.end() && m_lastMapPosition->first == channelId)
  50. goto skipFind;
  51. }
  52. m_lastMapPosition = m_inputChannelMap.find(channelId);
  53. skipFind:
  54. if (m_lastMapPosition == m_inputChannelMap.end())
  55. {
  56. if (m_inputChannelIds.size() == size_t(m_threshold))
  57. return m_threshold;
  58. m_lastMapPosition = m_inputChannelMap.insert(InputChannelMap::value_type(channelId, (unsigned int)m_inputChannelIds.size())).first;
  59. m_inputQueues.push_back(MessageQueue());
  60. m_inputChannelIds.push_back(channelId);
  61. if (m_inputChannelIds.size() == size_t(m_threshold))
  62. PrepareInterpolation();
  63. }
  64. return m_lastMapPosition->second;
  65. }
  66. unsigned int RawIDA::LookupInputChannel(word32 channelId) const
  67. {
  68. std::map<word32, unsigned int>::const_iterator it = m_inputChannelMap.find(channelId);
  69. if (it == m_inputChannelMap.end())
  70. return m_threshold;
  71. else
  72. return it->second;
  73. }
  74. void RawIDA::ChannelData(word32 channelId, const byte *inString, size_t length, bool messageEnd)
  75. {
  76. int i = InsertInputChannel(channelId);
  77. if (i < m_threshold)
  78. {
  79. lword size = m_inputQueues[i].MaxRetrievable();
  80. m_inputQueues[i].Put(inString, length);
  81. if (size < 4 && size + length >= 4)
  82. {
  83. m_channelsReady++;
  84. if (m_channelsReady == size_t(m_threshold))
  85. ProcessInputQueues();
  86. }
  87. if (messageEnd)
  88. {
  89. m_inputQueues[i].MessageEnd();
  90. if (m_inputQueues[i].NumberOfMessages() == 1)
  91. {
  92. m_channelsFinished++;
  93. if (m_channelsFinished == size_t(m_threshold))
  94. {
  95. m_channelsReady = 0;
  96. for (i=0; i<m_threshold; i++)
  97. m_channelsReady += m_inputQueues[i].AnyRetrievable();
  98. ProcessInputQueues();
  99. }
  100. }
  101. }
  102. }
  103. }
  104. lword RawIDA::InputBuffered(word32 channelId) const
  105. {
  106. int i = LookupInputChannel(channelId);
  107. return i < m_threshold ? m_inputQueues[i].MaxRetrievable() : 0;
  108. }
  109. void RawIDA::ComputeV(unsigned int i)
  110. {
  111. if (i >= m_v.size())
  112. {
  113. m_v.resize(i+1);
  114. m_outputToInput.resize(i+1);
  115. }
  116. m_outputToInput[i] = LookupInputChannel(m_outputChannelIds[i]);
  117. if (m_outputToInput[i] == size_t(m_threshold) && i * size_t(m_threshold) <= 1000*1000)
  118. {
  119. m_v[i].resize(m_threshold);
  120. PrepareBulkPolynomialInterpolationAt(field, m_v[i].begin(), m_outputChannelIds[i], &(m_inputChannelIds[0]), m_w.begin(), m_threshold);
  121. }
  122. }
  123. void RawIDA::AddOutputChannel(word32 channelId)
  124. {
  125. m_outputChannelIds.push_back(channelId);
  126. m_outputChannelIdStrings.push_back(WordToString(channelId));
  127. m_outputQueues.push_back(ByteQueue());
  128. if (m_inputChannelIds.size() == size_t(m_threshold))
  129. ComputeV((unsigned int)m_outputChannelIds.size() - 1);
  130. }
  131. void RawIDA::PrepareInterpolation()
  132. {
  133. assert(m_inputChannelIds.size() == size_t(m_threshold));
  134. PrepareBulkPolynomialInterpolation(field, m_w.begin(), &(m_inputChannelIds[0]), (unsigned int)(m_threshold));
  135. for (unsigned int i=0; i<m_outputChannelIds.size(); i++)
  136. ComputeV(i);
  137. }
  138. void RawIDA::ProcessInputQueues()
  139. {
  140. bool finished = (m_channelsFinished == size_t(m_threshold));
  141. unsigned int i;
  142. while (finished ? m_channelsReady > 0 : m_channelsReady == size_t(m_threshold))
  143. {
  144. m_channelsReady = 0;
  145. for (i=0; i<size_t(m_threshold); i++)
  146. {
  147. MessageQueue &queue = m_inputQueues[i];
  148. queue.GetWord32(m_y[i]);
  149. if (finished)
  150. m_channelsReady += queue.AnyRetrievable();
  151. else
  152. m_channelsReady += queue.NumberOfMessages() > 0 || queue.MaxRetrievable() >= 4;
  153. }
  154. for (i=0; (unsigned int)i<m_outputChannelIds.size(); i++)
  155. {
  156. if (m_outputToInput[i] != size_t(m_threshold))
  157. m_outputQueues[i].PutWord32(m_y[m_outputToInput[i]]);
  158. else if (m_v[i].size() == size_t(m_threshold))
  159. m_outputQueues[i].PutWord32(BulkPolynomialInterpolateAt(field, m_y.begin(), m_v[i].begin(), m_threshold));
  160. else
  161. {
  162. m_u.resize(m_threshold);
  163. PrepareBulkPolynomialInterpolationAt(field, m_u.begin(), m_outputChannelIds[i], &(m_inputChannelIds[0]), m_w.begin(), m_threshold);
  164. m_outputQueues[i].PutWord32(BulkPolynomialInterpolateAt(field, m_y.begin(), m_u.begin(), m_threshold));
  165. }
  166. }
  167. }
  168. if (m_outputChannelIds.size() > 0 && m_outputQueues[0].AnyRetrievable())
  169. FlushOutputQueues();
  170. if (finished)
  171. {
  172. OutputMessageEnds();
  173. m_channelsReady = 0;
  174. m_channelsFinished = 0;
  175. m_v.clear();
  176. std::vector<MessageQueue> inputQueues;
  177. std::vector<word32> inputChannelIds;
  178. inputQueues.swap(m_inputQueues);
  179. inputChannelIds.swap(m_inputChannelIds);
  180. m_inputChannelMap.clear();
  181. m_lastMapPosition = m_inputChannelMap.end();
  182. for (i=0; i<size_t(m_threshold); i++)
  183. {
  184. inputQueues[i].GetNextMessage();
  185. inputQueues[i].TransferAllTo(*AttachedTransformation(), WordToString(inputChannelIds[i]));
  186. }
  187. }
  188. }
  189. void RawIDA::FlushOutputQueues()
  190. {
  191. for (unsigned int i=0; i<m_outputChannelIds.size(); i++)
  192. m_outputQueues[i].TransferAllTo(*AttachedTransformation(), m_outputChannelIdStrings[i]);
  193. }
  194. void RawIDA::OutputMessageEnds()
  195. {
  196. if (GetAutoSignalPropagation() != 0)
  197. {
  198. for (unsigned int i=0; i<m_outputChannelIds.size(); i++)
  199. AttachedTransformation()->ChannelMessageEnd(m_outputChannelIdStrings[i], GetAutoSignalPropagation()-1);
  200. }
  201. }
  202. // ****************************************************************
  203. void SecretSharing::IsolatedInitialize(const NameValuePairs &parameters)
  204. {
  205. m_pad = parameters.GetValueWithDefault("AddPadding", true);
  206. m_ida.IsolatedInitialize(parameters);
  207. }
  208. size_t SecretSharing::Put2(const byte *begin, size_t length, int messageEnd, bool blocking)
  209. {
  210. if (!blocking)
  211. throw BlockingInputOnly("SecretSharing");
  212. SecByteBlock buf(UnsignedMin(256, length));
  213. unsigned int threshold = m_ida.GetThreshold();
  214. while (length > 0)
  215. {
  216. size_t len = STDMIN(length, buf.size());
  217. m_ida.ChannelData(0xffffffff, begin, len, false);
  218. for (unsigned int i=0; i<threshold-1; i++)
  219. {
  220. m_rng.GenerateBlock(buf, len);
  221. m_ida.ChannelData(i, buf, len, false);
  222. }
  223. length -= len;
  224. begin += len;
  225. }
  226. if (messageEnd)
  227. {
  228. m_ida.SetAutoSignalPropagation(messageEnd-1);
  229. if (m_pad)
  230. {
  231. SecretSharing::Put(1);
  232. while (m_ida.InputBuffered(0xffffffff) > 0)
  233. SecretSharing::Put(0);
  234. }
  235. m_ida.ChannelData(0xffffffff, NULL, 0, true);
  236. for (unsigned int i=0; i<m_ida.GetThreshold()-1; i++)
  237. m_ida.ChannelData(i, NULL, 0, true);
  238. }
  239. return 0;
  240. }
  241. void SecretRecovery::IsolatedInitialize(const NameValuePairs &parameters)
  242. {
  243. m_pad = parameters.GetValueWithDefault("RemovePadding", true);
  244. RawIDA::IsolatedInitialize(CombinedNameValuePairs(parameters, MakeParameters("OutputChannelID", (word32)0xffffffff)));
  245. }
  246. void SecretRecovery::FlushOutputQueues()
  247. {
  248. if (m_pad)
  249. m_outputQueues[0].TransferTo(*AttachedTransformation(), m_outputQueues[0].MaxRetrievable()-4);
  250. else
  251. m_outputQueues[0].TransferTo(*AttachedTransformation());
  252. }
  253. void SecretRecovery::OutputMessageEnds()
  254. {
  255. if (m_pad)
  256. {
  257. PaddingRemover paddingRemover(new Redirector(*AttachedTransformation()));
  258. m_outputQueues[0].TransferAllTo(paddingRemover);
  259. }
  260. if (GetAutoSignalPropagation() != 0)
  261. AttachedTransformation()->MessageEnd(GetAutoSignalPropagation()-1);
  262. }
  263. // ****************************************************************
  264. void InformationDispersal::IsolatedInitialize(const NameValuePairs &parameters)
  265. {
  266. m_nextChannel = 0;
  267. m_pad = parameters.GetValueWithDefault("AddPadding", true);
  268. m_ida.IsolatedInitialize(parameters);
  269. }
  270. size_t InformationDispersal::Put2(const byte *begin, size_t length, int messageEnd, bool blocking)
  271. {
  272. if (!blocking)
  273. throw BlockingInputOnly("InformationDispersal");
  274. while (length--)
  275. {
  276. m_ida.ChannelData(m_nextChannel, begin, 1, false);
  277. begin++;
  278. m_nextChannel++;
  279. if (m_nextChannel == m_ida.GetThreshold())
  280. m_nextChannel = 0;
  281. }
  282. if (messageEnd)
  283. {
  284. m_ida.SetAutoSignalPropagation(messageEnd-1);
  285. if (m_pad)
  286. InformationDispersal::Put(1);
  287. for (word32 i=0; i<m_ida.GetThreshold(); i++)
  288. m_ida.ChannelData(i, NULL, 0, true);
  289. }
  290. return 0;
  291. }
  292. void InformationRecovery::IsolatedInitialize(const NameValuePairs &parameters)
  293. {
  294. m_pad = parameters.GetValueWithDefault("RemovePadding", true);
  295. RawIDA::IsolatedInitialize(parameters);
  296. }
  297. void InformationRecovery::FlushOutputQueues()
  298. {
  299. while (m_outputQueues[0].AnyRetrievable())
  300. {
  301. for (unsigned int i=0; i<m_outputChannelIds.size(); i++)
  302. m_outputQueues[i].TransferTo(m_queue, 1);
  303. }
  304. if (m_pad)
  305. m_queue.TransferTo(*AttachedTransformation(), m_queue.MaxRetrievable()-4*m_threshold);
  306. else
  307. m_queue.TransferTo(*AttachedTransformation());
  308. }
  309. void InformationRecovery::OutputMessageEnds()
  310. {
  311. if (m_pad)
  312. {
  313. PaddingRemover paddingRemover(new Redirector(*AttachedTransformation()));
  314. m_queue.TransferAllTo(paddingRemover);
  315. }
  316. if (GetAutoSignalPropagation() != 0)
  317. AttachedTransformation()->MessageEnd(GetAutoSignalPropagation()-1);
  318. }
  319. size_t PaddingRemover::Put2(const byte *begin, size_t length, int messageEnd, bool blocking)
  320. {
  321. if (!blocking)
  322. throw BlockingInputOnly("PaddingRemover");
  323. const byte *const end = begin + length;
  324. if (m_possiblePadding)
  325. {
  326. size_t len = std::find_if(begin, end, std::bind2nd(std::not_equal_to<byte>(), byte(0))) - begin;
  327. m_zeroCount += len;
  328. begin += len;
  329. if (begin == end)
  330. return 0;
  331. AttachedTransformation()->Put(1);
  332. while (m_zeroCount--)
  333. AttachedTransformation()->Put(0);
  334. AttachedTransformation()->Put(*begin++);
  335. m_possiblePadding = false;
  336. }
  337. #if defined(_MSC_VER) && !defined(__MWERKS__) && (_MSC_VER <= 1300)
  338. // VC60 and VC7 workaround: built-in reverse_iterator has two template parameters, Dinkumware only has one
  339. typedef std::reverse_bidirectional_iterator<const byte *, const byte> RevIt;
  340. #elif defined(_RWSTD_NO_CLASS_PARTIAL_SPEC)
  341. typedef std::reverse_iterator<const byte *, random_access_iterator_tag, const byte> RevIt;
  342. #else
  343. typedef std::reverse_iterator<const byte *> RevIt;
  344. #endif
  345. const byte *x = std::find_if(RevIt(end), RevIt(begin), bind2nd(std::not_equal_to<byte>(), byte(0))).base();
  346. if (x != begin && *(x-1) == 1)
  347. {
  348. AttachedTransformation()->Put(begin, x-begin-1);
  349. m_possiblePadding = true;
  350. m_zeroCount = end - x;
  351. }
  352. else
  353. AttachedTransformation()->Put(begin, end-begin);
  354. if (messageEnd)
  355. {
  356. m_possiblePadding = false;
  357. Output(0, begin, length, messageEnd, blocking);
  358. }
  359. return 0;
  360. }
  361. NAMESPACE_END