Source code of Windows XP (NT5)
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

523 lines
12 KiB

  1. //+-------------------------------------------------------------------------
  2. //
  3. // Microsoft Windows
  4. //
  5. // Copyright (C) Microsoft Corporation, 1997 - 1998
  6. //
  7. // File: mbnet.cpp
  8. //
  9. //--------------------------------------------------------------------------
  10. //
  11. // mbnet.cpp: Belief network model member functions
  12. //
  13. #include <basetsd.h>
  14. #include "basics.h"
  15. #include "algos.h"
  16. #include "gmprop.h"
  17. #include "gmobj.h"
  18. #include "cliqset.h"
  19. #include "clique.h"
  20. #include "expand.h"
  21. MBNET :: MBNET ()
  22. :_inmFree(0),
  23. _iInferEngID(0)
  24. {
  25. }
  26. MBNET :: ~ MBNET ()
  27. {
  28. PopModifierStack( true ); // Clear all modifiers from the network
  29. // Clear the node-index-to-name information
  30. _inmFree = 0;
  31. _vzsrNames.clear();
  32. }
  33. //
  34. // Clone this belief network from another. Note that the contents
  35. // of the modifier stack (inference engines, expanders, etc.) are
  36. // NOT cloned.
  37. //
  38. void MBNET :: Clone ( MODEL & model )
  39. {
  40. // This must be a truly empty structure
  41. ASSERT_THROW( _vpModifiers.size() == 0 && _vzsrNames.size() == 0,
  42. EC_INVALID_CLONE,
  43. "cannot clone into non-empty structure" );
  44. MODEL::Clone( model );
  45. MBNET * pmbnet;
  46. DynCastThrow( & model, pmbnet );
  47. MBNET & mbnet = *pmbnet;
  48. {
  49. // Build the name table by iterating over the contents and
  50. // allocating a slot for each node
  51. GELEMLNK * pgelm;
  52. MODELENUM mdlenumNode( mbnet );
  53. while ( pgelm = mdlenumNode.PlnkelNext() )
  54. {
  55. // Check that it's a node (not an edge)
  56. if ( ! pgelm->BIsEType( GELEM::EGELM_NODE ) )
  57. continue;
  58. GOBJMBN * pgobjmbn;
  59. DynCastThrow( pgelm, pgobjmbn );
  60. _vzsrNames.push_back( pgobjmbn->ZsrefName() );
  61. }
  62. _inmFree = _vzsrNames.size();
  63. }
  64. // Clone the distribution map
  65. _mppd.Clone( _mpsymtbl, mbnet._mppd ) ;
  66. // Check the topology if it's supposed to be present
  67. #ifdef _DEBUG
  68. if ( mbnet.BFlag( EIBF_Topology ) )
  69. VerifyTopology();
  70. #endif
  71. }
  72. //
  73. // Iterate over the distributions, matching them to the nodes they belong to.
  74. //
  75. void MBNET :: VerifyTopology ()
  76. {
  77. for ( MPPD::iterator itpd = Mppd().begin();
  78. itpd != Mppd().end();
  79. itpd++ )
  80. {
  81. const VTKNPD & vtknpd = (*itpd).first;
  82. const BNDIST * pbndist = (*itpd).second;
  83. // Guarantee that the descriptor is of the form "p(X|...)"
  84. if ( vtknpd.size() < 2
  85. || vtknpd[0] != TKNPD(DTKN_PD)
  86. || ! vtknpd[1].BStr() )
  87. throw GMException( EC_INV_PD, "invalid token descriptor on PD");
  88. // Get the name of the node whose distribution this is
  89. SZC szc = vtknpd[1].Szc();
  90. assert( szc ) ;
  91. // Find that named thing in the graph
  92. GOBJMBN * pbnobj = Mpsymtbl().find( szc );
  93. assert( pbnobj && pbnobj->EType() == GOBJMBN::EBNO_NODE );
  94. // Guarantee that it's a node
  95. GNODEMBN * pgndbn = dynamic_cast<GNODEMBN *> (pbnobj);
  96. ASSERT_THROW( pgndbn, EC_INV_PD, "token on PD references non-node");
  97. // Verify the node's distribution
  98. if ( ! pgndbn->BMatchTopology( *this, vtknpd ) )
  99. {
  100. throw GMException( EC_TOPOLOGY_MISMATCH,
  101. "topology mismatch between PD and network");
  102. }
  103. }
  104. }
  105. MBNET_MODIFIER * MBNET :: PModifierStackTop ()
  106. {
  107. return _vpModifiers.size() > 0
  108. ? _vpModifiers[ _vpModifiers.size() - 1 ]
  109. : NULL;
  110. }
  111. void MBNET :: PushModifierStack ( MBNET_MODIFIER * pmodf )
  112. {
  113. assert( pmodf );
  114. pmodf->Create();
  115. _vpModifiers.push_back( pmodf );
  116. }
  117. void MBNET :: PopModifierStack ( bool bAll )
  118. {
  119. int iPop = _vpModifiers.size();
  120. while ( iPop > 0 )
  121. {
  122. MBNET_MODIFIER * pmodf = _vpModifiers[ --iPop ];
  123. assert ( pmodf );
  124. // NOTE: Deleting the object should be all that's necessary;
  125. // object's destructor should call its Destroy() function.
  126. delete pmodf;
  127. if ( ! bAll )
  128. break;
  129. }
  130. if ( iPop == 0 )
  131. _vpModifiers.clear();
  132. else
  133. _vpModifiers.resize(iPop);
  134. }
  135. // Find the named object by index
  136. GOBJMBN * MBNET :: PgobjFindByIndex ( int inm )
  137. {
  138. ZSREF zsMt;
  139. if ( inm >= _vzsrNames.size()
  140. || _vzsrNames[inm] == zsMt )
  141. return NULL;
  142. return Mpsymtbl().find( _vzsrNames[inm] );
  143. }
  144. int MBNET :: INameIndex ( ZSREF zsr )
  145. {
  146. return ifind( _vzsrNames, zsr );
  147. }
  148. int MBNET :: INameIndex ( const GOBJMBN * pgobj )
  149. {
  150. return INameIndex( pgobj->ZsrefName() );
  151. }
  152. int MBNET :: CreateNameIndex ( const GOBJMBN * pgobj )
  153. {
  154. int ind = -1;
  155. if ( _inmFree >= _vzsrNames.size() )
  156. {
  157. // No free slots; grow the array
  158. ind = _vzsrNames.size();
  159. _vzsrNames.push_back( pgobj->ZsrefName() );
  160. _inmFree = _vzsrNames.size();
  161. }
  162. else
  163. {
  164. // Use the given free slot, find the next
  165. _vzsrNames[ind = _inmFree] = pgobj->ZsrefName();
  166. ZSREF zsMt;
  167. for ( ; _inmFree < _vzsrNames.size() ; _inmFree++ )
  168. {
  169. if ( zsMt == _vzsrNames[_inmFree] )
  170. break;
  171. }
  172. }
  173. return ind;
  174. }
  175. void MBNET :: DeleteNameIndex ( int inm )
  176. {
  177. ASSERT_THROW( inm < _vzsrNames.size(),
  178. EC_INTERNAL_ERROR,
  179. "MBNET name index out of range" );
  180. _vzsrNames[inm] = ZSREF();
  181. if ( inm < _inmFree )
  182. _inmFree = inm;
  183. }
  184. void MBNET :: DeleteNameIndex ( const GOBJMBN * pgobj )
  185. {
  186. int inm = INameIndex( pgobj );
  187. if ( inm >= 0 )
  188. DeleteNameIndex(inm);
  189. }
  190. // Add a named object to the graph and symbol table
  191. void MBNET :: AddElem ( SZC szcName, GOBJMBN * pgelm )
  192. {
  193. if ( szcName == NULL || ::strlen(szcName) == 0 )
  194. {
  195. MODEL::AddElem( pgelm ); // empty name
  196. }
  197. else
  198. {
  199. MODEL::AddElem( szcName, pgelm );
  200. assert( INameIndex( pgelm ) < 0 ); // guarantee no duplicates
  201. CreateNameIndex( pgelm );
  202. }
  203. }
  204. void MBNET :: DeleteElem ( GOBJMBN * pgobj )
  205. {
  206. DeleteNameIndex( pgobj );
  207. MODEL::DeleteElem( pgobj );
  208. }
  209. /*
  210. Iterator has moved into the MODEL class... I've left the code here
  211. in case MBNET needs its own iterator. (Max, 05/12/97)
  212. MBNET::ITER :: ITER ( MBNET & bnet, GOBJMBN::EBNOBJ eType )
  213. : _eType(eType),
  214. _bnet(bnet)
  215. {
  216. Reset();
  217. }
  218. void MBNET::ITER :: Reset ()
  219. {
  220. _pCurrent = NULL;
  221. _itsym = _bnet.Mpsymtbl().begin();
  222. BNext();
  223. }
  224. bool MBNET::ITER :: BNext ()
  225. {
  226. while ( _itsym != _bnet.Mpsymtbl().end() )
  227. {
  228. _pCurrent = (*_itsym).second.Pobj();
  229. _zsrCurrent = (*_itsym).first;
  230. _itsym++;
  231. if ( _pCurrent->EType() == _eType )
  232. return true;
  233. }
  234. _pCurrent = NULL;
  235. return false;
  236. }
  237. */
  238. void MBNET :: CreateTopology ()
  239. {
  240. if ( BFlag( EIBF_Topology ) )
  241. return;
  242. // Walk the map of distributions. For each one, extract the node
  243. // name and find it. Then add arcs for each parent.
  244. #ifdef _DEBUG
  245. UINT iCycleMax = 2;
  246. #else
  247. UINT iCycleMax = 1;
  248. #endif
  249. UINT iIter = 0;
  250. for ( UINT iCycle = 0 ; iCycle < iCycleMax ; iCycle++ )
  251. {
  252. for ( MPPD::iterator itpd = Mppd().begin();
  253. itpd != Mppd().end();
  254. itpd++, iIter++ )
  255. {
  256. const VTKNPD & vtknpd = (*itpd).first;
  257. const BNDIST * pbndist = (*itpd).second;
  258. // Guarantee that the descriptor is of the form "p(X|...)"
  259. if ( vtknpd.size() < 2
  260. || vtknpd[0] != TKNPD(DTKN_PD)
  261. || ! vtknpd[1].BStr() )
  262. throw GMException( EC_INV_PD, "invalid token descriptor on PD");
  263. // Get the name of the node whose distribution this is
  264. SZC szcChild = vtknpd[1].Szc();
  265. assert( szcChild ) ;
  266. // Find that named thing in the graph
  267. GOBJMBN * pbnobjChild = Mpsymtbl().find( szcChild );
  268. assert( pbnobjChild && pbnobjChild->EType() == GOBJMBN::EBNO_NODE );
  269. // Guarantee that it's a node
  270. GNODEMBN * pgndbnChild = dynamic_cast<GNODEMBN *> (pbnobjChild);
  271. ASSERT_THROW( pgndbnChild, EC_INV_PD, "token on PD references non-node");
  272. UINT cParents = 0;
  273. UINT cChildren = pgndbnChild->CChild();
  274. for ( int i = 2; i < vtknpd.size(); i++ )
  275. {
  276. if ( ! vtknpd[i].BStr() )
  277. continue;
  278. SZC szcParent = vtknpd[i].Szc();
  279. assert( szcParent) ;
  280. GOBJMBN * pbnobjParent = Mpsymtbl().find( szcParent );
  281. assert( pbnobjParent && pbnobjParent->EType() == GOBJMBN::EBNO_NODE );
  282. GNODEMBN * pgndbnParent = (GNODEMBN *) pbnobjParent;
  283. UINT cPrChildren = pgndbnParent->CChild();
  284. if ( iCycle == 0 )
  285. {
  286. AddElem( new GEDGEMBN_PROB( pgndbnParent, pgndbnChild ) );
  287. }
  288. cParents++;
  289. if ( iCycle == 0 )
  290. {
  291. UINT cChNew = pgndbnChild->CChild();
  292. UINT cPrNew = pgndbnChild->CParent();
  293. UINT cPrChNew = pgndbnParent->CChild();
  294. assert( cPrChNew = cPrChildren + 1 );
  295. assert( cChildren == cChNew );
  296. }
  297. }
  298. if ( iCycle )
  299. {
  300. UINT cPrNew = pgndbnChild->CParent();
  301. assert( cParents == cPrNew );
  302. }
  303. if ( iCycle == 0 )
  304. {
  305. #ifdef _DEBUG
  306. if ( ! pgndbnChild->BMatchTopology( *this, vtknpd ) )
  307. {
  308. throw GMException( EC_TOPOLOGY_MISMATCH,
  309. "topology mismatch between PD and network");
  310. }
  311. #endif
  312. }
  313. }
  314. }
  315. BSetBFlag( EIBF_Topology );
  316. }
  317. DEFINEVP(GEDGEMBN);
  318. void MBNET :: DestroyTopology ( bool bDirectedOnly )
  319. {
  320. // Size up an array to hold pointers to all the edges
  321. VPGEDGEMBN vpgedge;
  322. int cItem = Grph().Chn().Count();
  323. vpgedge.resize(cItem);
  324. // Find all the arcs/edges
  325. int iItem = 0;
  326. GELEMLNK * pgelm;
  327. MODELENUM mdlenum( self );
  328. while ( pgelm = mdlenum.PlnkelNext() )
  329. {
  330. // Check that it's an edge
  331. if ( ! pgelm->BIsEType( GELEM::EGELM_EDGE ) )
  332. continue;
  333. // Check that it's a directed probabilistic arc
  334. if ( bDirectedOnly && pgelm->EType() != GEDGEMBN::ETPROB )
  335. continue;
  336. GEDGEMBN * pgedge;
  337. DynCastThrow( pgelm, pgedge );
  338. vpgedge[iItem++] = pgedge;
  339. }
  340. // Delete all the accumulated edges
  341. for ( int i = 0; i < iItem; )
  342. {
  343. GEDGEMBN * pgedge = vpgedge[i++];
  344. delete pgedge;
  345. }
  346. assert( Grph().Chn().Count() + iItem == cItem );
  347. BSetBFlag( EIBF_Topology, false );
  348. }
  349. //
  350. // Bind distributions to nodes. If they're already bound, exit.
  351. // If the node has a distribution already, leave it.
  352. //
  353. void MBNET :: BindDistributions ( bool bBind )
  354. {
  355. bool bDist = BFlag( EIBF_Distributions );
  356. if ( ! (bDist ^ bBind) )
  357. return;
  358. ITER itnd( self, GOBJMBN::EBNO_NODE );
  359. for ( ; *itnd ; itnd++ )
  360. {
  361. GNODEMBND * pgndd = dynamic_cast<GNODEMBND *>(*itnd);
  362. if ( pgndd == NULL )
  363. continue;
  364. if ( ! bBind )
  365. {
  366. pgndd->ClearDist();
  367. }
  368. else
  369. if ( ! pgndd->BHasDist() )
  370. {
  371. pgndd->SetDist( self );
  372. }
  373. }
  374. BSetBFlag( EIBF_Distributions, bBind );
  375. }
  376. void MBNET :: ClearNodeMarks ()
  377. {
  378. ITER itnd( self, GOBJMBN::EBNO_NODE );
  379. for ( ; *itnd ; itnd++ )
  380. {
  381. GNODEMBN * pgndbn = NULL;
  382. DynCastThrow( *itnd, pgndbn );
  383. pgndbn->IMark() = 0;
  384. }
  385. }
  386. void MBNET :: TopSortNodes ()
  387. {
  388. ClearNodeMarks();
  389. ITER itnd( self, GOBJMBN::EBNO_NODE );
  390. for ( ; *itnd ; itnd++ )
  391. {
  392. GNODEMBN * pgndbn = NULL;
  393. DynCastThrow( *itnd, pgndbn );
  394. pgndbn->Visit();
  395. }
  396. itnd.Reset();
  397. for ( ; *itnd ; itnd++ )
  398. {
  399. GNODEMBN * pgndbn = NULL;
  400. DynCastThrow( *itnd, pgndbn );
  401. pgndbn->ITopLevel() = pgndbn->IMark();
  402. }
  403. }
  404. void MBNET :: Dump ()
  405. {
  406. TopSortNodes();
  407. UINT iEntry = 0;
  408. for ( MPSYMTBL::iterator itsym = Mpsymtbl().begin();
  409. itsym != Mpsymtbl().end();
  410. itsym++ )
  411. {
  412. GOBJMBN * pbnobj = (*itsym).second.Pobj();
  413. if ( pbnobj->EType() != GOBJMBN::EBNO_NODE )
  414. continue; // It's not a node
  415. GNODEMBN * pgndbn;
  416. DynCastThrow(pbnobj,pgndbn);
  417. int iNode = INameIndex( pbnobj );
  418. assert( iNode == INameIndex( pbnobj->ZsrefName() ) );
  419. cout << "\n\tEntry "
  420. << iEntry++
  421. << ", inode "
  422. << iNode
  423. << " ";
  424. pgndbn->Dump();
  425. }
  426. }
  427. GOBJMBN_INFER_ENGINE * MBNET :: PInferEngine ()
  428. {
  429. GOBJMBN_INFER_ENGINE * pInferEng = NULL;
  430. for ( int iMod = _vpModifiers.size(); --iMod >= 0; )
  431. {
  432. MBNET_MODIFIER * pmodf = _vpModifiers[iMod];
  433. pInferEng = dynamic_cast<GOBJMBN_INFER_ENGINE *> ( pmodf );
  434. if ( pInferEng )
  435. break;
  436. }
  437. return pInferEng;
  438. }
  439. void MBNET :: ExpandCI ()
  440. {
  441. PushModifierStack( new GOBJMBN_MBNET_EXPANDER( self ) );
  442. }
  443. void MBNET :: UnexpandCI ()
  444. {
  445. MBNET_MODIFIER * pmodf = PModifierStackTop();
  446. if ( pmodf == NULL )
  447. return;
  448. if ( pmodf->EType() == GOBJMBN::EBNO_MBNET_EXPANDER )
  449. PopModifierStack();
  450. }
  451. // Return true if an edge is allowed between these two nodes
  452. bool MBNET :: BAcyclicEdge ( GNODEMBN * pgndSource, GNODEMBN * pgndSink )
  453. {
  454. ClearNodeMarks();
  455. pgndSink->Visit( false );
  456. return pgndSource->IMark() == 0;
  457. }