Binary Search Template for Coding Interviews

October 20, 2020

Binary search is something we all learn in CS101 classes. It is a very simple concept however implementing it in 15-20 minutes in a high-pressure interview and making sure all the bounds checks are satisfied can be tricky. I’ve been using this binary search template to solve a lot:

def condition(e)bool:
  ''' 
  Condition must split the array in 2 halves
  Returning true for one and false for the other
  '''
  pass

def binary_search(array):
  lo, hi = 0, len(array)
  while lo < hi:
    mid = lo + (hi - lo)//2
    if condition(mid):
      hi = mid
    else:
      lo = mid+1
  return lo

They key was in finding a function that splits the array in 2 halves. Here’s an example of the template applied on some leetcode problems. I’ve intentionally created an explicit condition method which might be an overkill for some examples just to demonstrate the template.

Traditional Binary Search

def condition(index):
    return nums[index] >= target

lo, hi = 0 , len(nums)

while lo < hi:
    mid = lo + (hi-lo)//2
    if condition(mid):
        hi = mid
    else:
        lo = mid+1
return lo if lo < len(nums) and nums[lo] == target else -1

LC 69: Square root of number

def mySqrt(self, x: int) -> int:
    if x == 1: return 1

    def condition(mid: int) -> bool:
        return mid**2 > x

    lo, hi = 1, x+1

    while lo < hi:
        mid = lo + (hi-lo)//2
        if condition(mid):
            hi = mid
        else:
            lo = mid+1
    return lo-1

LC 374: Guess number higher or lower

def guessNumber(self, n: int) -> int:
    lo, hi = 0, n+1

    def condition(num: int) -> bool:
        return guess(num) <= 0

    while lo < hi:
        mid = lo + (hi-lo)//2
        if condition(mid):
            hi = mid
        else:
            lo = mid+1
    return lo

LC 153: Find minimum in rotated sorted array

def findMin(self, nums: List[int]) -> int:
    def condition(mid: int) -> bool:
        return nums[mid] < nums[0]

    lo, hi = 0, len(nums)

    while lo < hi:
        mid = lo + (hi-lo)//2
        if condition(mid):
            hi = mid
        else:
            lo = mid+1
    return nums[lo] if lo < len(nums) else nums[0]

LC 658: Find k closest Elements

def findClosestElements(self, arr: List[int], k: int, x: int) -> List[int]:
    h = []

    def condition(mid: int) -> bool:
        return arr[mid] >= x 

    lo, hi = 0, len(arr)
    min_elem, min_index = float('inf'), float('inf')
    while lo < hi:
        mid = lo + (hi-lo)//2
        if condition(mid):
            hi = mid
        else:
            lo = mid+1
        if abs(x-min_elem) > abs(x-arr[mid]):
            min_elem = arr[mid]
            min_index = mid   
    if abs(x-min_elem) > abs(x-arr[lo]):
        min_elem = arr[lo]
        min_index = lo

    out = []   
    l, r = min_index-1, min_index

    while r - l <= k :
        c1, c2 = float('inf'), float('inf')
        if l >= 0: c1 = arr[l]
        if r < len(arr): c2 = arr[r]

        if abs(c1-x) <= abs(c2-x):
            out.append(c1)
            l = l - 1
        else:
            out.append(c2)
            r = r + 1
    out.sort()       
    return out