diff --git a/algorithms/graph/lca.cpp b/algorithms/graph/lca.cpp index d2d68bf598d5f86b2af73e2d936c9a40a6251c9b..50f87011deaf64db426d396478a63edc5d24006c 100644 --- a/algorithms/graph/lca.cpp +++ b/algorithms/graph/lca.cpp @@ -4,26 +4,40 @@ /// - preprocess: O(V log V) /// - query: O(log V) /// Complexity (Space): O(V + E + V log V) -/// -/// OBS: * = return sum path to LCA -/// ** = return max value on path to LCA -/// *** = used in both * and ** #define MAXLOG 20 //log2(MAX) -int h[MAX]; -int par[MAX][MAXLOG]; -//*** int cost[MAX][MAXLOG]; - -vector<int> graph[MAX]; //*** vector<ii> +vector<ii> graph[MAX]; struct LCA { + vector<int> h; + vector<vector<int>> par, cost; + + LCA(int N) : + h(N), + par(N, vector<int>(MAXLOG)), + cost(N, vector<int>(MAXLOG)) + { + init(); + } + + void init() { + for (auto &i : par) + fill(all(i), -1); + for (auto &i : cost) + fill(all(i), 0); + dfs(0); // Assuming root is 0 + } + + inline int op(int a, int b) { + return a + b; // or max(a, b) + } /// Performs DFS while filling h, par, and cost. /// @param v root of the tree void dfs(int v, int p = -1, int c = 0) { par[v][0] = p; - //*** cost[v][0] = c; + cost[v][0] = c; if (p != -1) h[v] = h[p] + 1; @@ -31,8 +45,7 @@ struct LCA { for (int i = 1; i < MAXLOG; ++i) if (par[v][i - 1] != -1) { par[v][i] = par[par[v][i - 1]][i - 1]; - //* cost[v][i] += cost[v][i - 1] + cost[par[v][i - 1]][i - 1]; - //** cost[v][i] = max(cost[v][i], max(cost[par[v][i-1]][i-1], cost[v][i-1])); + cost[v][i] = op(cost[v][i], op(cost[par[v][i-1]][i-1], cost[v][i-1])); } for (auto u : graph[v]) @@ -40,46 +53,40 @@ struct LCA { dfs(u.fi, v, u.se); } - /// Preprocess tree. - /// @param v root of the tree - void preprocess(int v) { - mset(par, -1); - //*** mset(cost, 0); - dfs(v); - } - /// Returns LCA (or sum or max). /// @param p,q query nodes int query(int p, int q) { - //*** int ans = 0; + int ans = 0; if (h[p] < h[q]) swap(p, q); for (int i = MAXLOG - 1; i >= 0; --i) if (par[p][i] != -1 && h[par[p][i]] >= h[q]) { - //* ans += cost[p][i]; - //** ans = max(ans, cost[p][i]); + ans = op(ans, cost[p][i]); p = par[p][i]; } - if (p == q) - return p; //*** return ans; + if (p == q) { + #ifdef COST + return ans; + #else + return p; + #endif + } for (int i = MAXLOG - 1; i >= 0; --i) if (par[p][i] != -1 && par[p][i] != par[q][i]) { - //* ans += cost[p][i] + cost[q][i]; - //** ans = max(ans, max(cost[p][i], cost[q][i])); + ans = op(ans, op(cost[p][i], cost[q][i])); // * p = par[p][i]; q = par[q][i]; } - //* if (p == q) return ans; - //* else return ans + cost[p][0] + cost[q][0]; - - //** if (p == q) return ans; - //** else return max(ans, max(cost[p][0], cost[q][0])); - - return par[p][0]; + #ifdef COST + if (p == q) return ans; + else return op(ans, op(cost[p][0], cost[q][0])); + #else + return par[p][0]; + #endif } };