Sum of Top K Elements

The question is to write a function to keep track of the sum of top ‘K’ elements. We have queries of where we can either add an element, remove and element, or find max k sum of current numbers.

Example:

Initially, the numbers are [4 3 5 1 2] and let’s say k = 3.

Q1 -> maxKSum() = 5 + 4 + 3 = 12 Q2 -> add(7), current nos. = [4 3 5 1 2 7] Q3 -> maxKSum() = 7 + 5 + 4 = 16 Q4 -> remove(5)), current nos. = [4 3 1 2 7] Q5 -> maxKSum() = 7 + 4 + 3 = 14

Is there any way to implement the maxKSum() function such that it takes O(1) time?

Yes, you can achieve O(1) time complexity for the maxKSum() function by using a min-heap and some additional tracking variables. Here’s a brief outline of the approach:

  1. Min-Heap for K Largest Elements: Use a min-heap to keep track of the top k elements. Since it’s a min-heap, the smallest of the top k elements will be at the root, allowing you to easily handle removal or updates.

  2. Total Sum of K Largest Elements: Maintain a variable to keep track of the total sum of the top k elements. You’ll update this variable whenever you add or remove elements.

  3. Set for All Elements: Keep a set of all elements to facilitate easy removal.

  4. Add Function: When adding an element, you’ll need to check whether it should be among the top k. If it should, add it to the min-heap and update the sum accordingly.

  5. Remove Function: When removing an element, check if it is among the top k (it will be in the min-heap). If it is, remove it from the heap and update the sum.

  6. MaxKSum Function: Since you’re maintaining the sum of the top k elements, you can simply return this value in O(1) time.

Here’s the code:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import heapq

class MaxKSum:
    def __init__(self, nums, k):
        self.k = k
        self.nums_set = set(nums)
        self.k_largest_heap = heapq.nsmallest(k, nums)
        self.k_largest_sum = sum(self.k_largest_heap)
        heapq.heapify(self.k_largest_heap)

    def add(self, num):
        self.nums_set.add(num)
        if len(self.k_largest_heap) < self.k or num <= self.k_largest_heap[0]:
            return
        self.k_largest_sum -= heapq.heappushpop(self.k_largest_heap, num)
        self.k_largest_sum += num

    def remove(self, num):
        if num not in self.nums_set:
            return
        self.nums_set.remove(num)
        if num in self.k_largest_heap:
            self.k_largest_heap.remove(num)
            heapq.heapify(self.k_largest_heap)
            self.k_largest_sum -= num
            if self.nums_set:
                replacement = max(self.nums_set - set(self.k_largest_heap))
                heapq.heappush(self.k_largest_heap, replacement)
                self.k_largest_sum += replacement

    def maxKSum(self):
        return self.k_largest_sum

This code will enable you to perform add, remove, and maxKSum operations as described, with maxKSum returning the result in O(1) time.