Kth Largest Element in an Array

Question

Find K-th largest element in an array. Note that it is the kth largest element in the sorted order, not the kth distinct element.

Notice

You can swap elements in the array

Example

In array [9,3,2,4,8], the 3rd largest element is 4.

In array [1,2,3,4,5], the 1st largest element is 5, 2nd largest element is 4, 3rd largest element is 3 and etc.

Challenge

O(n) time, O(1) extra memory.

Analysis

Sort Array

  • O(nlogn) time

Max Heap

  • O(nlogk) running time + O(k) memory

在数字集合中寻找第k大,可以考虑用Max Heap,将数组遍历一遍,加入到一个容量为k的PriorityQueue,最后poll() k-1次,那么最后剩下在堆顶的就是kth largest的数字了。

另外此题用quick sort中的 quick select 的思路来解,更优化,参考:http://www.jiuzhang.com/solutions/kth-largest-element/

Quick Select

  • Time Complexity: average = O(n); worst case O(n^2), O(1) space

注意事项:

  • partition的主要思想:将比pivot小的元素放到pivot左边,比pivot大的放到pivot右边

  • pivot的选取决定了partition所得结果的效率,可以选择left pointer,更好的选择是在left和right范围内随机生成一个;

Time Complexity O(n)来自于O(n) + O(n/2) + O(n/4) + ... ~ O(2n),此时每次partition的pivot大约将区间对半分。

Source: https://discuss.leetcode.com/topic/15256/4-c-solutions-using-partition-max-heap-priority_queue-and-multiset-respectively

So, in the average sense, the problem is reduced to approximately half of its original size, giving the recursion T(n) = T(n/2) + O(n) in which O(n) is the time for partition. This recursion, once solved, gives T(n) = O(n) and thus we have a linear time solution. Note that since we only need to consider one half of the array, the time complexity is O(n). If we need to consider both the two halves of the array, like quicksort, then the recursion will be T(n) = 2T(n/2) + O(n) and the complexity will be O(nlogn).

Of course, O(n) is the average time complexity. In the worst case, the recursion may become T(n) = T(n - 1) + O(n) and the complexity will be O(n^2).

Solution

Sort Array

O(nlogn)

public class Solution {
    public int findKthLargest(int[] nums, int k) {
            final int N = nums.length;
            Arrays.sort(nums);
            return nums[N - k];
    }
}

Max Heap

O(N lg K) running time + O(K) memory

class Solution {
    /*
     * @param k : description of k
     * @param nums : array of nums
     * @return: description of return
     */
    public int kthLargestElement(int k, int[] nums) {
        if (nums == null || nums.length == 0 || k == 0) {
            return -1;
        }
        PriorityQueue<Integer> heap = new PriorityQueue<Integer>(k, new Comparator<Integer>() {
            @Override
            public int compare(Integer o1, Integer o2) {
                return o2 - o1;
            }
        });
        for (int i = 0; i < nums.length; i++) {
            heap.offer(nums[i]);
        }
        for (int j = 0; j < k - 1; j++) {
            heap.poll();
        }
        return heap.peek();
    }
};

Min Heap

O(N lg K) running time + O(K) memory

public int findKthLargest(int[] nums, int k) {

    final PriorityQueue<Integer> pq = new PriorityQueue<>();
    for(int val : nums) {
        pq.offer(val);

        if(pq.size() > k) {
            pq.poll();
        }
    }
    return pq.peek();
}

Quick Select: O(N) best case / O(N^2) worst case running time + O(1) memory

    public int findKthLargest(int[] nums, int k) {

        k = nums.length - k;
        int lo = 0;
        int hi = nums.length - 1;
        while (lo < hi) {
            final int j = partition(nums, lo, hi);
            if(j < k) {
                lo = j + 1;
            } else if (j > k) {
                hi = j - 1;
            } else {
                break;
            }
        }
        return nums[k];
    }

    private int partition(int[] a, int lo, int hi) {

        int i = lo;
        int j = hi + 1;
        while(true) {
            while(i < hi && less(a[++i], a[lo]));
            while(j > lo && less(a[lo], a[--j]));
            if(i >= j) {
                break;
            }
            exch(a, i, j);
        }
        exch(a, lo, j);
        return j;
    }

    private void exch(int[] a, int i, int j) {
        final int tmp = a[i];
        a[i] = a[j];
        a[j] = tmp;
    }

    private boolean less(int v, int w) {
        return v < w;
    }

Use Shuffle : O(N) guaranteed running time + O(1) space

    public int findKthLargest(int[] nums, int k) {

        shuffle(nums);
        k = nums.length - k;
        int lo = 0;
        int hi = nums.length - 1;
        while (lo < hi) {
            final int j = partition(nums, lo, hi);
            if(j < k) {
                lo = j + 1;
            } else if (j > k) {
                hi = j - 1;
            } else {
                break;
            }
        }
        return nums[k];
    }

    private void shuffle(int a[]) {

        final Random random = new Random();
        for(int ind = 1; ind < a.length; ind++) {
            final int r = random.nextInt(ind + 1);
            exch(a, ind, r);
        }
    }

Quick Select (Partition with two pointers)

O(N) best case / O(N^2) worst case running time + O(1) memory

