Skip to content
Snippets Groups Projects
Commit 3bdc0611 authored by Bruno Freitas Tissei's avatar Bruno Freitas Tissei
Browse files

Add ball tree

parent 3b51a1d5
No related branches found
No related tags found
No related merge requests found
/**
* Balltree (k-Nearest Neighbors)
*
* Complexity (Time): O(n log n)
* Complexity (Space): O(n)
*/
#define x first
#define y second
typedef pair<double, double> point;
typedef vector<point> pset;
typedef struct node {
double radius;
point center;
node *left, *right;
} node;
double distance(point &a, point &b) {
return sqrt((a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y));
}
// Find furthest point from center and returns <distance,index> of that point
pair<double, int> get_radius(point &center, pset &ps) {
int ind = 0;
double dist, radius = -1.0;
for (int i = 0; i < ps.size(); ++i) {
dist = distance(center, ps[i]);
if (radius < dist) {
radius = dist;
ind = i;
}
}
return pair<double, int>(radius, ind);
}
// Find average point and pretends it's the center of the given set of points
void get_center(pset &ps, point &center) {
center.x = center.y = 0;
for (auto p : ps) {
center.x += p.x;
center.y += p.y;
}
center.x /= (double)ps.size();
center.y /= (double)ps.size();
}
// Splits the set of points in closer to ps[lind] and closer to ps[rind],
// where lind is returned by get_radius and rind is the furthest points
// from ps[lind]
void partition(pset &ps, pset &left, pset &right, int lind) {
int rind = 0;
double dist, grt = -1.0;
double ldist, rdist;
point rmpoint;
point lmpoint = ps[lind];
for (int i = 0; i < ps.size(); ++i)
if (i != lind) {
dist = distance(lmpoint, ps[i]);
if (dist > grt) {
grt = dist;
rind = i;
}
}
rmpoint = ps[rind];
left.push_back(ps[lind]);
right.push_back(ps[rind]);
for (int i = 0; i < ps.size(); ++i)
if (i != lind && i != rind) {
ldist = distance(ps[i], lmpoint);
rdist = distance(ps[i], rmpoint);
if (ldist <= rdist)
left.push_back(ps[i]);
else
right.push_back(ps[i]);
}
}
// Build ball-tree recursively
// ps: vector of points
node *build(pset &ps) {
if (ps.size() == 0)
return nullptr;
node *n = new node;
// When there's only one point in ps, a leaf node is created storing that
// point
if (ps.size() == 1) {
n->center = ps[0];
n->radius = 0.0;
n->right = n->left = nullptr;
// Otherwise, ps gets split into two partitions, one for each child
} else {
get_center(ps, n->center);
auto rad = get_radius(n->center, ps);
pset lpart, rpart;
partition(ps, lpart, rpart, rad.second);
n->radius = rad.first;
n->left = build(lpart);
n->right = build(rpart);
}
return n;
}
// Search the ball-tree recursively
// n: root
// t: query point
// pq: initially empty multiset (will contain the answer after execution)
// k: number of nearest neighbors
void search(node *n, point t, multiset<double> &pq, int &k) {
if (n->left == nullptr && n->right == nullptr) {
double dist = distance(t, n->center);
// (!) Only necessary when the same point needs to be ignored
if (dist < EPS)
return;
else if (pq.size() < k || dist < *pq.rbegin()) {
pq.insert(dist);
if (pq.size() > k)
pq.erase(prev(pq.end()));
}
} else {
double distl = distance(t, n->left->center);
double distr = distance(t, n->right->center);
if (distl <= distr) {
if (pq.size() < k || (distl <= *pq.rbegin() + n->left->radius))
search(n->left, t, pq, k);
if (pq.size() < k || (distr <= *pq.rbegin() + n->right->radius))
search(n->right, t, pq, k);
} else {
if (pq.size() < k || (distr <= *pq.rbegin() + n->right->radius))
search(n->right, t, pq, k);
if (pq.size() < k || (distl <= *pq.rbegin() + n->left->radius))
search(n->left, t, pq, k);
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment