[백준 9463] 순열 그래프
https://www.acmicpc.net/problem/9463
https://www.acmicpc.net/problem/7578 공장 문제와 비슷하다
Counting Inversions를 이용해 연결된 갯수를 세는 문제이다
기본적으로 counting inversions을 수행하면 시간복잡도가 O(N * N)이 나와서 counting 누적합을 찾기 위해 세그먼트 트리를 이용해야 한다
수의 범위가 1 ~ N이고 수의 갯수가 N이니 기본 배열 arr과 N의 인덱스를 기억할 compare배열을 만들어
첫 줄에 입력받는 수는 arr에 그대로 입력받고
두 번째 줄에 입력받는 수는 comp[k] = i번째 값을 저장해준다
예제 1번으로 보면 아래와 같은 모양이 된다
i | 1 | 2 | 3 | 4 | 5 |
arr | 2 | 5 | 4 | 1 | 3 |
comp | 1 | 4 | 3 | 5 | 2 |
예제 1번(2 5 4 1 3) 기준
i → 1
comp[arr[1]] = 4 ~ N 까지 몇 개가 연결되었는지를 찾는다 → 0
arr[1] = 2, comp[arr[1]] = 4이므로 4번째 위치의 값을 1 늘린다 → sum = [0 0 0 1 0]
i → 2
comp[arr[2]] = 2 ~ N 까지 몇 개가 연결되었는지를 찾는다 → 1
arr[2] = 5, comp[arr[2]] = 2이므로 2번째 위치의 값을 1 늘린다 → sum = [0 1 0 1 0]
i → 3
comp[arr[3]] = 5 ~ N 까지 몇 개가 연결되었는지를 찾는다 → 0
arr[3] = 5, comp[arr[3]] = 5이므로 5번째 위치의 값을 1 늘린다 → sum = [0 1 0 1 1]
i → 4
comp[arr[4]] = 1 ~ N 까지 몇 개가 연결되었는지를 찾는다 → 3
arr[4] = 1, comp[arr[4]] = 1이므로 1번째 위치의 값을 1 늘린다 → sum[1 1 0 1 1]
i → 5
comp[arr[5]] = 3 ~ N 까지 몇 개가 연결되었는지를 찾는다 → 2
arr[5] = 3, comp[arr[5]] = 3이므로 3번째 위치의 값을 1 늘린다 → sum[1 1 1 1 1]
위에서 찾는다의 과정이 u ~ v 까지 누적합을 찾는 과정이므로 누적합 세그먼트 트리를 이용하면 된다.
#include "iostream"
#include <cstring>
#pragma GCC optimize ("O3")
#pragma GCC optimize ("Ofast")
#pragma GCC optimize ("unroll-loops")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,avx,avx2")
#define ll long long
using namespace std;
int N;
int arr[100001], segTree[262144], comp[100001];
int find(int start, int end, int idx, int left, int right);
void update(int start, int end, int idx, int target);
void init();
int main() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);
int t;
cin >> t;
while(t-->0) {
cin >> N;
init();
for(int i = 1; i <= N; i++) {
cin >> arr[i];
}
for(int i = 1; i <= N; i++) {
int k; cin >> k;
comp[k] = i;
}
ll res = 0;
for(int i = 1; i <= N; i++) {
res += find(1, N, 1, comp[arr[i]], N);
update(1, N, 1, comp[arr[i]]);
}
cout << res << '\n';
}
}
void init() {
memset(arr, 0, sizeof(arr));
memset(comp, 0, sizeof(comp));
memset(segTree, 0, sizeof(segTree));
}
int find(int start, int end, int idx, int left, int right) {
if(right < start || end < left) {
return 0;
}
if(left <= start && end <= right) {
return segTree[idx];
}
int mid = (end + start) / 2;
int l = find(start, mid, idx << 1, left, right);
int r = find(mid + 1, end, idx << 1 | 1, left, right);
return l + r;
}
void update(int start, int end, int idx, int target) {
if(target < start || end < target) {
return;
}
segTree[idx] += 1;
if(start == end) {
return;
}
int mid = (end + start) / 2;
update(start, mid, idx << 1, target);
update(mid + 1, end, idx << 1 | 1, target);
}