Source code of Windows XP (NT5)
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

523 lines
12 KiB

//+-------------------------------------------------------------------------
//
// Microsoft Windows
//
// Copyright (C) Microsoft Corporation, 1997 - 1998
//
// File: mbnet.cpp
//
//--------------------------------------------------------------------------
//
// mbnet.cpp: Belief network model member functions
//
#include <basetsd.h>
#include "basics.h"
#include "algos.h"
#include "gmprop.h"
#include "gmobj.h"
#include "cliqset.h"
#include "clique.h"
#include "expand.h"
MBNET :: MBNET ()
:_inmFree(0),
_iInferEngID(0)
{
}
MBNET :: ~ MBNET ()
{
PopModifierStack( true ); // Clear all modifiers from the network
// Clear the node-index-to-name information
_inmFree = 0;
_vzsrNames.clear();
}
//
// Clone this belief network from another. Note that the contents
// of the modifier stack (inference engines, expanders, etc.) are
// NOT cloned.
//
void MBNET :: Clone ( MODEL & model )
{
// This must be a truly empty structure
ASSERT_THROW( _vpModifiers.size() == 0 && _vzsrNames.size() == 0,
EC_INVALID_CLONE,
"cannot clone into non-empty structure" );
MODEL::Clone( model );
MBNET * pmbnet;
DynCastThrow( & model, pmbnet );
MBNET & mbnet = *pmbnet;
{
// Build the name table by iterating over the contents and
// allocating a slot for each node
GELEMLNK * pgelm;
MODELENUM mdlenumNode( mbnet );
while ( pgelm = mdlenumNode.PlnkelNext() )
{
// Check that it's a node (not an edge)
if ( ! pgelm->BIsEType( GELEM::EGELM_NODE ) )
continue;
GOBJMBN * pgobjmbn;
DynCastThrow( pgelm, pgobjmbn );
_vzsrNames.push_back( pgobjmbn->ZsrefName() );
}
_inmFree = _vzsrNames.size();
}
// Clone the distribution map
_mppd.Clone( _mpsymtbl, mbnet._mppd ) ;
// Check the topology if it's supposed to be present
#ifdef _DEBUG
if ( mbnet.BFlag( EIBF_Topology ) )
VerifyTopology();
#endif
}
//
// Iterate over the distributions, matching them to the nodes they belong to.
//
void MBNET :: VerifyTopology ()
{
for ( MPPD::iterator itpd = Mppd().begin();
itpd != Mppd().end();
itpd++ )
{
const VTKNPD & vtknpd = (*itpd).first;
const BNDIST * pbndist = (*itpd).second;
// Guarantee that the descriptor is of the form "p(X|...)"
if ( vtknpd.size() < 2
|| vtknpd[0] != TKNPD(DTKN_PD)
|| ! vtknpd[1].BStr() )
throw GMException( EC_INV_PD, "invalid token descriptor on PD");
// Get the name of the node whose distribution this is
SZC szc = vtknpd[1].Szc();
assert( szc ) ;
// Find that named thing in the graph
GOBJMBN * pbnobj = Mpsymtbl().find( szc );
assert( pbnobj && pbnobj->EType() == GOBJMBN::EBNO_NODE );
// Guarantee that it's a node
GNODEMBN * pgndbn = dynamic_cast<GNODEMBN *> (pbnobj);
ASSERT_THROW( pgndbn, EC_INV_PD, "token on PD references non-node");
// Verify the node's distribution
if ( ! pgndbn->BMatchTopology( *this, vtknpd ) )
{
throw GMException( EC_TOPOLOGY_MISMATCH,
"topology mismatch between PD and network");
}
}
}
MBNET_MODIFIER * MBNET :: PModifierStackTop ()
{
return _vpModifiers.size() > 0
? _vpModifiers[ _vpModifiers.size() - 1 ]
: NULL;
}
void MBNET :: PushModifierStack ( MBNET_MODIFIER * pmodf )
{
assert( pmodf );
pmodf->Create();
_vpModifiers.push_back( pmodf );
}
void MBNET :: PopModifierStack ( bool bAll )
{
int iPop = _vpModifiers.size();
while ( iPop > 0 )
{
MBNET_MODIFIER * pmodf = _vpModifiers[ --iPop ];
assert ( pmodf );
// NOTE: Deleting the object should be all that's necessary;
// object's destructor should call its Destroy() function.
delete pmodf;
if ( ! bAll )
break;
}
if ( iPop == 0 )
_vpModifiers.clear();
else
_vpModifiers.resize(iPop);
}
// Find the named object by index
GOBJMBN * MBNET :: PgobjFindByIndex ( int inm )
{
ZSREF zsMt;
if ( inm >= _vzsrNames.size()
|| _vzsrNames[inm] == zsMt )
return NULL;
return Mpsymtbl().find( _vzsrNames[inm] );
}
int MBNET :: INameIndex ( ZSREF zsr )
{
return ifind( _vzsrNames, zsr );
}
int MBNET :: INameIndex ( const GOBJMBN * pgobj )
{
return INameIndex( pgobj->ZsrefName() );
}
int MBNET :: CreateNameIndex ( const GOBJMBN * pgobj )
{
int ind = -1;
if ( _inmFree >= _vzsrNames.size() )
{
// No free slots; grow the array
ind = _vzsrNames.size();
_vzsrNames.push_back( pgobj->ZsrefName() );
_inmFree = _vzsrNames.size();
}
else
{
// Use the given free slot, find the next
_vzsrNames[ind = _inmFree] = pgobj->ZsrefName();
ZSREF zsMt;
for ( ; _inmFree < _vzsrNames.size() ; _inmFree++ )
{
if ( zsMt == _vzsrNames[_inmFree] )
break;
}
}
return ind;
}
void MBNET :: DeleteNameIndex ( int inm )
{
ASSERT_THROW( inm < _vzsrNames.size(),
EC_INTERNAL_ERROR,
"MBNET name index out of range" );
_vzsrNames[inm] = ZSREF();
if ( inm < _inmFree )
_inmFree = inm;
}
void MBNET :: DeleteNameIndex ( const GOBJMBN * pgobj )
{
int inm = INameIndex( pgobj );
if ( inm >= 0 )
DeleteNameIndex(inm);
}
// Add a named object to the graph and symbol table
void MBNET :: AddElem ( SZC szcName, GOBJMBN * pgelm )
{
if ( szcName == NULL || ::strlen(szcName) == 0 )
{
MODEL::AddElem( pgelm ); // empty name
}
else
{
MODEL::AddElem( szcName, pgelm );
assert( INameIndex( pgelm ) < 0 ); // guarantee no duplicates
CreateNameIndex( pgelm );
}
}
void MBNET :: DeleteElem ( GOBJMBN * pgobj )
{
DeleteNameIndex( pgobj );
MODEL::DeleteElem( pgobj );
}
/*
Iterator has moved into the MODEL class... I've left the code here
in case MBNET needs its own iterator. (Max, 05/12/97)
MBNET::ITER :: ITER ( MBNET & bnet, GOBJMBN::EBNOBJ eType )
: _eType(eType),
_bnet(bnet)
{
Reset();
}
void MBNET::ITER :: Reset ()
{
_pCurrent = NULL;
_itsym = _bnet.Mpsymtbl().begin();
BNext();
}
bool MBNET::ITER :: BNext ()
{
while ( _itsym != _bnet.Mpsymtbl().end() )
{
_pCurrent = (*_itsym).second.Pobj();
_zsrCurrent = (*_itsym).first;
_itsym++;
if ( _pCurrent->EType() == _eType )
return true;
}
_pCurrent = NULL;
return false;
}
*/
void MBNET :: CreateTopology ()
{
if ( BFlag( EIBF_Topology ) )
return;
// Walk the map of distributions. For each one, extract the node
// name and find it. Then add arcs for each parent.
#ifdef _DEBUG
UINT iCycleMax = 2;
#else
UINT iCycleMax = 1;
#endif
UINT iIter = 0;
for ( UINT iCycle = 0 ; iCycle < iCycleMax ; iCycle++ )
{
for ( MPPD::iterator itpd = Mppd().begin();
itpd != Mppd().end();
itpd++, iIter++ )
{
const VTKNPD & vtknpd = (*itpd).first;
const BNDIST * pbndist = (*itpd).second;
// Guarantee that the descriptor is of the form "p(X|...)"
if ( vtknpd.size() < 2
|| vtknpd[0] != TKNPD(DTKN_PD)
|| ! vtknpd[1].BStr() )
throw GMException( EC_INV_PD, "invalid token descriptor on PD");
// Get the name of the node whose distribution this is
SZC szcChild = vtknpd[1].Szc();
assert( szcChild ) ;
// Find that named thing in the graph
GOBJMBN * pbnobjChild = Mpsymtbl().find( szcChild );
assert( pbnobjChild && pbnobjChild->EType() == GOBJMBN::EBNO_NODE );
// Guarantee that it's a node
GNODEMBN * pgndbnChild = dynamic_cast<GNODEMBN *> (pbnobjChild);
ASSERT_THROW( pgndbnChild, EC_INV_PD, "token on PD references non-node");
UINT cParents = 0;
UINT cChildren = pgndbnChild->CChild();
for ( int i = 2; i < vtknpd.size(); i++ )
{
if ( ! vtknpd[i].BStr() )
continue;
SZC szcParent = vtknpd[i].Szc();
assert( szcParent) ;
GOBJMBN * pbnobjParent = Mpsymtbl().find( szcParent );
assert( pbnobjParent && pbnobjParent->EType() == GOBJMBN::EBNO_NODE );
GNODEMBN * pgndbnParent = (GNODEMBN *) pbnobjParent;
UINT cPrChildren = pgndbnParent->CChild();
if ( iCycle == 0 )
{
AddElem( new GEDGEMBN_PROB( pgndbnParent, pgndbnChild ) );
}
cParents++;
if ( iCycle == 0 )
{
UINT cChNew = pgndbnChild->CChild();
UINT cPrNew = pgndbnChild->CParent();
UINT cPrChNew = pgndbnParent->CChild();
assert( cPrChNew = cPrChildren + 1 );
assert( cChildren == cChNew );
}
}
if ( iCycle )
{
UINT cPrNew = pgndbnChild->CParent();
assert( cParents == cPrNew );
}
if ( iCycle == 0 )
{
#ifdef _DEBUG
if ( ! pgndbnChild->BMatchTopology( *this, vtknpd ) )
{
throw GMException( EC_TOPOLOGY_MISMATCH,
"topology mismatch between PD and network");
}
#endif
}
}
}
BSetBFlag( EIBF_Topology );
}
DEFINEVP(GEDGEMBN);
void MBNET :: DestroyTopology ( bool bDirectedOnly )
{
// Size up an array to hold pointers to all the edges
VPGEDGEMBN vpgedge;
int cItem = Grph().Chn().Count();
vpgedge.resize(cItem);
// Find all the arcs/edges
int iItem = 0;
GELEMLNK * pgelm;
MODELENUM mdlenum( self );
while ( pgelm = mdlenum.PlnkelNext() )
{
// Check that it's an edge
if ( ! pgelm->BIsEType( GELEM::EGELM_EDGE ) )
continue;
// Check that it's a directed probabilistic arc
if ( bDirectedOnly && pgelm->EType() != GEDGEMBN::ETPROB )
continue;
GEDGEMBN * pgedge;
DynCastThrow( pgelm, pgedge );
vpgedge[iItem++] = pgedge;
}
// Delete all the accumulated edges
for ( int i = 0; i < iItem; )
{
GEDGEMBN * pgedge = vpgedge[i++];
delete pgedge;
}
assert( Grph().Chn().Count() + iItem == cItem );
BSetBFlag( EIBF_Topology, false );
}
//
// Bind distributions to nodes. If they're already bound, exit.
// If the node has a distribution already, leave it.
//
void MBNET :: BindDistributions ( bool bBind )
{
bool bDist = BFlag( EIBF_Distributions );
if ( ! (bDist ^ bBind) )
return;
ITER itnd( self, GOBJMBN::EBNO_NODE );
for ( ; *itnd ; itnd++ )
{
GNODEMBND * pgndd = dynamic_cast<GNODEMBND *>(*itnd);
if ( pgndd == NULL )
continue;
if ( ! bBind )
{
pgndd->ClearDist();
}
else
if ( ! pgndd->BHasDist() )
{
pgndd->SetDist( self );
}
}
BSetBFlag( EIBF_Distributions, bBind );
}
void MBNET :: ClearNodeMarks ()
{
ITER itnd( self, GOBJMBN::EBNO_NODE );
for ( ; *itnd ; itnd++ )
{
GNODEMBN * pgndbn = NULL;
DynCastThrow( *itnd, pgndbn );
pgndbn->IMark() = 0;
}
}
void MBNET :: TopSortNodes ()
{
ClearNodeMarks();
ITER itnd( self, GOBJMBN::EBNO_NODE );
for ( ; *itnd ; itnd++ )
{
GNODEMBN * pgndbn = NULL;
DynCastThrow( *itnd, pgndbn );
pgndbn->Visit();
}
itnd.Reset();
for ( ; *itnd ; itnd++ )
{
GNODEMBN * pgndbn = NULL;
DynCastThrow( *itnd, pgndbn );
pgndbn->ITopLevel() = pgndbn->IMark();
}
}
void MBNET :: Dump ()
{
TopSortNodes();
UINT iEntry = 0;
for ( MPSYMTBL::iterator itsym = Mpsymtbl().begin();
itsym != Mpsymtbl().end();
itsym++ )
{
GOBJMBN * pbnobj = (*itsym).second.Pobj();
if ( pbnobj->EType() != GOBJMBN::EBNO_NODE )
continue; // It's not a node
GNODEMBN * pgndbn;
DynCastThrow(pbnobj,pgndbn);
int iNode = INameIndex( pbnobj );
assert( iNode == INameIndex( pbnobj->ZsrefName() ) );
cout << "\n\tEntry "
<< iEntry++
<< ", inode "
<< iNode
<< " ";
pgndbn->Dump();
}
}
GOBJMBN_INFER_ENGINE * MBNET :: PInferEngine ()
{
GOBJMBN_INFER_ENGINE * pInferEng = NULL;
for ( int iMod = _vpModifiers.size(); --iMod >= 0; )
{
MBNET_MODIFIER * pmodf = _vpModifiers[iMod];
pInferEng = dynamic_cast<GOBJMBN_INFER_ENGINE *> ( pmodf );
if ( pInferEng )
break;
}
return pInferEng;
}
void MBNET :: ExpandCI ()
{
PushModifierStack( new GOBJMBN_MBNET_EXPANDER( self ) );
}
void MBNET :: UnexpandCI ()
{
MBNET_MODIFIER * pmodf = PModifierStackTop();
if ( pmodf == NULL )
return;
if ( pmodf->EType() == GOBJMBN::EBNO_MBNET_EXPANDER )
PopModifierStack();
}
// Return true if an edge is allowed between these two nodes
bool MBNET :: BAcyclicEdge ( GNODEMBN * pgndSource, GNODEMBN * pgndSink )
{
ClearNodeMarks();
pgndSink->Visit( false );
return pgndSource->IMark() == 0;
}