1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
 class Solution { long long res = 1; vector<pair<long long, long long>> val; vector<vector<long long>> adj; vector<long long> P; void dfs(long long u, long long par) { val[u] = {0,P[u]}; for(auto& v : adj[u]) { if(v == par) continue; dfs(v,u); val[u].first = max(val[u].first, val[v].first + P[u]); val[u].second = max(val[u].second, val[v].second + P[u]); } res = max(res, val[u].first); } void bfs(long long u, long long par, long long ma1, long long ma2) { queue<array<long long,4>> q; q.push({u,par,ma1,ma2}); while(q.size()) { auto [u,par,ma1,ma2] = q.front(); q.pop(); priority_queue<pair<long long, long long>, vector<pair<long long, long long>>, greater<pair<long long, long long>>> q1,q2; res = max(res, ma1); q1.push({ma1,par}); q2.push({ma2,par}); q1.push({0,u}); q2.push({P[u],u}); for(auto& v : adj[u]) { if(v == par) continue; q1.push({val[v].first,v}); q2.push({val[v].second,v}); if(q1.size() > 2) q1.pop(); if(q2.size() > 2) q2.pop(); } vector<pair<long long, long long>> ma1v, ma2v; while(q1.size()) { ma1v.push_back(q1.top()); q1.pop(); } while(q2.size()) { ma2v.push_back(q2.top()); q2.pop(); } for(auto& [v1, idx1] : ma1v) { for(auto& [v2,idx2] : ma2v) { if(idx1 == idx2) continue; res = max(res,v1 + v2); } } for(auto& v : adj[u]) { if(v == par) continue; long long ma1u = 0, ma2u = 0; for(auto [value, idx] : ma1v) { if(idx == v) continue; ma1u = max(ma1u, value); if(idx != par and idx != u) ma1u = max(ma1u, value + P[u]); } for(auto [value, idx] : ma2v) { if(idx == v) continue; ma2u = max(ma2u, value); if(idx != par and idx != u) ma2u = max(ma2u, value + P[u]); } res = max(res, ma1u + val[v].second); res = max(res, ma2u + val[v].first); q.push({v,u,ma1u,ma2u}); } }
}
public: long long maxOutput(int n, vector<vector<int>>& edges, vector<int>& price) { adj = vector<vector<long long>>(n + 10); val = vector<pair<long long,long long>>(n + 10); res = 1; for(auto e : edges) { int u = e[0], v = e[1]; adj[u].push_back(v); adj[v].push_back(u); } P = vector<long long>(n + 10); for(int i = 0; i < n; i++) P[i] = price[i]; dfs(0,1); bfs(0,1,INT_MIN,INT_MIN); return res; } };
