class Solution:
def numIslands(self, grid: List[List[str]]) -> int:
m, n = len(grid), len(grid[0])
disjoint_sets = DisjointSets(n * m)
island_count = 0
getIdx = lambda row, col: row * n + col
isValidCell = lambda row, col: (0 <= row < m) and (0 <= col < n)
isIsland = lambda row, col: isValidCell(row, col) and grid[row][col] == "1"
for row in range(m):
for col in range(n):
if isIsland(row, col):
island_count += 1
if isIsland(row - 1, col):
island_count -= disjoint_sets.union(
getIdx(row, col), getIdx(row - 1, col)
)
if isIsland(row, col - 1):
island_count -= disjoint_sets.union(
getIdx(row, col), getIdx(row, col - 1)
)
return island_count