MTF
GNN.h
1 #ifndef MTF_GNN_H
2 #define MTF_GNN_H
3 
4 #include "mtf/SM/GNNParams.h"
5 
6 #include <vector>
7 #include <memory>
8 
9 _MTF_BEGIN_NAMESPACE
10 
11 namespace gnn{
12  struct Node{
13  VectorXi nns_inds;
14  int size;
15  int capacity;
16  };
17  struct IndxDist{
18  double dist;
19  int idx;
20  };
21  inline int cmpQsort(const void *a, const void *b){
22  IndxDist *a1 = (IndxDist *)a;
23  IndxDist *b1 = (IndxDist *)b;
24  // Ascending
25  if(a1->dist > b1->dist) return 1;
26  if(a1->dist == b1->dist) return 0;
27  return -1;
28  }
29 
30  template<class DistType>
31  class GNN{
32  public:
33 
34  typedef GNNParams ParamType;
35  typedef std::shared_ptr<const DistType> DistTypePtr;
36 
37  GNN(DistTypePtr _dist_func, int _n_samples, int _n_dims,
38  bool _is_symmetrical = true, const ParamType *gnn_params = nullptr);
39  ~GNN(){}
40  void computeDistances(const double *dataset);
41  void buildGraph(const double *dataset);
42  void searchGraph(const double *query, const double *dataset,
43  int *nn_ids, double *nn_dists, int K = 1);
44  void saveGraph(const char* file_name);
45  void loadGraph(const char* file_name);
46 
47  void buildGraph(const double *X, int k);
48  int searchGraph(const double *Xq, const double *X,
49  int NNs, int K);
50 
51  protected:
52 
53  DistTypePtr dist_func;
54  const int n_samples, n_dims;
55  const bool is_symmetrical;
56  ParamType params;
57  std::vector<Node> nodes;
58  MatrixXd dataset_distances;
59 
60  int start_node_idx;
61  bool dist_computed;
62 
63  int getRandNum(int lb, int ub){
64  // time_t sec;
65  // time(&sec);
66  // srand((unsigned int) sec);
67  return (rand() % (ub - lb + 1) + lb);
68  }
69  template<typename ScalarT>
70  void swap(ScalarT *i, ScalarT *j){
71  ScalarT temp;
72  temp = *i;
73  *i = *j;
74  *j = temp;
75  }
76  void knnSearch2(const double *Q, IndxDist *dists, const double *X,
77  int rows, int cols, int k);
78  void knnSearch11(const double *Q, IndxDist *dists, const double *X, int rows,
79  int cols, int k, int *X_inds);
80 
81  int min(int a, int b){ return a < b ? a : b; }
82 
83  void pickKNNs(IndxDist *vis_nodes, int visited, IndxDist **gnn_dists,
84  int K, int *gnns_cap);
85 
86  void addNode(Node *node_i, int nn);
87  };
88 }
89 _MTF_END_NAMESPACE
90 
91 #endif
92 
Definition: GNNParams.h:9
Definition: GNN.h:31
Definition: GNN.h:12
GNN with FLANN support.
Definition: FGNN.h:12
Definition: GNN.h:17