PS/C++

[백준 1471] 사탕 돌리기

siyamaki 2024. 6. 20. 10:35

1 ~ N까지의 수가 시계방향 원형으로 만들어져 있고 N -> 1로 연결되어 있다

 

어느 노드가 가지고 있는 자릿수의 합만큼 이동한다 된다.(ex : 123 (1 + 2 + 3) → 123 + 6)

 

N이 17일 경우 아래와 같이 그래프가 그려지는걸 볼 수가 있는데

         7 → 14  15 ← 12 ← 6 ← 3
              ↓   ↓        ↑
    9 → 1 → 2 → 4 → 8 → 16
                       ↑
5 → 10 → 11 → 13 → 17

 

자세히 보면 그래프는 전부 단방향이고 1차원 배열로 그래프를 구성할 수 있다.(arr[현재값] = 다음값)

 

15 12 6 16 8 4의 경우는 사이클이 생기면 해당 노드에 도착할 땐 항상 6개의 정점을 방문 후 종료된다. 이러한 사이클이 만들어지는 정점을 찾기 위해 SCC를 이용한다.

 

DP배열을 만들고 -1로 초기화 한 다음 사이클이 생기는 노드의 번호에 scc의 size가 2 이상일 경우 scc에 속한 정점 번호에 scc의 size의 갯수를 저장해준다

 

사이클이 생기지 않은 경우는 단독 값이므로 전부 1인 경우이니 저장을 할 필요가 없다.

 

그 다음으로는 1번정점부터 N번정점까지 dfs를 이용해 dp배열에 값을 채워나간 후 최댓값을 찾으면 된다

#include <bits/stdc++.h>
using namespace std;
int N, id = 1, res = 0, scc[200001], dp[200001], arr[200001];
bool finished[200001];
vector<vector<int>> t;
stack<int> s;
int dfs(int current);
int getSum(int n);
int cntDfs(int current);
int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr); cout.tie(nullptr);

    cin >> N;

    for(int i = 1; i <= N; i++) {   // 그래프 생성
        int jump = getSum(i);
        int next = (i + jump) % N == 0 ? N : (i + jump) % N;
        arr[i] = next;
    }
    scc[0] = 987654321;
    for(int i = 1; i <= N; i++) {
        if(scc[i] == 0) {
            dfs(i);
        }
    }
    fill(dp, dp + N + 1, -1);
    for(const auto &k : t) {
        if(k.size() >= 2) {
            for(int n : k) {
                dp[n] = k.size();
            }
        }
    }
    for(int i = 1; i <= N; i++) {
        if(dp[i] == -1) {
            cntDfs(i);
        }
    }
    for(int i = 1; i <= N; i++) {
        res = max(res, dp[i]);
    }
    cout << res;
}
int getSum(int n) {
    int k = 1;
    int sum = 0;
    while(k <= n) {
        sum += (n / k) % 10;
        k *= 10;
    }
    return sum;
}
int dfs(int current) {
    scc[current] = id++;
    s.push(current);
    int pid = scc[current];

    int next = arr[current];
    if(scc[next] == 0) {
        pid = min(pid, dfs(next));
    } else if(!finished[next]) {
        pid = min(pid, scc[next]);
    }

    if(pid == scc[current]) {
        vector<int> tmp;
        while(current) {
            int y = s.top(); s.pop();
            tmp.push_back(y);
            finished[y]= true;
            if(y == current) {
                break;
            }
        }
        t.push_back(tmp);
    }

    return pid;
}
int cntDfs(int current) {
    int next = arr[current];
    if(current == next) {
        dp[current] = 1;
    } else if (dp[next] != -1) {
        dp[current] = dp[next] + 1;
    } else {
        dp[current] = cntDfs(next) + 1;
    }
    return dp[current];
}