[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

random_forest_visitors.hxx VIGRA

1 /************************************************************************/
2 /* */
3 /* Copyright 2014-2015 by Ullrich Koethe and Philip Schill */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 #ifndef VIGRA_RF3_VISITORS_HXX
36 #define VIGRA_RF3_VISITORS_HXX
37 
38 #include <vector>
39 #include <memory>
40 #include "../multi_array.hxx"
41 #include "../multi_shape.hxx"
42 #include <typeinfo>
43 
44 
45 namespace vigra
46 {
47 namespace rf3
48 {
49 
50 /**
51  * @brief Base class from which all random forest visitors derive.
52  *
53  * @details
54  * Due to the parallel training, we cannot simply use a single visitor for all trees.
55  * Instead, each tree gets a copy of the original visitor.
56  *
57  * The random forest training with visitors looks as follows:
58  * - Do the random forest preprocessing (translate labels to 0, 1, 2, ...).
59  * - Call visit_at_beginning() on the original visitor.
60  * - For each tree:
61  * - - Copy the original visitor and give the copy to the tree.
62  * - - Do the preprocessing (create the bootstrap sample, assign weights to the data points, ...).
63  * - - Call visit_before_tree() on the visitor copy.
64  * - - Do the node splitting until the tree is fully trained.
65  * - - Call visit_after_tree() on the visitor copy.
66  * - Call visit_at_end (which gets a vector with pointers to the visitor copies) on the original visitor.
67  */
69 {
70 public:
71 
73  :
74  active_(true)
75  {}
76 
77  /**
78  * @brief Do something before training starts.
79  */
81  {}
82 
83  /**
84  * @brief Do something after all trees have been learned.
85  *
86  * @param v vector with pointers to the visitor copies
87  * @param rf the trained random forest
88  */
89  template <typename VISITORS, typename RF, typename FEATURES, typename LABELS>
90  void visit_after_training(VISITORS &, RF &, const FEATURES &, const LABELS &)
91  {}
92 
93  /**
94  * @brief Do something before a tree has been learned.
95  *
96  * @param weights the actual instance weights (after bootstrap sampling and class weights)
97  */
98  template <typename TREE, typename FEATURES, typename LABELS, typename WEIGHTS>
99  void visit_before_tree(TREE &, FEATURES &, LABELS &, WEIGHTS &)
100  {}
101 
102  /**
103  * @brief Do something after a tree has been learned.
104  */
105  template <typename RF, typename FEATURES, typename LABELS, typename WEIGHTS>
106  void visit_after_tree(RF &,
107  FEATURES &,
108  LABELS &,
109  WEIGHTS &)
110  {}
111 
112  /**
113  * @brief Do something after the split was made.
114  */
115  template <typename TREE,
116  typename FEATURES,
117  typename LABELS,
118  typename WEIGHTS,
119  typename SCORER,
120  typename ITER>
121  void visit_after_split(TREE &,
122  FEATURES &,
123  LABELS &,
124  WEIGHTS &,
125  SCORER &,
126  ITER,
127  ITER,
128  ITER)
129  {}
130 
131  /**
132  * @brief Return whether the visitor is active or not.
133  */
134  bool is_active() const
135  {
136  return active_;
137  }
138 
139  /**
140  * @brief Activate the visitor.
141  */
142  void activate()
143  {
144  active_ = true;
145  }
146 
147  /**
148  * @brief Deactivate the visitor.
149  */
150  void deactivate()
151  {
152  active_ = false;
153  }
154 
155 private:
156 
157  bool active_;
158 
159 };
160 
161 /////////////////////////////////////////////
162 // The concrete visitors //
163 /////////////////////////////////////////////
164 
165 /**
166  * @brief Compute the out of bag error.
167  *
168  * After training, each data point is put down those trees for which it is OOB.
169  * Using bootstrap sampling, each data point is OOB for approximately 37% of
170  * the trees.
171  */
172 class OOBError : public RFVisitorBase
173 {
174 public:
175 
176  /**
177  * Save whether a data point is in-bag (weight > 0) or out-of-bag (weight == 0).
178  */
179  template <typename TREE, typename FEATURES, typename LABELS, typename WEIGHTS>
181  TREE & /*tree*/,
182  FEATURES & /*features*/,
183  LABELS & /*labels*/,
184  WEIGHTS & weights
185  ){
186  double const EPS = 1e-20;
187  bool found = false;
188 
189  // Save the in-bags.
190  is_in_bag_.resize(weights.size(), true);
191  for (size_t i = 0; i < weights.size(); ++i)
192  {
193  if (std::abs(weights[i]) < EPS)
194  {
195  is_in_bag_[i] = false;
196  found = true;
197  }
198  }
199 
200  if (!found)
201  throw std::runtime_error("OOBError::visit_before_tree(): The tree has no out-of-bags.");
202  }
203 
204  /**
205  * Compute the out-of-bag error.
206  */
207  template <typename VISITORS, typename RF, typename FEATURES, typename LABELS>
209  VISITORS & visitors,
210  RF & rf,
211  const FEATURES & features,
212  const LABELS & labels
213  ){
214  // Check the input sizes.
215  vigra_precondition(rf.num_trees() > 0, "OOBError::visit_after_training(): Number of trees must be greater than zero after training.");
216  vigra_precondition(visitors.size() == rf.num_trees(), "OOBError::visit_after_training(): Number of visitors must be equal to number of trees.");
217  size_t const num_instances = features.shape()[0];
218  auto const num_features = features.shape()[1];
219  for (auto vptr : visitors)
220  vigra_precondition(vptr->is_in_bag_.size() == num_instances, "OOBError::visit_after_training(): Some visitors have the wrong number of data points.");
221 
222  // Get a prediction for each data point using only the trees where it is out of bag.
223  typedef typename std::remove_const<LABELS>::type Labels;
224  Labels pred(Shape1(1));
225  oob_err_ = 0.0;
226  for (size_t i = 0; i < (size_t)num_instances; ++i)
227  {
228  // Get the indices of the trees where the data points is out of bag.
229  std::vector<size_t> tree_indices;
230  for (size_t k = 0; k < visitors.size(); ++k)
231  if (!visitors[k]->is_in_bag_[i])
232  tree_indices.push_back(k);
233 
234  // Get the prediction using the above trees.
235  auto const sub_features = features.subarray(Shape2(i, 0), Shape2(i+1, num_features));
236  rf.predict(sub_features, pred, 1, tree_indices);
237  if (pred(0) != labels(i))
238  oob_err_ += 1.0;
239  }
240  oob_err_ /= num_instances;
241  }
242 
243  /**
244  * the out-of-bag error
245  */
246  double oob_err_;
247 
248 private:
249  std::vector<bool> is_in_bag_; // whether a data point is in-bag or out-of-bag
250 };
251 
252 
253 
254 /**
255  * @brief Compute the variable importance.
256  */
258 {
259 public:
260 
261  VariableImportance(size_t repetition_count = 10)
262  :
263  repetition_count_(repetition_count)
264  {}
265 
266  /**
267  * Resize the variable importance array and store in-bag / out-of-bag information.
268  */
269  template <typename TREE, typename FEATURES, typename LABELS, typename WEIGHTS>
271  TREE & tree,
272  FEATURES & features,
273  LABELS & /*labels*/,
274  WEIGHTS & weights
275  ){
276  // Resize the variable importance array.
277  // The shape differs from the shape of the actual output, since the single trees
278  // only store the impurity decrease without the permutation importances.
279  auto const num_features = features.shape()[1];
280  variable_importance_.reshape(Shape2(num_features, tree.num_classes()+2), 0.0);
281 
282  // Save the in-bags.
283  double const EPS = 1e-20;
284  bool found = false;
285  is_in_bag_.resize(weights.size(), true);
286  for (size_t i = 0; i < weights.size(); ++i)
287  {
288  if (std::abs(weights[i]) < EPS)
289  {
290  is_in_bag_[i] = false;
291  found = true;
292  }
293  }
294  if (!found)
295  throw std::runtime_error("VariableImportance::visit_before_tree(): The tree has no out-of-bags.");
296  }
297 
298  /**
299  * Calculate the impurity decrease based variable importance after every split.
300  */
301  template <typename TREE,
302  typename FEATURES,
303  typename LABELS,
304  typename WEIGHTS,
305  typename SCORER,
306  typename ITER>
307  void visit_after_split(TREE & tree,
308  FEATURES & /*features*/,
309  LABELS & labels,
310  WEIGHTS & weights,
311  SCORER & scorer,
312  ITER begin,
313  ITER /*split*/,
314  ITER end)
315  {
316  // Update the impurity decrease.
317  typename SCORER::Functor functor;
318  auto const region_impurity = functor.region_score(labels, weights, begin, end);
319  auto const split_impurity = scorer.best_score_;
320  variable_importance_(scorer.best_dim_, tree.num_classes()+1) += region_impurity - split_impurity;
321  }
322 
323  /**
324  * Compute the permutation importance.
325  */
326  template <typename RF, typename FEATURES, typename LABELS, typename WEIGHTS>
327  void visit_after_tree(RF & rf,
328  const FEATURES & features,
329  const LABELS & labels,
330  WEIGHTS & /*weights*/)
331  {
332  // Non-const types of features and labels.
333  typedef typename std::remove_const<FEATURES>::type Features;
334  typedef typename std::remove_const<LABELS>::type Labels;
335 
336  typedef typename Features::value_type FeatureType;
337 
338  auto const num_features = features.shape()[1];
339 
340  // For the permutation importance, the features must be permuted (obviously).
341  // This cannot be done on the original feature matrix, since it would interfere
342  // with other threads in concurrent training. Therefore, we have to make a copy.
343  Features feats;
344  Labels labs;
345  copy_out_of_bags(features, labels, feats, labs);
346  auto const num_oobs = feats.shape()[0];
347 
348  // Compute (standard and class-wise) out-of-bag success rate with the original sample.
349  MultiArray<1, double> oob_right(Shape1(rf.num_classes()+1), 0.0);
350  vigra::MultiArray<1,int> pred( (Shape1(num_oobs)) );
351  rf.predict(feats, pred, 1);
352  for (size_t i = 0; i < (size_t)labs.size(); ++i)
353  {
354  if (labs(i) == pred(i))
355  {
356  oob_right(labs(i)) += 1.0; // per class
357  oob_right(rf.num_classes()) += 1.0; // total
358  }
359  }
360 
361  // Get out-of-bag success rate after permuting the j'th dimension.
363  for (size_t j = 0; j < (size_t)num_features; ++j)
364  {
365  MultiArray<1, FeatureType> backup(( Shape1(num_oobs) ));
366  backup = feats.template bind<1>(j);
367  MultiArray<2, double> perm_oob_right(Shape2(1, rf.num_classes()+1), 0.0);
368 
369  for (size_t k = 0; k < repetition_count_; ++k)
370  {
371  // Permute the current dimension.
372  for (int ii = num_oobs-1; ii >= 1; --ii)
373  std::swap(feats(ii, j), feats(randint(ii+1), j));
374 
375  // Get the out-of-bag success rate after permuting.
376  rf.predict(feats, pred, 1);
377  for (size_t i = 0; i < (size_t)labs.size(); ++i)
378  {
379  if (labs(i) == pred(i))
380  {
381  perm_oob_right(0, labs(i)) += 1.0; // per class
382  perm_oob_right(0, rf.num_classes()) += 1.0; // total
383  }
384  }
385  }
386 
387  // Normalize and add to the importance matrix.
388  perm_oob_right /= repetition_count_;
389  perm_oob_right.bind<0>(0) -= oob_right;
390  perm_oob_right *= -1;
391  perm_oob_right /= num_oobs;
392  variable_importance_.subarray(Shape2(j, 0), Shape2(j+1, rf.num_classes()+1)) += perm_oob_right;
393 
394  // Copy back the permuted dimension.
395  feats.template bind<1>(j) = backup;
396  }
397  }
398 
399  /**
400  * Accumulate the variable importances from the single trees.
401  */
402  template <typename VISITORS, typename RF, typename FEATURES, typename LABELS>
404  VISITORS & visitors,
405  RF & rf,
406  const FEATURES & features,
407  const LABELS & /*labels*/
408  ){
409  vigra_precondition(rf.num_trees() > 0, "VariableImportance::visit_after_training(): Number of trees must be greater than zero after training.");
410  vigra_precondition(visitors.size() == rf.num_trees(), "VariableImportance::visit_after_training(): Number of visitors must be equal to number of trees.");
411 
412  // Sum the variable importances from the single trees.
413  auto const num_features = features.shape()[1];
414  variable_importance_.reshape(Shape2(num_features, rf.num_classes()+2), 0.0);
415  for (auto vptr : visitors)
416  {
417  vigra_precondition(vptr->variable_importance_.shape() == variable_importance_.shape(),
418  "VariableImportance::visit_after_training(): Shape mismatch.");
419  variable_importance_ += vptr->variable_importance_;
420  }
421 
422  // Normalize the variable importance.
423  variable_importance_ /= rf.num_trees();
424  }
425 
426  /**
427  * This Array has the same entries as the R - random forest variable
428  * importance.
429  * Matrix is featureCount by (classCount +2)
430  * variable_importance_(ii,jj) is the variable importance measure of
431  * the ii-th variable according to:
432  * jj = 0 - (classCount-1)
433  * classwise permutation importance
434  * jj = rowCount(variable_importance_) -2
435  * permutation importance
436  * jj = rowCount(variable_importance_) -1
437  * gini decrease importance.
438  *
439  * permutation importance:
440  * The difference between the fraction of OOB samples classified correctly
441  * before and after permuting (randomizing) the ii-th column is calculated.
442  * The ii-th column is permuted rep_cnt times.
443  *
444  * class wise permutation importance:
445  * same as permutation importance. We only look at those OOB samples whose
446  * response corresponds to class jj.
447  *
448  * gini decrease importance:
449  * row ii corresponds to the sum of all gini decreases induced by variable ii
450  * in each node of the random forest.
451  */
453 
454  /**
455  * how often the permutation takes place
456  */
458 
459 private:
460 
461  /**
462  * Copy the out-of-bag features and labels.
463  */
464  template <typename F0, typename L0, typename F1, typename L1>
465  void copy_out_of_bags(
466  F0 const & features_in,
467  L0 const & labels_in,
468  F1 & features_out,
469  L1 & labels_out
470  ) const {
471  auto const num_instances = features_in.shape()[0];
472  auto const num_features = features_in.shape()[1];
473 
474  // Count the out-of-bags.
475  size_t num_oobs = 0;
476  for (auto x : is_in_bag_)
477  if (!x)
478  ++num_oobs;
479 
480  // Copy the out-of-bags.
481  features_out.reshape(Shape2(num_oobs, num_features));
482  labels_out.reshape(Shape1(num_oobs));
483  size_t current = 0;
484  for (size_t i = 0; i < (size_t)num_instances; ++i)
485  {
486  if (!is_in_bag_[i])
487  {
488  auto const src = features_in.template bind<0>(i);
489  auto out = features_out.template bind<0>(current);
490  out = src;
491  labels_out(current) = labels_in(i);
492  ++current;
493  }
494  }
495  }
496 
497  std::vector<bool> is_in_bag_; // whether a data point is in-bag or out-of-bag
498 };
499 
500 
501 
502 /////////////////////////////////////////////
503 // The visitor chain //
504 /////////////////////////////////////////////
505 
506 /**
507  * @brief The default visitor node (= "do nothing").
508  */
510 {};
511 
512 namespace detail
513 {
514 
515 /**
516  * @brief Container elements of the statically linked visitor list. Use the create_visitor() functions to create visitors up to size 10.
517  */
518 template <typename VISITOR, typename NEXT = RFStopVisiting, bool CPY = false>
520 {
521 public:
522 
523  typedef VISITOR Visitor;
524  typedef NEXT Next;
525 
526  typename std::conditional<CPY, Visitor, Visitor &>::type visitor_;
527  Next next_;
528 
529  RFVisitorNode(Visitor & visitor, Next next)
530  :
531  visitor_(visitor),
532  next_(next)
533  {}
534 
535  explicit RFVisitorNode(Visitor & visitor)
536  :
537  visitor_(visitor),
538  next_(RFStopVisiting())
539  {}
540 
542  :
543  visitor_(other.visitor_),
544  next_(other.next_)
545  {}
546 
547  explicit RFVisitorNode(RFVisitorNode<Visitor, Next, !CPY> const & other)
548  :
549  visitor_(other.visitor_),
550  next_(other.next_)
551  {}
552 
553  void visit_before_training()
554  {
555  if (visitor_.is_active())
556  visitor_.visit_before_training();
557  next_.visit_before_training();
558  }
559 
560  template <typename VISITORS, typename RF, typename FEATURES, typename LABELS>
561  void visit_after_training(VISITORS & v, RF & rf, const FEATURES & features, const LABELS & labels)
562  {
563  typedef typename VISITORS::value_type VisitorNodeType;
564  typedef typename VisitorNodeType::Visitor VisitorType;
565  typedef typename VisitorNodeType::Next NextType;
566 
567  // We want to call the visit_after_training function of the concrete visitor (e. g. OOBError).
568  // Since v is a vector of visitor nodes (and not a vector of concrete visitors), we have to
569  // extract the concrete visitors.
570  // A vector cannot hold references, so we use pointers instead.
571  if (visitor_.is_active())
572  {
573  std::vector<VisitorType*> visitors;
574  for (auto & x : v)
575  visitors.push_back(&x.visitor_);
576  visitor_.visit_after_training(visitors, rf, features, labels);
577  }
578 
579  // Remove the concrete visitors that we just visited.
580  std::vector<NextType> nexts;
581  for (auto & x : v)
582  nexts.push_back(x.next_);
583 
584  // Call the next visitor node in the chain.
585  next_.visit_after_training(nexts, rf, features, labels);
586  }
587 
588  template <typename TREE, typename FEATURES, typename LABELS, typename WEIGHTS>
589  void visit_before_tree(TREE & tree, FEATURES & features, LABELS & labels, WEIGHTS & weights)
590  {
591  if (visitor_.is_active())
592  visitor_.visit_before_tree(tree, features, labels, weights);
593  next_.visit_before_tree(tree, features, labels, weights);
594  }
595 
596  template <typename RF, typename FEATURES, typename LABELS, typename WEIGHTS>
597  void visit_after_tree(RF & rf,
598  FEATURES & features,
599  LABELS & labels,
600  WEIGHTS & weights)
601  {
602  if (visitor_.is_active())
603  visitor_.visit_after_tree(rf, features, labels, weights);
604  next_.visit_after_tree(rf, features, labels, weights);
605  }
606 
607  template <typename TREE,
608  typename FEATURES,
609  typename LABELS,
610  typename WEIGHTS,
611  typename SCORER,
612  typename ITER>
613  void visit_after_split(TREE & tree,
614  FEATURES & features,
615  LABELS & labels,
616  WEIGHTS & weights,
617  SCORER & scorer,
618  ITER begin,
619  ITER split,
620  ITER end)
621  {
622  if (visitor_.is_active())
623  visitor_.visit_after_split(tree, features, labels, weights, scorer, begin, split, end);
624  next_.visit_after_split(tree, features, labels, weights, scorer, begin, split, end);
625  }
626 
627 };
628 
629 } // namespace detail
630 
631 /**
632  * The VisitorCopy can be used to set the copy argument of the given visitor chain to true.
633  */
634 template <typename VISITOR>
636 {
638 };
639 
640 template <>
642 {
643  typedef RFStopVisiting type;
644 };
645 
646 
647 
648 //////////////////////////////////////////////////////////
649 // Visitor factory functions for up to 10 visitors. //
650 // FIXME: This should be a variadic template. //
651 //////////////////////////////////////////////////////////
652 
653 template<typename A>
654 detail::RFVisitorNode<A>
655 create_visitor(A & a)
656 {
657  typedef detail::RFVisitorNode<A> _0_t;
658  _0_t _0(a);
659  return _0;
660 }
661 
662 template<typename A, typename B>
663 detail::RFVisitorNode<A, detail::RFVisitorNode<B> >
664 create_visitor(A & a, B & b)
665 {
666  typedef detail::RFVisitorNode<B> _1_t;
667  _1_t _1(b);
668  typedef detail::RFVisitorNode<A, _1_t> _0_t;
669  _0_t _0(a, _1);
670  return _0;
671 }
672 
673 template<typename A, typename B, typename C>
674 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C> > >
675 create_visitor(A & a, B & b, C & c)
676 {
677  typedef detail::RFVisitorNode<C> _2_t;
678  _2_t _2(c);
679  typedef detail::RFVisitorNode<B, _2_t> _1_t;
680  _1_t _1(b, _2);
681  typedef detail::RFVisitorNode<A, _1_t> _0_t;
682  _0_t _0(a, _1);
683  return _0;
684 }
685 
686 template<typename A, typename B, typename C, typename D>
687 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C,
688  detail::RFVisitorNode<D> > > >
689 create_visitor(A & a, B & b, C & c, D & d)
690 {
691  typedef detail::RFVisitorNode<D> _3_t;
692  _3_t _3(d);
693  typedef detail::RFVisitorNode<C, _3_t> _2_t;
694  _2_t _2(c, _3);
695  typedef detail::RFVisitorNode<B, _2_t> _1_t;
696  _1_t _1(b, _2);
697  typedef detail::RFVisitorNode<A, _1_t> _0_t;
698  _0_t _0(a, _1);
699  return _0;
700 }
701 
702 template<typename A, typename B, typename C, typename D, typename E>
703 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C,
704  detail::RFVisitorNode<D, detail::RFVisitorNode<E> > > > >
705 create_visitor(A & a, B & b, C & c, D & d, E & e)
706 {
707  typedef detail::RFVisitorNode<E> _4_t;
708  _4_t _4(e);
709  typedef detail::RFVisitorNode<D, _4_t> _3_t;
710  _3_t _3(d, _4);
711  typedef detail::RFVisitorNode<C, _3_t> _2_t;
712  _2_t _2(c, _3);
713  typedef detail::RFVisitorNode<B, _2_t> _1_t;
714  _1_t _1(b, _2);
715  typedef detail::RFVisitorNode<A, _1_t> _0_t;
716  _0_t _0(a, _1);
717  return _0;
718 }
719 
720 template<typename A, typename B, typename C, typename D, typename E,
721  typename F>
722 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C,
723  detail::RFVisitorNode<D, detail::RFVisitorNode<E, detail::RFVisitorNode<F> > > > > >
724 create_visitor(A & a, B & b, C & c, D & d, E & e, F & f)
725 {
726  typedef detail::RFVisitorNode<F> _5_t;
727  _5_t _5(f);
728  typedef detail::RFVisitorNode<E, _5_t> _4_t;
729  _4_t _4(e, _5);
730  typedef detail::RFVisitorNode<D, _4_t> _3_t;
731  _3_t _3(d, _4);
732  typedef detail::RFVisitorNode<C, _3_t> _2_t;
733  _2_t _2(c, _3);
734  typedef detail::RFVisitorNode<B, _2_t> _1_t;
735  _1_t _1(b, _2);
736  typedef detail::RFVisitorNode<A, _1_t> _0_t;
737  _0_t _0(a, _1);
738  return _0;
739 }
740 
741 template<typename A, typename B, typename C, typename D, typename E,
742  typename F, typename G>
743 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C,
744  detail::RFVisitorNode<D, detail::RFVisitorNode<E, detail::RFVisitorNode<F,
745  detail::RFVisitorNode<G> > > > > > >
746 create_visitor(A & a, B & b, C & c, D & d, E & e, F & f, G & g)
747 {
748  typedef detail::RFVisitorNode<G> _6_t;
749  _6_t _6(g);
750  typedef detail::RFVisitorNode<F, _6_t> _5_t;
751  _5_t _5(f, _6);
752  typedef detail::RFVisitorNode<E, _5_t> _4_t;
753  _4_t _4(e, _5);
754  typedef detail::RFVisitorNode<D, _4_t> _3_t;
755  _3_t _3(d, _4);
756  typedef detail::RFVisitorNode<C, _3_t> _2_t;
757  _2_t _2(c, _3);
758  typedef detail::RFVisitorNode<B, _2_t> _1_t;
759  _1_t _1(b, _2);
760  typedef detail::RFVisitorNode<A, _1_t> _0_t;
761  _0_t _0(a, _1);
762  return _0;
763 }
764 
765 template<typename A, typename B, typename C, typename D, typename E,
766  typename F, typename G, typename H>
767 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C,
768  detail::RFVisitorNode<D, detail::RFVisitorNode<E, detail::RFVisitorNode<F,
769  detail::RFVisitorNode<G, detail::RFVisitorNode<H> > > > > > > >
770 create_visitor(A & a, B & b, C & c, D & d, E & e, F & f, G & g, H & h)
771 {
772  typedef detail::RFVisitorNode<H> _7_t;
773  _7_t _7(h);
774  typedef detail::RFVisitorNode<G, _7_t> _6_t;
775  _6_t _6(g, _7);
776  typedef detail::RFVisitorNode<F, _6_t> _5_t;
777  _5_t _5(f, _6);
778  typedef detail::RFVisitorNode<E, _5_t> _4_t;
779  _4_t _4(e, _5);
780  typedef detail::RFVisitorNode<D, _4_t> _3_t;
781  _3_t _3(d, _4);
782  typedef detail::RFVisitorNode<C, _3_t> _2_t;
783  _2_t _2(c, _3);
784  typedef detail::RFVisitorNode<B, _2_t> _1_t;
785  _1_t _1(b, _2);
786  typedef detail::RFVisitorNode<A, _1_t> _0_t;
787  _0_t _0(a, _1);
788  return _0;
789 }
790 
791 template<typename A, typename B, typename C, typename D, typename E,
792  typename F, typename G, typename H, typename I>
793 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C,
794  detail::RFVisitorNode<D, detail::RFVisitorNode<E, detail::RFVisitorNode<F,
795  detail::RFVisitorNode<G, detail::RFVisitorNode<H, detail::RFVisitorNode<I> > > > > > > > >
796 create_visitor(A & a, B & b, C & c, D & d, E & e, F & f, G & g, H & h, I & i)
797 {
798  typedef detail::RFVisitorNode<I> _8_t;
799  _8_t _8(i);
800  typedef detail::RFVisitorNode<H, _8_t> _7_t;
801  _7_t _7(h, _8);
802  typedef detail::RFVisitorNode<G, _7_t> _6_t;
803  _6_t _6(g, _7);
804  typedef detail::RFVisitorNode<F, _6_t> _5_t;
805  _5_t _5(f, _6);
806  typedef detail::RFVisitorNode<E, _5_t> _4_t;
807  _4_t _4(e, _5);
808  typedef detail::RFVisitorNode<D, _4_t> _3_t;
809  _3_t _3(d, _4);
810  typedef detail::RFVisitorNode<C, _3_t> _2_t;
811  _2_t _2(c, _3);
812  typedef detail::RFVisitorNode<B, _2_t> _1_t;
813  _1_t _1(b, _2);
814  typedef detail::RFVisitorNode<A, _1_t> _0_t;
815  _0_t _0(a, _1);
816  return _0;
817 }
818 
819 template<typename A, typename B, typename C, typename D, typename E,
820  typename F, typename G, typename H, typename I, typename J>
821 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C,
822  detail::RFVisitorNode<D, detail::RFVisitorNode<E, detail::RFVisitorNode<F,
823  detail::RFVisitorNode<G, detail::RFVisitorNode<H, detail::RFVisitorNode<I,
824  detail::RFVisitorNode<J> > > > > > > > > >
825 create_visitor(A & a, B & b, C & c, D & d, E & e, F & f, G & g, H & h, I & i,
826  J & j)
827 {
828  typedef detail::RFVisitorNode<J> _9_t;
829  _9_t _9(j);
830  typedef detail::RFVisitorNode<I, _9_t> _8_t;
831  _8_t _8(i, _9);
832  typedef detail::RFVisitorNode<H, _8_t> _7_t;
833  _7_t _7(h, _8);
834  typedef detail::RFVisitorNode<G, _7_t> _6_t;
835  _6_t _6(g, _7);
836  typedef detail::RFVisitorNode<F, _6_t> _5_t;
837  _5_t _5(f, _6);
838  typedef detail::RFVisitorNode<E, _5_t> _4_t;
839  _4_t _4(e, _5);
840  typedef detail::RFVisitorNode<D, _4_t> _3_t;
841  _3_t _3(d, _4);
842  typedef detail::RFVisitorNode<C, _3_t> _2_t;
843  _2_t _2(c, _3);
844  typedef detail::RFVisitorNode<B, _2_t> _1_t;
845  _1_t _1(b, _2);
846  typedef detail::RFVisitorNode<A, _1_t> _0_t;
847  _0_t _0(a, _1);
848  return _0;
849 }
850 
851 
852 
853 } // namespace rf3
854 } // namespace vigra
855 
856 #endif
void visit_before_tree(TREE &, FEATURES &, LABELS &, WEIGHTS &)
Do something before a tree has been learned.
Definition: random_forest_visitors.hxx:99
void visit_before_tree(TREE &tree, FEATURES &features, LABELS &, WEIGHTS &weights)
Definition: random_forest_visitors.hxx:270
void visit_after_split(TREE &, FEATURES &, LABELS &, WEIGHTS &, SCORER &, ITER, ITER, ITER)
Do something after the split was made.
Definition: random_forest_visitors.hxx:121
const difference_type & shape() const
Definition: multi_array.hxx:1648
void visit_after_tree(RF &, FEATURES &, LABELS &, WEIGHTS &)
Do something after a tree has been learned.
Definition: random_forest_visitors.hxx:106
void deactivate()
Deactivate the visitor.
Definition: random_forest_visitors.hxx:150
void reshape(const difference_type &shape)
Definition: multi_array.hxx:2861
Base class from which all random forest visitors derive.
Definition: random_forest_visitors.hxx:68
size_t repetition_count_
Definition: random_forest_visitors.hxx:457
The default visitor node (= "do nothing").
Definition: random_forest_visitors.hxx:509
Definition: random.hxx:669
void visit_after_training(VISITORS &visitors, RF &rf, const FEATURES &features, const LABELS &labels)
Definition: random_forest_visitors.hxx:208
Compute the variable importance.
Definition: random_forest_visitors.hxx:257
Compute the out of bag error.
Definition: random_forest_visitors.hxx:172
double oob_err_
Definition: random_forest_visitors.hxx:246
void visit_after_tree(RF &rf, const FEATURES &features, const LABELS &labels, WEIGHTS &)
Definition: random_forest_visitors.hxx:327
Definition: random_forest_visitors.hxx:635
void activate()
Activate the visitor.
Definition: random_forest_visitors.hxx:142
Class for fixed size vectors.This class contains an array of size SIZE of the specified VALUETYPE...
Definition: accessor.hxx:940
void visit_after_split(TREE &tree, FEATURES &, LABELS &labels, WEIGHTS &weights, SCORER &scorer, ITER begin, ITER, ITER end)
Definition: random_forest_visitors.hxx:307
void visit_before_tree(TREE &, FEATURES &, LABELS &, WEIGHTS &weights)
Definition: random_forest_visitors.hxx:180
Container elements of the statically linked visitor list. Use the create_visitor() functions to creat...
Definition: random_forest_visitors.hxx:519
MultiArray< 2, double > variable_importance_
Definition: random_forest_visitors.hxx:452
FFTWComplex< R >::NormType abs(const FFTWComplex< R > &a)
absolute value (= magnitude)
Definition: fftw3.hxx:1002
bool is_active() const
Return whether the visitor is active or not.
Definition: random_forest_visitors.hxx:134
void visit_before_training()
Do something before training starts.
Definition: random_forest_visitors.hxx:80
MultiArrayView subarray(difference_type p, difference_type q) const
Definition: multi_array.hxx:1528
detail::VisitorNode< A > create_visitor(A &a)
Definition: rf_visitors.hxx:344
void visit_after_training(VISITORS &, RF &, const FEATURES &, const LABELS &)
Do something after all trees have been learned.
Definition: random_forest_visitors.hxx:90
void visit_after_training(VISITORS &visitors, RF &rf, const FEATURES &features, const LABELS &)
Definition: random_forest_visitors.hxx:403

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.11.1 (Fri May 19 2017)