38 #ifndef VIGRA_HIERARCHICAL_CLUSTERING_HXX
39 #define VIGRA_HIERARCHICAL_CLUSTERING_HXX
48 #include "priority_queue.hxx"
49 #include "metrics.hxx"
53 namespace cluster_operators{
58 class EDGE_INDICATOR_MAP,
60 class NODE_FEATURE_MAP,
77 typedef typename EDGE_INDICATOR_MAP::Value ValueType;
78 typedef ValueType WeightType;
79 typedef MERGE_GRAPH MergeGraph;
80 typedef typename MergeGraph::Graph Graph;
81 typedef typename Graph::Edge BaseGraphEdge;
82 typedef typename Graph::Node BaseGraphNode;
83 typedef typename MergeGraph::Edge Edge;
84 typedef typename MergeGraph::Node Node;
85 typedef typename MergeGraph::EdgeIt EdgeIt;
86 typedef typename MergeGraph::NodeIt NodeIt;
87 typedef typename MergeGraph::IncEdgeIt IncEdgeIt;
88 typedef typename MergeGraph::index_type index_type;
89 typedef MergeGraphItemHelper<MergeGraph,Edge> EdgeHelper;
90 typedef MergeGraphItemHelper<MergeGraph,Node> NodeHelper;
93 typedef typename EDGE_INDICATOR_MAP::Reference EdgeIndicatorReference;
94 typedef typename NODE_FEATURE_MAP::Reference NodeFeatureReference;
98 EDGE_INDICATOR_MAP edgeIndicatorMap,
99 EDGE_SIZE_MAP edgeSizeMap,
100 NODE_FEATURE_MAP nodeFeatureMap,
101 NODE_SIZE_MAP nodeSizeMap,
102 MIN_WEIGHT_MAP minWeightEdgeMap,
103 const ValueType beta,
104 const metrics::MetricType metricType,
105 const ValueType wardness=1.0
107 : mergeGraph_(mergeGraph),
108 edgeIndicatorMap_(edgeIndicatorMap),
109 edgeSizeMap_(edgeSizeMap),
110 nodeFeatureMap_(nodeFeatureMap),
111 nodeSizeMap_(nodeSizeMap),
112 minWeightEdgeMap_(minWeightEdgeMap),
113 pq_(mergeGraph.maxEdgeId()+1),
118 typedef typename MergeGraph::MergeNodeCallBackType MergeNodeCallBackType;
119 typedef typename MergeGraph::MergeEdgeCallBackType MergeEdgeCallBackType;
120 typedef typename MergeGraph::EraseEdgeCallBackType EraseEdgeCallBackType;
123 MergeNodeCallBackType cbMn(MergeNodeCallBackType:: template from_method<SelfType,&SelfType::mergeNodes>(
this));
124 MergeEdgeCallBackType cbMe(MergeEdgeCallBackType:: template from_method<SelfType,&SelfType::mergeEdges>(
this));
125 EraseEdgeCallBackType cbEe(EraseEdgeCallBackType:: template from_method<SelfType,&SelfType::eraseEdge>(
this));
127 mergeGraph_.registerMergeNodeCallBack(cbMn);
128 mergeGraph_.registerMergeEdgeCallBack(cbMe);
129 mergeGraph_.registerEraseEdgeCallBack(cbEe);
132 for(EdgeIt e(mergeGraph);e!=lemon::INVALID;++e){
133 const Edge
edge = *e;
134 const BaseGraphEdge graphEdge=EdgeHelper::itemToGraphItem(mergeGraph_,edge);
135 const index_type edgeId = mergeGraph_.id(edge);
136 const ValueType currentWeight = this->getEdgeWeight(edge);
137 pq_.
push(edgeId,currentWeight);
138 minWeightEdgeMap_[graphEdge]=currentWeight;
145 const BaseGraphEdge aa=EdgeHelper::itemToGraphItem(mergeGraph_,a);
146 const BaseGraphEdge bb=EdgeHelper::itemToGraphItem(mergeGraph_,b);
147 EdgeIndicatorReference va=edgeIndicatorMap_[aa];
148 EdgeIndicatorReference vb=edgeIndicatorMap_[bb];
149 va*=edgeSizeMap_[aa];
150 vb*=edgeSizeMap_[bb];
152 edgeSizeMap_[aa]+=edgeSizeMap_[bb];
153 va/=(edgeSizeMap_[aa]);
154 vb/=edgeSizeMap_[bb];
161 const BaseGraphNode aa=NodeHelper::itemToGraphItem(mergeGraph_,a);
162 const BaseGraphNode bb=NodeHelper::itemToGraphItem(mergeGraph_,b);
163 NodeFeatureReference va=nodeFeatureMap_[aa];
164 NodeFeatureReference vb=nodeFeatureMap_[bb];
165 va*=nodeSizeMap_[aa];
166 vb*=nodeSizeMap_[bb];
168 nodeSizeMap_[aa]+=nodeSizeMap_[bb];
169 va/=(nodeSizeMap_[aa]);
170 vb/=nodeSizeMap_[bb];
182 const Node newNode = mergeGraph_.inactiveEdgesNode(edge);
187 for (IncEdgeIt e(mergeGraph_,newNode);e!=lemon::INVALID;++e){
190 const Edge incEdge(*e);
193 const BaseGraphEdge incGraphEdge = EdgeHelper::itemToGraphItem(mergeGraph_,incEdge);
198 const ValueType newWeight = getEdgeWeight(incEdge);
202 pq_.
push(incEdge.id(),newWeight);
206 minWeightEdgeMap_[incGraphEdge]=newWeight;
214 index_type minLabel = pq_.
top();
215 while(mergeGraph_.hasEdgeId(minLabel)==
false){
217 minLabel = pq_.
top();
219 return Edge(minLabel);
224 index_type minLabel = pq_.
top();
225 while(mergeGraph_.hasEdgeId(minLabel)==
false){
227 minLabel = pq_.
top();
239 ValueType getEdgeWeight(
const Edge & e){
241 const Node u = mergeGraph_.u(e);
242 const Node v = mergeGraph_.v(e);
244 const BaseGraphEdge ee=EdgeHelper::itemToGraphItem(mergeGraph_,e);
245 const BaseGraphNode uu=NodeHelper::itemToGraphItem(mergeGraph_,u);
246 const BaseGraphNode vv=NodeHelper::itemToGraphItem(mergeGraph_,v);
248 const ValueType wardFacRaw = 1.0 / ( 1.0/
std::log(nodeSizeMap_[uu]) + 1.0/
std::log(nodeSizeMap_[vv]) );
249 const ValueType wardFac = (wardFacRaw*wardness_) + (1.0-wardness_);
251 const ValueType fromEdgeIndicator = edgeIndicatorMap_[ee];
252 ValueType fromNodeDist = metric_(nodeFeatureMap_[uu],nodeFeatureMap_[vv]);
253 const ValueType totalWeight = ((1.0-beta_)*fromEdgeIndicator + beta_*fromNodeDist)*wardFac;
258 MergeGraph & mergeGraph_;
259 EDGE_INDICATOR_MAP edgeIndicatorMap_;
260 EDGE_SIZE_MAP edgeSizeMap_;
261 NODE_FEATURE_MAP nodeFeatureMap_;
262 NODE_SIZE_MAP nodeSizeMap_;
263 MIN_WEIGHT_MAP minWeightEdgeMap_;
268 metrics::Metric<float> metric_;
275 template<
class CLUSTER_OPERATOR>
279 typedef CLUSTER_OPERATOR ClusterOperator;
280 typedef typename ClusterOperator::MergeGraph MergeGraph;
281 typedef typename MergeGraph::Graph Graph;
282 typedef typename Graph::Edge BaseGraphEdge;
283 typedef typename Graph::Node BaseGraphNode;
284 typedef typename MergeGraph::Edge Edge;
285 typedef typename MergeGraph::Node Node;
286 typedef typename CLUSTER_OPERATOR::WeightType ValueType;
287 typedef typename MergeGraph::index_type MergeGraphIndexType;
291 const size_t nodeNumStopCond = 1,
292 const bool buildMergeTree =
true,
293 const bool verbose =
false
295 : nodeNumStopCond_ (nodeNumStopCond),
296 buildMergeTreeEncoding_(buildMergeTree),
299 size_t nodeNumStopCond_;
300 bool buildMergeTreeEncoding_;
306 const MergeGraphIndexType a,
307 const MergeGraphIndexType b,
308 const MergeGraphIndexType r,
311 a_(a),b_(b),r_(r),w_(w){
313 MergeGraphIndexType a_;
314 MergeGraphIndexType b_;
315 MergeGraphIndexType r_;
319 typedef std::vector<MergeItem> MergeTreeEncoding;
323 ClusterOperator & clusterOperator,
324 const Parameter & parameter = Parameter()
327 clusterOperator_(clusterOperator),
330 graph_(mergeGraph_.
graph()),
331 timestamp_(graph_.maxNodeId()+1),
332 toTimeStamp_(graph_.maxNodeId()+1),
333 timeStampIndexToMergeIndex_(graph_.maxNodeId()+1),
334 mergeTreeEndcoding_()
338 mergeTreeEndcoding_.reserve(graph_.nodeNum()*2);
340 for(MergeGraphIndexType nodeId=0;nodeId<=mergeGraph_.maxNodeId();++nodeId){
341 toTimeStamp_[nodeId]=nodeId;
349 while(mergeGraph_.nodeNum()>param_.nodeNumStopCond_ && mergeGraph_.edgeNum()>0){
352 const Edge edgeToRemove = clusterOperator_.contractionEdge();
353 if(param_.buildMergeTreeEncoding_){
354 const MergeGraphIndexType uid = mergeGraph_.id(mergeGraph_.u(edgeToRemove));
355 const MergeGraphIndexType vid = mergeGraph_.id(mergeGraph_.v(edgeToRemove));
356 const ValueType w = clusterOperator_.contractionWeight();
358 mergeGraph_.contractEdge( edgeToRemove);
359 const MergeGraphIndexType aliveNodeId = mergeGraph_.hasNodeId(uid) ? uid : vid;
360 const MergeGraphIndexType deadNodeId = aliveNodeId==vid ? uid : vid;
361 timeStampIndexToMergeIndex_[timeStampToIndex(timestamp_)]=mergeTreeEndcoding_.size();
362 mergeTreeEndcoding_.push_back(MergeItem( toTimeStamp_[aliveNodeId],toTimeStamp_[deadNodeId],timestamp_,w));
363 toTimeStamp_[aliveNodeId]=timestamp_;
368 mergeGraph_.contractEdge( edgeToRemove );
370 if(param_.verbose_ && mergeGraph_.nodeNum()%10==0)
371 std::cout<<
"\rNodes: "<<std::setw(10)<<mergeGraph_.nodeNum()<<std::flush;
380 return mergeTreeEndcoding_;
384 template<
class OUT_ITER>
385 size_t leafNodeIds(
const MergeGraphIndexType treeNodeId, OUT_ITER begin)
const{
386 if(treeNodeId<=graph_.maxNodeId()){
393 std::queue<MergeGraphIndexType> queue;
394 queue.push(treeNodeId);
396 while(!queue.empty()){
398 const MergeGraphIndexType
id = queue.front();
400 const MergeGraphIndexType mergeIndex = timeStampToMergeIndex(
id);
401 const MergeGraphIndexType ab[]= { mergeTreeEndcoding_[mergeIndex].a_, mergeTreeEndcoding_[mergeIndex].b_};
403 for(
size_t i=0;i<2;++i){
404 if(ab[i]<=graph_.maxNodeId()){
429 const MergeGraphIndexType
reprNodeId(
const MergeGraphIndexType
id)
const{
430 return mergeGraph_.reprNodeId(
id);
434 MergeGraphIndexType timeStampToIndex(
const MergeGraphIndexType timestamp)
const{
435 return timestamp- graph_.maxNodeId();
439 MergeGraphIndexType timeStampToMergeIndex(
const MergeGraphIndexType timestamp)
const{
440 return timeStampIndexToMergeIndex_[timeStampToIndex(timestamp)];
443 ClusterOperator & clusterOperator_;
445 MergeGraph & mergeGraph_;
446 const Graph & graph_;
451 MergeGraphIndexType timestamp_;
452 std::vector<MergeGraphIndexType> toTimeStamp_;
453 std::vector<MergeGraphIndexType> timeStampIndexToMergeIndex_;
455 MergeTreeEncoding mergeTreeEndcoding_;
463 #endif // VIGRA_HIERARCHICAL_CLUSTERING_HXX