class Solution {
    /*
     * @param k : description of k
     * @param nums : array of nums
     * @return: description of return
     */
    public int kthLargestElement(int k, int[] nums) {
        if (nums == null || nums.length == 0 || k <= 0 || k > nums.length) {
            return 0;
        }

        return select(nums, 0, nums.length - 1, nums.length - k);

    }

    public int select(int[] nums, int left, int right, int k) {
        if (left == right) {
            return nums[left];
        }

        int pivotIndex = partition(nums, left, right);
        if (pivotIndex == k) {
            return nums[pivotIndex];
        } else if (pivotIndex < k) {
            return select(nums, pivotIndex + 1, right, k);
        }  else {
            return select(nums, left, pivotIndex - 1, k);
        }
    }

    public int partition(int[] nums, int left, int right) {

        // Init pivot, better to be random
        int pivot = nums[left];

        // Begin partition
        while (left < right) {
            while (left < right && nums[right] >= pivot) { // skip nums[i] that equals pivot
                right--;
            }
            nums[left] = nums[right];
            while (left < right && nums[left] <= pivot) { // skip nums[i] that equals pivot
                left++;
            }
            nums[right] = nums[left];
        }

        // Recover pivot to array
        nums[left] = pivot;
        return left;
    }
}

Quick Select (with random pivot)

Source: wikipedia: QuickSelect

Animation:

import java.util.Random;

class Solution {
    /*
     * @param k : description of k
     * @param nums : array of nums
     * @return: description of return
     */
    public int kthLargestElement(int k, int[] nums) {
        if (nums == null || nums.length == 0 || k <= 0 || k > nums.length) {
            return 0;
        }

        return select(nums, 0, nums.length - 1, nums.length - k);

    }

    public int select(int[] nums, int left, int right, int k) {
        if (left == right) {
            return nums[left];
        }

        int pivotIndex = partition(nums, left, right);
        if (pivotIndex == k) {
            return nums[pivotIndex];
        } else if (pivotIndex < k) {
            return select(nums, pivotIndex + 1, right, k);
        }  else {
            return select(nums, left, pivotIndex - 1, k);
        }
    }

    public void swap(int[] nums, int x, int y) {
        int tmp = nums[x];
        nums[x] = nums[y];
        nums[y] = tmp;
    }

    public int partition(int[] nums, int left, int right) {

        Random rand = new Random();
        int pivotIndex = rand.nextInt((right - left) + 1) + left;
        // Init pivot
        int pivotValue = nums[pivotIndex];

        swap(nums, pivotIndex, right);

        // First index that nums[firstIndex] > pivotValue
        int firstIndex = left;

        for (int i = left; i <= right - 1; i++) {
            if (nums[i] < pivotValue) {
                swap(nums, firstIndex, i);
                firstIndex++;
            }
        }

        // Recover pivot to array
        swap(nums, right, firstIndex);
        return firstIndex;
    }

    public static void main(String[] args) {
        System.out.println("kth Largest Element: Quick Select");
        int[] A = {21, 3, 34, 5, 13, 8, 2, 55, 1, 19};
        Solution search = new Solution();
        int expResult[] = {1, 2, 3, 5, 8, 13, 19, 21, 34, 55};
        int k = expResult.length;
        int err = 0;
        for (int exp : expResult) {
            if (exp != search.kthLargestElement(k--, A)) {
                System.out.println("Test failed: " + k);
                err++;
            }
        }
        System.out.println("Test finished");
    }
}

LeetCode Quick Select

import java.util.Random;
class Solution {
  int [] nums;

  public void swap(int a, int b) {
    int tmp = this.nums[a];
    this.nums[a] = this.nums[b];
    this.nums[b] = tmp;
  }


  public int partition(int left, int right, int pivot_index) {
    int pivot = this.nums[pivot_index];
    // 1. move pivot to end
    swap(pivot_index, right);
    int store_index = left;

    // 2. move all smaller elements to the left
    for (int i = left; i <= right; i++) {
      if (this.nums[i] < pivot) {
        swap(store_index, i);
        store_index++;
      }
    }

    // 3. move pivot to its final place
    swap(store_index, right);

    return store_index;
  }

  public int quickselect(int left, int right, int k_smallest) {
    /*
    Returns the k-th smallest element of list within left..right.
    */

    if (left == right) // If the list contains only one element,
      return this.nums[left];  // return that element

    // select a random pivot_index
    Random random_num = new Random();
    int pivot_index = left + random_num.nextInt(right - left); 

    pivot_index = partition(left, right, pivot_index);

    // the pivot is on (N - k)th smallest position
    if (k_smallest == pivot_index)
      return this.nums[k_smallest];
    // go left side
    else if (k_smallest < pivot_index)
      return quickselect(left, pivot_index - 1, k_smallest);
    // go right side
    return quickselect(pivot_index + 1, right, k_smallest);
  }

  public int findKthLargest(int[] nums, int k) {
    this.nums = nums;
    int size = nums.length;
    // kth largest is (N - k)th smallest
    return quickselect(0, size - 1, size - k);
  }
}

Reference

LeetCode:

Source: wikipedia: QuickSelect

Animation:

Last updated