Binary search is something we all learn in CS101 classes. It is a very simple concept however implementing it in 15-20 minutes in a high-pressure interview and making sure all the bounds checks are satisfied can be tricky. I’ve been using this binary search template to solve a lot:
def condition(e) → bool:
'''
Condition must split the array in 2 halves
Returning true for one and false for the other
'''
pass
def binary_search(array):
lo, hi = 0, len(array)
while lo < hi:
mid = lo + (hi - lo)//2
if condition(mid):
hi = mid
else:
lo = mid+1
return lo
They key was in finding a function that splits the array in 2 halves. Here’s an example of the template applied on some leetcode problems. I’ve intentionally created an explicit condition
method which might be an overkill for some examples just to demonstrate the template.
Traditional Binary Search
def condition(index):
return nums[index] >= target
lo, hi = 0 , len(nums)
while lo < hi:
mid = lo + (hi-lo)//2
if condition(mid):
hi = mid
else:
lo = mid+1
return lo if lo < len(nums) and nums[lo] == target else -1
LC 69: Square root of number
def mySqrt(self, x: int) -> int:
if x == 1: return 1
def condition(mid: int) -> bool:
return mid**2 > x
lo, hi = 1, x+1
while lo < hi:
mid = lo + (hi-lo)//2
if condition(mid):
hi = mid
else:
lo = mid+1
return lo-1
LC 374: Guess number higher or lower
def guessNumber(self, n: int) -> int:
lo, hi = 0, n+1
def condition(num: int) -> bool:
return guess(num) <= 0
while lo < hi:
mid = lo + (hi-lo)//2
if condition(mid):
hi = mid
else:
lo = mid+1
return lo
LC 153: Find minimum in rotated sorted array
def findMin(self, nums: List[int]) -> int:
def condition(mid: int) -> bool:
return nums[mid] < nums[0]
lo, hi = 0, len(nums)
while lo < hi:
mid = lo + (hi-lo)//2
if condition(mid):
hi = mid
else:
lo = mid+1
return nums[lo] if lo < len(nums) else nums[0]
LC 658: Find k closest Elements
def findClosestElements(self, arr: List[int], k: int, x: int) -> List[int]:
h = []
def condition(mid: int) -> bool:
return arr[mid] >= x
lo, hi = 0, len(arr)
min_elem, min_index = float('inf'), float('inf')
while lo < hi:
mid = lo + (hi-lo)//2
if condition(mid):
hi = mid
else:
lo = mid+1
if abs(x-min_elem) > abs(x-arr[mid]):
min_elem = arr[mid]
min_index = mid
if abs(x-min_elem) > abs(x-arr[lo]):
min_elem = arr[lo]
min_index = lo
out = []
l, r = min_index-1, min_index
while r - l <= k :
c1, c2 = float('inf'), float('inf')
if l >= 0: c1 = arr[l]
if r < len(arr): c2 = arr[r]
if abs(c1-x) <= abs(c2-x):
out.append(c1)
l = l - 1
else:
out.append(c2)
r = r + 1
out.sort()
return out