Counter Strike : Global Offensive Source Code
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

616 lines
19 KiB

  1. //===-- HeuristicSolver.h - Heuristic PBQP Solver --------------*- C++ -*-===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //
  10. // Heuristic PBQP solver. This solver is able to perform optimal reductions for
  11. // nodes of degree 0, 1 or 2. For nodes of degree >2 a plugable heuristic is
  12. // used to select a node for reduction.
  13. //
  14. //===----------------------------------------------------------------------===//
  15. #ifndef LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H
  16. #define LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H
  17. #include "Graph.h"
  18. #include "Solution.h"
  19. #include <limits>
  20. #include <vector>
  21. namespace PBQP {
  22. /// \brief Heuristic PBQP solver implementation.
  23. ///
  24. /// This class should usually be created (and destroyed) indirectly via a call
  25. /// to HeuristicSolver<HImpl>::solve(Graph&).
  26. /// See the comments for HeuristicSolver.
  27. ///
  28. /// HeuristicSolverImpl provides the R0, R1 and R2 reduction rules,
  29. /// backpropagation phase, and maintains the internal copy of the graph on
  30. /// which the reduction is carried out (the original being kept to facilitate
  31. /// backpropagation).
  32. template <typename HImpl>
  33. class HeuristicSolverImpl {
  34. private:
  35. typedef typename HImpl::NodeData HeuristicNodeData;
  36. typedef typename HImpl::EdgeData HeuristicEdgeData;
  37. typedef std::list<Graph::EdgeItr> SolverEdges;
  38. public:
  39. /// \brief Iterator type for edges in the solver graph.
  40. typedef SolverEdges::iterator SolverEdgeItr;
  41. private:
  42. class NodeData {
  43. public:
  44. NodeData() : solverDegree(0) {}
  45. HeuristicNodeData& getHeuristicData() { return hData; }
  46. SolverEdgeItr addSolverEdge(Graph::EdgeItr eItr) {
  47. ++solverDegree;
  48. return solverEdges.insert(solverEdges.end(), eItr);
  49. }
  50. void removeSolverEdge(SolverEdgeItr seItr) {
  51. --solverDegree;
  52. solverEdges.erase(seItr);
  53. }
  54. SolverEdgeItr solverEdgesBegin() { return solverEdges.begin(); }
  55. SolverEdgeItr solverEdgesEnd() { return solverEdges.end(); }
  56. unsigned getSolverDegree() const { return solverDegree; }
  57. void clearSolverEdges() {
  58. solverDegree = 0;
  59. solverEdges.clear();
  60. }
  61. private:
  62. HeuristicNodeData hData;
  63. unsigned solverDegree;
  64. SolverEdges solverEdges;
  65. };
  66. class EdgeData {
  67. public:
  68. HeuristicEdgeData& getHeuristicData() { return hData; }
  69. void setN1SolverEdgeItr(SolverEdgeItr n1SolverEdgeItr) {
  70. this->n1SolverEdgeItr = n1SolverEdgeItr;
  71. }
  72. SolverEdgeItr getN1SolverEdgeItr() { return n1SolverEdgeItr; }
  73. void setN2SolverEdgeItr(SolverEdgeItr n2SolverEdgeItr){
  74. this->n2SolverEdgeItr = n2SolverEdgeItr;
  75. }
  76. SolverEdgeItr getN2SolverEdgeItr() { return n2SolverEdgeItr; }
  77. private:
  78. HeuristicEdgeData hData;
  79. SolverEdgeItr n1SolverEdgeItr, n2SolverEdgeItr;
  80. };
  81. Graph &g;
  82. HImpl h;
  83. Solution s;
  84. std::vector<Graph::NodeItr> stack;
  85. typedef std::list<NodeData> NodeDataList;
  86. NodeDataList nodeDataList;
  87. typedef std::list<EdgeData> EdgeDataList;
  88. EdgeDataList edgeDataList;
  89. public:
  90. /// \brief Construct a heuristic solver implementation to solve the given
  91. /// graph.
  92. /// @param g The graph representing the problem instance to be solved.
  93. HeuristicSolverImpl(Graph &g) : g(g), h(*this) {}
  94. /// \brief Get the graph being solved by this solver.
  95. /// @return The graph representing the problem instance being solved by this
  96. /// solver.
  97. Graph& getGraph() { return g; }
  98. /// \brief Get the heuristic data attached to the given node.
  99. /// @param nItr Node iterator.
  100. /// @return The heuristic data attached to the given node.
  101. HeuristicNodeData& getHeuristicNodeData(Graph::NodeItr nItr) {
  102. return getSolverNodeData(nItr).getHeuristicData();
  103. }
  104. /// \brief Get the heuristic data attached to the given edge.
  105. /// @param eItr Edge iterator.
  106. /// @return The heuristic data attached to the given node.
  107. HeuristicEdgeData& getHeuristicEdgeData(Graph::EdgeItr eItr) {
  108. return getSolverEdgeData(eItr).getHeuristicData();
  109. }
  110. /// \brief Begin iterator for the set of edges adjacent to the given node in
  111. /// the solver graph.
  112. /// @param nItr Node iterator.
  113. /// @return Begin iterator for the set of edges adjacent to the given node
  114. /// in the solver graph.
  115. SolverEdgeItr solverEdgesBegin(Graph::NodeItr nItr) {
  116. return getSolverNodeData(nItr).solverEdgesBegin();
  117. }
  118. /// \brief End iterator for the set of edges adjacent to the given node in
  119. /// the solver graph.
  120. /// @param nItr Node iterator.
  121. /// @return End iterator for the set of edges adjacent to the given node in
  122. /// the solver graph.
  123. SolverEdgeItr solverEdgesEnd(Graph::NodeItr nItr) {
  124. return getSolverNodeData(nItr).solverEdgesEnd();
  125. }
  126. /// \brief Remove a node from the solver graph.
  127. /// @param eItr Edge iterator for edge to be removed.
  128. ///
  129. /// Does <i>not</i> notify the heuristic of the removal. That should be
  130. /// done manually if necessary.
  131. void removeSolverEdge(Graph::EdgeItr eItr) {
  132. EdgeData &eData = getSolverEdgeData(eItr);
  133. NodeData &n1Data = getSolverNodeData(g.getEdgeNode1(eItr)),
  134. &n2Data = getSolverNodeData(g.getEdgeNode2(eItr));
  135. n1Data.removeSolverEdge(eData.getN1SolverEdgeItr());
  136. n2Data.removeSolverEdge(eData.getN2SolverEdgeItr());
  137. }
  138. /// \brief Compute a solution to the PBQP problem instance with which this
  139. /// heuristic solver was constructed.
  140. /// @return A solution to the PBQP problem.
  141. ///
  142. /// Performs the full PBQP heuristic solver algorithm, including setup,
  143. /// calls to the heuristic (which will call back to the reduction rules in
  144. /// this class), and cleanup.
  145. Solution computeSolution() {
  146. setup();
  147. h.setup();
  148. h.reduce();
  149. backpropagate();
  150. h.cleanup();
  151. cleanup();
  152. return s;
  153. }
  154. /// \brief Add to the end of the stack.
  155. /// @param nItr Node iterator to add to the reduction stack.
  156. void pushToStack(Graph::NodeItr nItr) {
  157. getSolverNodeData(nItr).clearSolverEdges();
  158. stack.push_back(nItr);
  159. }
  160. /// \brief Returns the solver degree of the given node.
  161. /// @param nItr Node iterator for which degree is requested.
  162. /// @return Node degree in the <i>solver</i> graph (not the original graph).
  163. unsigned getSolverDegree(Graph::NodeItr nItr) {
  164. return getSolverNodeData(nItr).getSolverDegree();
  165. }
  166. /// \brief Set the solution of the given node.
  167. /// @param nItr Node iterator to set solution for.
  168. /// @param selection Selection for node.
  169. void setSolution(const Graph::NodeItr &nItr, unsigned selection) {
  170. s.setSelection(nItr, selection);
  171. for (Graph::AdjEdgeItr aeItr = g.adjEdgesBegin(nItr),
  172. aeEnd = g.adjEdgesEnd(nItr);
  173. aeItr != aeEnd; ++aeItr) {
  174. Graph::EdgeItr eItr(*aeItr);
  175. Graph::NodeItr anItr(g.getEdgeOtherNode(eItr, nItr));
  176. getSolverNodeData(anItr).addSolverEdge(eItr);
  177. }
  178. }
  179. /// \brief Apply rule R0.
  180. /// @param nItr Node iterator for node to apply R0 to.
  181. ///
  182. /// Node will be automatically pushed to the solver stack.
  183. void applyR0(Graph::NodeItr nItr) {
  184. assert(getSolverNodeData(nItr).getSolverDegree() == 0 &&
  185. "R0 applied to node with degree != 0.");
  186. // Nothing to do. Just push the node onto the reduction stack.
  187. pushToStack(nItr);
  188. s.recordR0();
  189. }
  190. /// \brief Apply rule R1.
  191. /// @param xnItr Node iterator for node to apply R1 to.
  192. ///
  193. /// Node will be automatically pushed to the solver stack.
  194. void applyR1(Graph::NodeItr xnItr) {
  195. NodeData &nd = getSolverNodeData(xnItr);
  196. assert(nd.getSolverDegree() == 1 &&
  197. "R1 applied to node with degree != 1.");
  198. Graph::EdgeItr eItr = *nd.solverEdgesBegin();
  199. const Matrix &eCosts = g.getEdgeCosts(eItr);
  200. const Vector &xCosts = g.getNodeCosts(xnItr);
  201. // Duplicate a little to avoid transposing matrices.
  202. if (xnItr == g.getEdgeNode1(eItr)) {
  203. Graph::NodeItr ynItr = g.getEdgeNode2(eItr);
  204. Vector &yCosts = g.getNodeCosts(ynItr);
  205. for (unsigned j = 0; j < yCosts.getLength(); ++j) {
  206. PBQPNum min = eCosts[0][j] + xCosts[0];
  207. for (unsigned i = 1; i < xCosts.getLength(); ++i) {
  208. PBQPNum c = eCosts[i][j] + xCosts[i];
  209. if (c < min)
  210. min = c;
  211. }
  212. yCosts[j] += min;
  213. }
  214. h.handleRemoveEdge(eItr, ynItr);
  215. } else {
  216. Graph::NodeItr ynItr = g.getEdgeNode1(eItr);
  217. Vector &yCosts = g.getNodeCosts(ynItr);
  218. for (unsigned i = 0; i < yCosts.getLength(); ++i) {
  219. PBQPNum min = eCosts[i][0] + xCosts[0];
  220. for (unsigned j = 1; j < xCosts.getLength(); ++j) {
  221. PBQPNum c = eCosts[i][j] + xCosts[j];
  222. if (c < min)
  223. min = c;
  224. }
  225. yCosts[i] += min;
  226. }
  227. h.handleRemoveEdge(eItr, ynItr);
  228. }
  229. removeSolverEdge(eItr);
  230. assert(nd.getSolverDegree() == 0 &&
  231. "Degree 1 with edge removed should be 0.");
  232. pushToStack(xnItr);
  233. s.recordR1();
  234. }
  235. /// \brief Apply rule R2.
  236. /// @param xnItr Node iterator for node to apply R2 to.
  237. ///
  238. /// Node will be automatically pushed to the solver stack.
  239. void applyR2(Graph::NodeItr xnItr) {
  240. assert(getSolverNodeData(xnItr).getSolverDegree() == 2 &&
  241. "R2 applied to node with degree != 2.");
  242. NodeData &nd = getSolverNodeData(xnItr);
  243. const Vector &xCosts = g.getNodeCosts(xnItr);
  244. SolverEdgeItr aeItr = nd.solverEdgesBegin();
  245. Graph::EdgeItr yxeItr = *aeItr,
  246. zxeItr = *(++aeItr);
  247. Graph::NodeItr ynItr = g.getEdgeOtherNode(yxeItr, xnItr),
  248. znItr = g.getEdgeOtherNode(zxeItr, xnItr);
  249. bool flipEdge1 = (g.getEdgeNode1(yxeItr) == xnItr),
  250. flipEdge2 = (g.getEdgeNode1(zxeItr) == xnItr);
  251. const Matrix *yxeCosts = flipEdge1 ?
  252. new Matrix(g.getEdgeCosts(yxeItr).transpose()) :
  253. &g.getEdgeCosts(yxeItr);
  254. const Matrix *zxeCosts = flipEdge2 ?
  255. new Matrix(g.getEdgeCosts(zxeItr).transpose()) :
  256. &g.getEdgeCosts(zxeItr);
  257. unsigned xLen = xCosts.getLength(),
  258. yLen = yxeCosts->getRows(),
  259. zLen = zxeCosts->getRows();
  260. Matrix delta(yLen, zLen);
  261. for (unsigned i = 0; i < yLen; ++i) {
  262. for (unsigned j = 0; j < zLen; ++j) {
  263. PBQPNum min = (*yxeCosts)[i][0] + (*zxeCosts)[j][0] + xCosts[0];
  264. for (unsigned k = 1; k < xLen; ++k) {
  265. PBQPNum c = (*yxeCosts)[i][k] + (*zxeCosts)[j][k] + xCosts[k];
  266. if (c < min) {
  267. min = c;
  268. }
  269. }
  270. delta[i][j] = min;
  271. }
  272. }
  273. if (flipEdge1)
  274. delete yxeCosts;
  275. if (flipEdge2)
  276. delete zxeCosts;
  277. Graph::EdgeItr yzeItr = g.findEdge(ynItr, znItr);
  278. bool addedEdge = false;
  279. if (yzeItr == g.edgesEnd()) {
  280. yzeItr = g.addEdge(ynItr, znItr, delta);
  281. addedEdge = true;
  282. } else {
  283. Matrix &yzeCosts = g.getEdgeCosts(yzeItr);
  284. h.preUpdateEdgeCosts(yzeItr);
  285. if (ynItr == g.getEdgeNode1(yzeItr)) {
  286. yzeCosts += delta;
  287. } else {
  288. yzeCosts += delta.transpose();
  289. }
  290. }
  291. bool nullCostEdge = tryNormaliseEdgeMatrix(yzeItr);
  292. if (!addedEdge) {
  293. // If we modified the edge costs let the heuristic know.
  294. h.postUpdateEdgeCosts(yzeItr);
  295. }
  296. if (nullCostEdge) {
  297. // If this edge ended up null remove it.
  298. if (!addedEdge) {
  299. // We didn't just add it, so we need to notify the heuristic
  300. // and remove it from the solver.
  301. h.handleRemoveEdge(yzeItr, ynItr);
  302. h.handleRemoveEdge(yzeItr, znItr);
  303. removeSolverEdge(yzeItr);
  304. }
  305. g.removeEdge(yzeItr);
  306. } else if (addedEdge) {
  307. // If the edge was added, and non-null, finish setting it up, add it to
  308. // the solver & notify heuristic.
  309. edgeDataList.push_back(EdgeData());
  310. g.setEdgeData(yzeItr, &edgeDataList.back());
  311. addSolverEdge(yzeItr);
  312. h.handleAddEdge(yzeItr);
  313. }
  314. h.handleRemoveEdge(yxeItr, ynItr);
  315. removeSolverEdge(yxeItr);
  316. h.handleRemoveEdge(zxeItr, znItr);
  317. removeSolverEdge(zxeItr);
  318. pushToStack(xnItr);
  319. s.recordR2();
  320. }
  321. /// \brief Record an application of the RN rule.
  322. ///
  323. /// For use by the HeuristicBase.
  324. void recordRN() { s.recordRN(); }
  325. private:
  326. NodeData& getSolverNodeData(Graph::NodeItr nItr) {
  327. return *static_cast<NodeData*>(g.getNodeData(nItr));
  328. }
  329. EdgeData& getSolverEdgeData(Graph::EdgeItr eItr) {
  330. return *static_cast<EdgeData*>(g.getEdgeData(eItr));
  331. }
  332. void addSolverEdge(Graph::EdgeItr eItr) {
  333. EdgeData &eData = getSolverEdgeData(eItr);
  334. NodeData &n1Data = getSolverNodeData(g.getEdgeNode1(eItr)),
  335. &n2Data = getSolverNodeData(g.getEdgeNode2(eItr));
  336. eData.setN1SolverEdgeItr(n1Data.addSolverEdge(eItr));
  337. eData.setN2SolverEdgeItr(n2Data.addSolverEdge(eItr));
  338. }
  339. void setup() {
  340. if (h.solverRunSimplify()) {
  341. simplify();
  342. }
  343. // Create node data objects.
  344. for (Graph::NodeItr nItr = g.nodesBegin(), nEnd = g.nodesEnd();
  345. nItr != nEnd; ++nItr) {
  346. nodeDataList.push_back(NodeData());
  347. g.setNodeData(nItr, &nodeDataList.back());
  348. }
  349. // Create edge data objects.
  350. for (Graph::EdgeItr eItr = g.edgesBegin(), eEnd = g.edgesEnd();
  351. eItr != eEnd; ++eItr) {
  352. edgeDataList.push_back(EdgeData());
  353. g.setEdgeData(eItr, &edgeDataList.back());
  354. addSolverEdge(eItr);
  355. }
  356. }
  357. void simplify() {
  358. disconnectTrivialNodes();
  359. eliminateIndependentEdges();
  360. }
  361. // Eliminate trivial nodes.
  362. void disconnectTrivialNodes() {
  363. unsigned numDisconnected = 0;
  364. for (Graph::NodeItr nItr = g.nodesBegin(), nEnd = g.nodesEnd();
  365. nItr != nEnd; ++nItr) {
  366. if (g.getNodeCosts(nItr).getLength() == 1) {
  367. std::vector<Graph::EdgeItr> edgesToRemove;
  368. for (Graph::AdjEdgeItr aeItr = g.adjEdgesBegin(nItr),
  369. aeEnd = g.adjEdgesEnd(nItr);
  370. aeItr != aeEnd; ++aeItr) {
  371. Graph::EdgeItr eItr = *aeItr;
  372. if (g.getEdgeNode1(eItr) == nItr) {
  373. Graph::NodeItr otherNodeItr = g.getEdgeNode2(eItr);
  374. g.getNodeCosts(otherNodeItr) +=
  375. g.getEdgeCosts(eItr).getRowAsVector(0);
  376. }
  377. else {
  378. Graph::NodeItr otherNodeItr = g.getEdgeNode1(eItr);
  379. g.getNodeCosts(otherNodeItr) +=
  380. g.getEdgeCosts(eItr).getColAsVector(0);
  381. }
  382. edgesToRemove.push_back(eItr);
  383. }
  384. if (!edgesToRemove.empty())
  385. ++numDisconnected;
  386. while (!edgesToRemove.empty()) {
  387. g.removeEdge(edgesToRemove.back());
  388. edgesToRemove.pop_back();
  389. }
  390. }
  391. }
  392. }
  393. void eliminateIndependentEdges() {
  394. std::vector<Graph::EdgeItr> edgesToProcess;
  395. unsigned numEliminated = 0;
  396. for (Graph::EdgeItr eItr = g.edgesBegin(), eEnd = g.edgesEnd();
  397. eItr != eEnd; ++eItr) {
  398. edgesToProcess.push_back(eItr);
  399. }
  400. while (!edgesToProcess.empty()) {
  401. if (tryToEliminateEdge(edgesToProcess.back()))
  402. ++numEliminated;
  403. edgesToProcess.pop_back();
  404. }
  405. }
  406. bool tryToEliminateEdge(Graph::EdgeItr eItr) {
  407. if (tryNormaliseEdgeMatrix(eItr)) {
  408. g.removeEdge(eItr);
  409. return true;
  410. }
  411. return false;
  412. }
  413. bool tryNormaliseEdgeMatrix(Graph::EdgeItr &eItr) {
  414. const PBQPNum infinity = std::numeric_limits<PBQPNum>::infinity();
  415. Matrix &edgeCosts = g.getEdgeCosts(eItr);
  416. Vector &uCosts = g.getNodeCosts(g.getEdgeNode1(eItr)),
  417. &vCosts = g.getNodeCosts(g.getEdgeNode2(eItr));
  418. for (unsigned r = 0; r < edgeCosts.getRows(); ++r) {
  419. PBQPNum rowMin = infinity;
  420. for (unsigned c = 0; c < edgeCosts.getCols(); ++c) {
  421. if (vCosts[c] != infinity && edgeCosts[r][c] < rowMin)
  422. rowMin = edgeCosts[r][c];
  423. }
  424. uCosts[r] += rowMin;
  425. if (rowMin != infinity) {
  426. edgeCosts.subFromRow(r, rowMin);
  427. }
  428. else {
  429. edgeCosts.setRow(r, 0);
  430. }
  431. }
  432. for (unsigned c = 0; c < edgeCosts.getCols(); ++c) {
  433. PBQPNum colMin = infinity;
  434. for (unsigned r = 0; r < edgeCosts.getRows(); ++r) {
  435. if (uCosts[r] != infinity && edgeCosts[r][c] < colMin)
  436. colMin = edgeCosts[r][c];
  437. }
  438. vCosts[c] += colMin;
  439. if (colMin != infinity) {
  440. edgeCosts.subFromCol(c, colMin);
  441. }
  442. else {
  443. edgeCosts.setCol(c, 0);
  444. }
  445. }
  446. return edgeCosts.isZero();
  447. }
  448. void backpropagate() {
  449. while (!stack.empty()) {
  450. computeSolution(stack.back());
  451. stack.pop_back();
  452. }
  453. }
  454. void computeSolution(Graph::NodeItr nItr) {
  455. NodeData &nodeData = getSolverNodeData(nItr);
  456. Vector v(g.getNodeCosts(nItr));
  457. // Solve based on existing solved edges.
  458. for (SolverEdgeItr solvedEdgeItr = nodeData.solverEdgesBegin(),
  459. solvedEdgeEnd = nodeData.solverEdgesEnd();
  460. solvedEdgeItr != solvedEdgeEnd; ++solvedEdgeItr) {
  461. Graph::EdgeItr eItr(*solvedEdgeItr);
  462. Matrix &edgeCosts = g.getEdgeCosts(eItr);
  463. if (nItr == g.getEdgeNode1(eItr)) {
  464. Graph::NodeItr adjNode(g.getEdgeNode2(eItr));
  465. unsigned adjSolution = s.getSelection(adjNode);
  466. v += edgeCosts.getColAsVector(adjSolution);
  467. }
  468. else {
  469. Graph::NodeItr adjNode(g.getEdgeNode1(eItr));
  470. unsigned adjSolution = s.getSelection(adjNode);
  471. v += edgeCosts.getRowAsVector(adjSolution);
  472. }
  473. }
  474. setSolution(nItr, v.minIndex());
  475. }
  476. void cleanup() {
  477. h.cleanup();
  478. nodeDataList.clear();
  479. edgeDataList.clear();
  480. }
  481. };
  482. /// \brief PBQP heuristic solver class.
  483. ///
  484. /// Given a PBQP Graph g representing a PBQP problem, you can find a solution
  485. /// by calling
  486. /// <tt>Solution s = HeuristicSolver<H>::solve(g);</tt>
  487. ///
  488. /// The choice of heuristic for the H parameter will affect both the solver
  489. /// speed and solution quality. The heuristic should be chosen based on the
  490. /// nature of the problem being solved.
  491. /// Currently the only solver included with LLVM is the Briggs heuristic for
  492. /// register allocation.
  493. template <typename HImpl>
  494. class HeuristicSolver {
  495. public:
  496. static Solution solve(Graph &g) {
  497. HeuristicSolverImpl<HImpl> hs(g);
  498. return hs.computeSolution();
  499. }
  500. };
  501. }
  502. #endif // LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H