PS/C++

[백준 9463] 순열 그래프

siyamaki 2024. 6. 20. 13:30

https://www.acmicpc.net/problem/9463

 

https://www.acmicpc.net/problem/7578 공장 문제와 비슷하다

 

Counting Inversions를 이용해 연결된 갯수를 세는 문제이다

 

기본적으로 counting inversions을 수행하면 시간복잡도가 O(N * N)이 나와서 counting 누적합을 찾기 위해 세그먼트 트리를 이용해야 한다

 

수의 범위가 1 ~ N이고 수의 갯수가 N이니 기본 배열 arr과 N의 인덱스를 기억할 compare배열을 만들어

첫 줄에 입력받는 수는 arr에 그대로 입력받고

두 번째 줄에 입력받는 수는 comp[k] =  i번째 값을 저장해준다

예제 1번으로 보면 아래와 같은 모양이 된다

i 1 2 3 4 5
arr 2 5 4 1 3
comp 1 4 3 5 2

 

예제 1번(2 5 4 1 3) 기준

 

i → 1

comp[arr[1]] = 4 ~ N 까지 몇 개가 연결되었는지를 찾는다 → 0

arr[1] = 2, comp[arr[1]] = 4이므로  4번째 위치의 값을 1 늘린다 →  sum = [0 0 0 1 0]

 

i → 2

comp[arr[2]] = 2 ~ N 까지 몇 개가 연결되었는지를 찾는다 → 1

arr[2] = 5, comp[arr[2]] = 2이므로  2번째 위치의 값을 1 늘린다 → sum = [0 1 0 1 0]

 

i → 3

comp[arr[3]] = 5 ~ N 까지 몇 개가 연결되었는지를 찾는다 → 0

arr[3] = 5, comp[arr[3]] = 5이므로 5번째 위치의 값을 1 늘린다 → sum = [0 1 0 1 1]

 

i → 4

comp[arr[4]] = 1 ~ N 까지 몇 개가 연결되었는지를 찾는다 → 3

arr[4] = 1, comp[arr[4]] = 1이므로 1번째 위치의 값을 1 늘린다 → sum[1 1 0 1 1]

 

i → 5

comp[arr[5]] = 3 ~ N 까지 몇 개가 연결되었는지를 찾는다 → 2

arr[5] = 3, comp[arr[5]] = 3이므로 3번째 위치의 값을 1 늘린다  → sum[1 1 1 1 1]

 

위에서 찾는다의 과정이 u ~ v 까지 누적합을 찾는 과정이므로 누적합 세그먼트 트리를 이용하면 된다.

#include "iostream"
#include <cstring>
#pragma GCC optimize ("O3")
#pragma GCC optimize ("Ofast")
#pragma GCC optimize ("unroll-loops")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,avx,avx2")
#define ll long long
using namespace std;
int N;
int arr[100001], segTree[262144], comp[100001];
int find(int start, int end, int idx, int left, int right);
void update(int start, int end, int idx, int target);
void init();
int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr); cout.tie(nullptr);
    int t;
    cin >> t;
    while(t-->0) {
        cin >> N;
        init();
        for(int i = 1; i <= N; i++) {
            cin >> arr[i];
        }
        for(int i = 1; i <= N; i++) {
            int k; cin >> k;
            comp[k] = i;
        }
        ll res = 0;
        for(int i = 1; i <= N; i++) {
            res += find(1, N, 1, comp[arr[i]], N);
            update(1, N, 1, comp[arr[i]]);
        }
        cout << res << '\n';
    }

}
void init() {
    memset(arr, 0, sizeof(arr));
    memset(comp, 0, sizeof(comp));
    memset(segTree, 0, sizeof(segTree));
}
int find(int start, int end, int idx, int left, int right) {
    if(right < start || end < left) {
        return 0;
    }
    if(left <= start && end <= right) {
        return segTree[idx];
    }
    int mid = (end + start) / 2;
    int l = find(start, mid, idx << 1, left, right);
    int r = find(mid + 1, end, idx << 1 | 1, left, right);
    return l + r;
}
void update(int start, int end, int idx, int target) {
    if(target < start || end < target) {
        return;
    }
    segTree[idx] += 1;
    if(start == end) {
        return;
    }
    int mid = (end + start) / 2;
    update(start, mid, idx << 1, target);
    update(mid + 1, end, idx << 1 | 1, target);
}