Save my data

백준 2042 : 구간 합 구하기 본문

알고리즘/백준

백준 2042 : 구간 합 구하기

양을 좋아하는 문씨 2024. 4. 22. 23:24

 

https://www.acmicpc.net/problem/2042

 

2042번: 구간 합 구하기

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net


문제만 대충 보면 어떻게든 구현할 수는 있는데 시간제한이 있고 해서 특정한 성능이 나오는 알고리즘을 적용해야 하는데,

알고리즘 분류 탭에도 설명되어 있듯 세그먼트 트리를 활용하라고 되어있다.

 

사실 이 문제는 예전에 알고리즘 특강을 들을 때 유튜브나 인강에서 대충 다루고 넘어갔던 기억이 있다.

위낙 DP, 그래프 위주로 출제가 되니까 그쪽으로 신경을 많이 썼고 이건 대충 아 이렇게 구현하는 구나 정도만 해보고 넘어갔던 기억이 있어서 다시 풀려니까 기억이 잘 안났다.

 

예전 코드를 보면서 기억을 더듬는데 주석도 안 달려있고, 블로그에도 기록이 안 되어 있어서 이 참에 다시 풀어보면서 문서화 할 겸 기록을 남겨둔다.

 

주석을 달 때 참고하였음 :

https://yoongrammer.tistory.com/103

 

세그먼트 트리(Segment Tree) 개념 및 구현

목차 세그먼트 트리(Segment Tree) 세그먼트 트리(Segment Tree)는 배열 간격에 대한 정보를 이진 트리에 저장하는 자료구조입니다. 다음 예를 보겠습니다. A = {1, 2, 3, 4, 5 … ,N} 라는 배열에 아래 연산을

yoongrammer.tistory.com


참고 코드(예전 풀이) :

import sys
N, M, K = map(int, sys.stdin.readline().split())

# 이진트리 빌드 함수
def build(n):
    """
    2의 p승 -> m
    이진 트리이므로 2의 N제곱 형태로 나타내어진다.
    구간별로 모든 수를 담으려면 이런 구조여야 한다.
    """
    p = 0
    m = 0
    
    """
    2의 N제곱 형태를 가지면서 전체 구간의 길이를 모두 담을 수 있는 리스트 중 최소 길이를 찾기 위한 반복문이다.
    """
    while n > m:
        m = 2 ** p
        p += 1
   	# 만들어진 이진 트리를 (1차원으로)표현하기 위한 리스트를 초기화한다.
    tree = [0] * (m * 2)
    
    """
    이진 트리의 뒷부분부터 채워나간다.
    뒷부분을 모두 채운 다음 앞에는 뒷부분 값(하위 노드)들의 합을 담은 값으로 채워나간다.
    """
    for i in range(N):
        tree[i + m] = int(sys.stdin.readline())
    for i in range(m - 1, 0, -1):
        tree[i] = tree[2 * i] + tree[(2 * i) + 1]
    # 완성된 이진 트리 반환
    return tree, m
tree, leaf = build(N)

# 값이 변하는 경우 그것을 적용하는 함수
def change(tree, b, c, leaf):
    # 아래 노드부터 변화량을 적용해주기 위하여 leaf라는 값을 더한 것임
    idx = leaf + b - 1
    # 변화량이 원본보다 크다면 양수, 적다면 음수
    delta = c - tree[idx] if tree[idx] < c else -(tree[idx] - c)
    # 노드를 타고 올라가면서 값들을 갱신한다.
    while idx:
        tree[idx] += delta
        idx //= 2
    return tree
    
# 구간합을 구하는 함수
def interval_sum(tree, b, c, leaf):
    # 구간합을 담을 리스트
    total = []
    start = leaf + b - 1
    end = leaf + c - 1
    
    while start < end:
    	"""
        1. 시작 인덱스가 왼쪽 노드를 가리키는 해당하는 경우(짝수 인덱스인 경우) :
        -> 시작 인덱스 + 1 한 것에 2로 나눈 몫이 자신의 부모 인덱스이다.
        -> 왼쪽과 오른쪽을 합한 값이다.
        -> 별도의 처리를 할 필요 없이 부모 노드의 값을 가지고 다시 계산하면 된다.
        
        2. 시작 인덱스가 오른쪽 노드를 가리키는 경우(홀수 인덱스인 경우) :
        -> 부모 인덱스의 양 노드를 합한 값에서 왼쪽 노드의 값을 빼는 작업이 필요하다.
        -> total에 오른쪽 값만 추가해주고 옆의 부모 노드로 넘어간다.
        -> 왜냐하면 자신의 부모 노드는 내 왼쪽에 있는 노드의 값이 합해진 경우이기 때문에,
        그 값은 쓸 수 없고, 자기 자신만의 값을 가진 채 오른쪽의 부모 노드로 넘어가야 한다.
        
        3. 끝 인덱스가 왼쪽 노드를 가리키는 경우(짝수 인덱스인 경우) :
        -> 왼쪽 값과 오른쪽 값이 모두 포함되지 않고 왼쪽 값만 포함되는 경우를 의미한다.
        -> 즉 시작 인덱스를 계산하는 경우의 반대로 계산하면 된다.
        -> 마찬가지로 부모 인덱스의 값은 쓸 수 없으므로 자기 자신의 값만 총합 리스트에 넣고,
        자기 위치의 왼쪽 부모 노드로 이동한다.
        
        4. 반복문이 모두 끝났을 때 시작 인덱스와 끝 인덱스가 같다는 것은
        시작 노드와 끝 노드가 하나의 부모 노드로 모였다는 것을 의미한다.
        그러므로 노드는 더 이상 위로 올라갈 수 없고,
        현재 위치한 노드의 값이 곧 다른 곳에서 올라온 노드를 제외한 모든 하위 노드들의 합이므로 그 값도 총합에 더해야 한다.
        
        5. 모든 계산(시작 인덱스가 끝 인덱스보다 커지는 경우)이 끝나면 총합을 리턴한다.
        """
        if start % 2:
            total.append(tree[start])
        start = (start + 1) // 2
        if not end % 2:
            total.append(tree[end])
        end = (end - 1) // 2
    if start == end:
        total.append(tree[start])
    return sum(total)
for _ in range(M + K):
    a, b, c = map(int, sys.stdin.readline().split())
    if a == 1:
        tree = change(tree, b, c, leaf)
    else:
        answer = interval_sum(tree, b, c, leaf)
        print(answer)

 

원본에는 주석이 달려있지 않아서 하나하나 일일이 작성하였다.

그리고 트리를 초기화 할 때 앞에서부터 채워넣는 방식이랑 뒤에서부터 채우는 방식 두 가지가 있던걸로 기억하는데 정확하지는 않고, 아무튼 나는 몇 가지 방식중에 하나를 선택해서 만들었던 기억이 있다.

 

사실 전체 흐름을 그림으로 그려놓고 보면 직관적인데, 사실 내가 위에 코딩해놓은 것을 그림으로 바로 그리려니까 순서가 좀 안맞는 부분이 있어서 당황했다. 코드를 따라서 그려야만 올바르게 나왔다.

아마 코딩을 먼저 하고 그 다음 설계에 대한 분석을 해서 그랬던 것 같다.

 

그림으로 먼저 그려놓고 그냥 코드로만 옮겼으면 더 리뷰가 쉬웠을 것 같다는 생각을 했다.

Comments