1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
| using namespace std;
class BinaryTree { public: int value; BinaryTree *left = nullptr; BinaryTree *right = nullptr;
BinaryTree(int value) { this->value = value; } };
void find(BinaryTree* node, int k, vector<int>& res) { if(!node) return; if(!k) { res.push_back(node->value); } else { find(node->left, k - 1, res); find(node->right, k - 1, res); } }
pair<bool, int> travel(BinaryTree* node, int& target, int& k, vector<int>& res) { if(!node) return {false, -1}; if(node->value == target) { find(node, k, res); return {true, 1}; } else { auto [lf, ldis] = travel(node->left, target, k, res); auto [rf, rdis] = travel(node->right, target, k, res); if(!lf and !rf) return {false, -1}; if(lf) { if(ldis > k) return {true, ldis + 1}; if(ldis == k) res.push_back(node->value); else if(ldis < k) find(node->right, k - ldis - 1, res); return {true, ldis + 1}; } if(rf) { if(rdis > k) return {true, rdis + 1}; if(rdis == k) res.push_back(node->value); else if(rdis < k) find(node->left, k - rdis - 1, res); return {true, rdis + 1}; } } }
vector<int> findNodesDistanceK(BinaryTree *tree, int target, int k) { vector<int> res; travel(tree, target, k, res); return res; }
|