PS/C++

[백준 17410] 수열과 쿼리 1.5

siyamaki 2024. 6. 20. 15:01

https://www.acmicpc.net/problem/17410

 

사용 알고리즘 : 제곱근 분할법

구간을 특정 값으로 쪼개어 나눠 저장하고(bucket) k보다 큰 원소의 개수는 upper_bound로 찾았다.

 

처음에 구간 내에서 값을 update 할 때 해당 bucket에서 value 위치를 upper bound로 찾아 값을 대체 후 sort하는 식으로 풀었는데 계속 시간 초과가 발생하였다.

 

https://www.acmicpc.net/board/view/143801 의 답변을 통해 그냥 naive하게 바꾸려는 값의 위치를 찾아 erase해주고

새로 들어와야 하는 위치를 찾아 insert 해주면 된다.

 

find 할 때에는 l ~  [bucket n][bucket m][bucket o] ~ r 식으로 왼쪽 구간 ~ 다음 bucket 이전까지 반복문으로 구하고 마지막 bucket ~ r 구간까지 반복문으로 구한 후 bucket들 끼리는 upper_bound를 이용해 큰 원소의 개수를 구한다.

 

보통 제곱근 분할을 할 땐 √N을 이용하는데 이걸 써도 시간초과가 났길래 버킷의 크기를 계속 조정해서 맞는 개수를 찾았다

#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

int arr[100001];
vector<int> bucket[84];
int N, M, sqrtN;
void update(int target, int value);
int find(int left, int right, int value);
int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr); cout.tie(nullptr);

    cin >> N;
    sqrtN = 1200;
    for(int i = 1; i <= N; i++) {
        cin >> arr[i];
    }
    cin >> M;
    for(int i = 1; i <= N; i++) {
        bucket[i / sqrtN].push_back(arr[i]);
    }
    for(auto & i : bucket) {
        sort(i.begin(), i.end());
    }
    for(int i = 0; i < M; i++) {
        int a, b, c;
        cin >> a >> b >> c;
        if(a == 1)  {
            update(b, c);
        } else {
            int d; cin >> d;
            int res = find(b, c, d);
            cout << res << '\n';
        }
    }
}
void update(int target, int value) {
    int idx = target / sqrtN;   // 버킷 위치
    // 값 변경
    for(int i = 0; i < bucket[idx].size(); i++) {
        if(bucket[idx][i] == arr[target]) {
            bucket[idx].erase(bucket[idx].begin() + i);
            break;
        }
    }
    for(int i = 0; i <= bucket[idx].size(); i++) {
        if(i == bucket[idx].size() || bucket[idx][i] >= value) {
            bucket[idx].insert(bucket[idx].begin() + i, value);
            break;
        }
    }
    arr[target] = value;
}
int find(int left, int right, int value) {
    int res = 0;

    while(left % sqrtN != 0 && left <= right) { // 왼쪽 구간
        res += arr[left] > value ? 1 : 0;
        left++;
    }
    while((right + 1) % sqrtN != 0 && left <= right) { // 오른쪽 구간
        res += arr[right] > value ? 1 : 0;
        right--;
    }
    while(left <= right) {
        int idx = left / sqrtN;
        res += bucket[idx].size() - (upper_bound(bucket[idx].begin(), bucket[idx].end(), value) - bucket[idx].begin());
        left += sqrtN;
    }
    return res;
}