문제
수 N개 A1, A2, ..., AN이 주어진다. 이때, 연속된 부분 구간의 합이 M으로 나누어 떨어지는 구간의 개수를 구하는 프로그램을 작성하시오.
즉, Ai + ... + Aj (i ≤ j) 의 합이 M으로 나누어 떨어지는 (i, j) 쌍의 개수를 구해야 한다.
입력
첫째 줄에 N과 M이 주어진다. (1 ≤ N ≤ 106, 2 ≤ M ≤ 103)
둘째 줄에 N개의 수 A1, A2, ..., AN이 주어진다. (0 ≤ Ai ≤ 109)
출력
첫째 줄에 연속된 부분 구간의 합이 M으로 나누어 떨어지는 구간의 개수를 출력한다.
예제 입력 1
5 3
1 2 3 1 2
예제 출력 1
7
풀이
N의 최대 수가 106개이기 때문에 모든 구간 합을 매번 계산하면 제한 시간 안에 절대 통과할 수 없다.
따라서 이 문제를 풀기 위해서는 구간 합 알고리즘에 대한 이해가 필요하다.
예를 들어 arr[] = [1, 4, 5, 2, 7, 3, 4]라는 배열이 주어졌다고 해보자.
구간 합 알고리즘은 위의 배열을 합 배열로 바꿔주는 것부터 시작한다.
[1, 1+4, 1+4+5, 1+4+5+2 …]처럼 누적된 합으로 배열 값을 바꿔주면 s = [1, 5, 10, 12, 19, 22, 26]이 된다.
여기서 임의의 수 i부터 j까지의 구간 합을 공식으로 나타내면 아래와 같다.
만약 i가 2이고, j가 5라면 s[5] - s[1] = 22 - 10 = 12가 된다.
생각해보면 간단하다. s[5]는 arr[0]+arr[1]+arr[2]+…+arr[5]이고, s[1] = arr[0]+arr[1] 이기 때문에,
s[5] - s[1] = arr[2] + arr[3] + arr[4] + arr[5] 즉, 원래 구하려던 2부터 5까지의 합이 된다.
다시 문제로 돌아가서 그럼 나머지 합은 어떻게 구간 합 알고리즘을 적용할 수 있을까?
우리가 구해야 하는 것은 구간합이 M으로 나눈 나머지가 0이어야 한다.
위의 예시에서 설명한 구간 합 배열 s와 임의의 수 i, j라고 할 때 공식으로 나타내면 다음과 같다.
즉, 구간 합 배열을 나머지 배열로 바꾸고 그 나머지가 같은 것들로 조합을 하면 된다는 결론이 나온다.
s = [1, 5, 10, 12, 19, 22, 26], M = 3일 때 나머지 배열 mod[] = [1, 2, 1, 0, 1, 1, 2]이 된다.
나머지 배열에서 원소 값이 0이라는 뜻은 0부터 해당 값까지의 구간 합이 M으로 나누어진다는 뜻이기 때문에 0의 개수를 먼저 정답에 더해준다.
그리고 각각의 나머지 개수를 세서 2개씩 고를 수 있는 경우의 수를 확인하면 된다.
1은 4개니까 (4*3)/2 = 6, 2는 2개니까 (2*1)/2 = 1이고 따라서 정답은 1 + 6 + 1 = 8이다.
코드
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;
public class Main {
public static void main(String[] args) throws IOException {
StringBuilder sb = new StringBuilder();
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
int n = Integer.parseInt(st.nextToken());
int m = Integer.parseInt(st.nextToken());
long[] numArr = new long[n+1];
st = new StringTokenizer(br.readLine());
for(int i=1; i<=n; i++) {
numArr[i] = numArr[i-1] + Long.parseLong(st.nextToken());
}
long sum = 0;
for(int i=1; i<=n; i++) {
numArr[i] %= m;
if(numArr[i]==0) sum++;
}
Arrays.sort(numArr);
long start = numArr[1];
int count = 1;
for(int i=2; i<=n; i++) {
if(numArr[i]==start) count++;
else {
if(count>=2) sum += (count*(count-1)/2);
count = 1;
start = numArr[i];
}
}
if(count>=2) sum += (count*(count-1)/2);
System.out.println(sum );
}
}
'Problem Solving' 카테고리의 다른 글
[Java] 백준 3079. 입국심사 (0) | 2023.04.27 |
---|---|
[Java] 백준 11003. 최솟값 찾기 (0) | 2023.04.05 |
[Java] 백준 1260. DFS와 BFS (0) | 2023.03.30 |
[Java] 프로그래머스. 전화번호 목록 (0) | 2023.03.16 |
[Java] 프로그래머스. 완주하지 못한 선수 (0) | 2023.03.16 |