Given a binary search tree, how do we find a kth smallest or kth largest element?
For example, given the following binary search tree.
Third largest element is 6
Second smallest element is 2
Fifth largest element is 5 and so on…
Solution:
We can solve this problem by modifying the the in-order traversal method of a binary tree. In addition to the root node, we can pass two more parameters one is K, and current count of the nodes visited as a reference parameter. When the current count reaches K we found the kth order element.
To find out the kth smallest element, we need to visit left sub-tree, then root and then the right sub-tree as usual. To find the kth largest element we need to do reverse in-order traversal i.e First visit right sub-tree, then root and then the left sub-tree.
Here is the C++ implementation. This includes the recursive and iteration versions of finding kth smallest and kth largest elements. The iterative version is simply the manual implementation of a recursion using a stack.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <stack> | |
using namespace std; | |
//Binary tree node type | |
struct TreeNode | |
{ | |
int val; | |
TreeNode *left, *right; | |
TreeNode(int v):val(v),left(NULL), right(NULL) | |
{ | |
} | |
}; | |
//Binary search tree insert procedure for creating binary tree to test | |
TreeNode* insert(TreeNode * root, int data) | |
{ | |
if( root == NULL ) | |
{ | |
root = new TreeNode(data); | |
return root; | |
} | |
if( data < root->val ) | |
root->left = insert( root->left, data ); | |
else | |
root->right = insert( root->right, data ); | |
return root; | |
} | |
//Recursive solution to find kth smallest element using inorder traversal | |
void kth_smallest_recur(TreeNode *root, int k, int& cur) | |
{ | |
if( root == NULL || cur >= k ) | |
{ | |
return; | |
} | |
kth_smallest_recur(root->left, k, cur); | |
cur++; | |
if( cur == k ) | |
{ | |
cout << root->val << endl; | |
return; | |
} | |
kth_smallest_recur(root->right,k,cur); | |
} | |
//iterative version to find kth smallest element | |
int kth_smallest(TreeNode *root, int k) | |
{ | |
if( root != NULL ) | |
{ | |
stack<TreeNode*> st; | |
TreeNode *ptr = root; | |
while(ptr != NULL) | |
{ | |
st.push(ptr); | |
ptr = ptr->left; | |
} | |
while( !st.empty() ) | |
{ | |
TreeNode *temp = st.top(); | |
st.pop(); | |
if( --k == 0 ) | |
{ | |
return temp->val; | |
} | |
if( temp->right != NULL ) | |
{ | |
temp = temp->right; | |
while( temp ) | |
{ | |
st.push(temp); | |
temp = temp->left; | |
} | |
} | |
} | |
} | |
return -1; | |
} | |
//Recursive method to find kth largest element in BST | |
void kth_largest_recur(TreeNode *root, int k, int ¤t) | |
{ | |
if( root == NULL || current >= k ) | |
return; | |
kth_largest_recur(root->right, k, current); | |
current++; | |
if( current == k ) | |
{ | |
cout << root->val << endl; | |
return; | |
} | |
kth_largest_recur(root->left, k, current); | |
} | |
//Iterative method to find the kth largest element in BST | |
void kth_largest(TreeNode *root, int k) | |
{ | |
if( root == NULL ) | |
return; | |
stack<TreeNode *> stk; | |
TreeNode *ptr = root; | |
while( ptr != NULL ) | |
{ | |
stk.push(ptr); | |
ptr = ptr->right; | |
} | |
while( !stk.empty() ) | |
{ | |
TreeNode *temp = stk.top(); | |
stk.pop(); | |
if( --k == 0 ) | |
{ | |
cout << temp->val << endl; | |
break; | |
} | |
if( temp->left != NULL ) | |
{ | |
temp = temp->left; | |
while( temp != NULL ) | |
{ | |
stk.push(temp); | |
temp = temp->right; | |
} | |
} | |
} | |
} | |
void Test1() | |
{ | |
TreeNode *root = NULL; | |
root = insert(root, 2); | |
root = insert(root, 1); | |
root = insert(root, 4); | |
root = insert(root, 3); | |
root = insert(root, 5); | |
int c = 0; | |
cout << kth_smallest(root,1) << endl; | |
kth_smallest_recur(root,1,c); | |
c = 0; | |
cout << kth_smallest(root,3) << endl; | |
kth_smallest_recur(root,3,c); | |
c = 0; | |
kth_largest(root,1); | |
kth_largest_recur(root,1,c); | |
c = 0; | |
kth_largest(root,4); | |
kth_largest_recur(root,4,c); | |
cout << endl; | |
} | |
void Test2() | |
{ | |
TreeNode *root = NULL; | |
root = insert(root, 4); | |
root = insert(root, 2); | |
root = insert(root, 6); | |
root = insert(root, 1); | |
root = insert(root, 3); | |
root = insert(root, 5); | |
root = insert(root, 7); | |
int c = 0; | |
cout << kth_smallest(root,5) << endl; | |
kth_smallest_recur(root,5,c); | |
c = 0; | |
cout << kth_smallest(root,7) << endl; | |
kth_smallest_recur(root,7,c); | |
c = 0; | |
kth_largest(root,7); | |
kth_largest_recur(root,7,c); | |
c = 0; | |
kth_largest(root,4); | |
kth_largest_recur(root,4,c); | |
} | |
int main() | |
{ | |
Test1(); | |
Test2(); | |
return 0; | |
} |