ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • N - Slimes / atcoder.jp
    atcoder.jp 2020. 1. 21. 10:58
    728x90

    문제링크 : https://atcoder.jp/contests/dp/tasks/dp_n
    문제해설 : https://jinpyo.kim/EducationalDP-solution
    Submission : https://atcoder.jp/contests/dp/submissions/9653540
    Java source : https://github.com/skysign/WSAPT/blob/master/atcoder.jp/N%20-%20Slimes/src/Main.java

    여러개의 슬라임을 2개씩 합칠 때, 합치는 비용을 고려해서, 가장 적은 비용으로 합치는 방법을 찾는 문제입니다.

    이 문제를 DP가 방식이 아니라, 아주 간단하게 생각해서, 인접한 두수의 합이 최소가 되는 방식으로, 2개의 슬라임을 합처나가면, 최소비용으로 합칠 수 있지 않을까? 라고 생각할 수도 있습니다. 문제에서 주어진 Sample Input은 그 방식으로 풀어도 답이 나오지만, 아래의 샘플 인풋은 그 방식으로 풀 수 없습니다. 이 문제는 DP로 풀어야만 합니다.

    4
    40 30 30 50

    (40, 30, 30, 50) → (40, 60, 50) : 60
    (40, 60, 50) → (100, 50) : 100
    (100, 50) → (150) : 150
    60 + 100 + 150 = 310
    그 방식으로 풀면, 310 이 합치는 비용입니다.

    DP 방식으로 풀면, 300 이 합치는 비용입니다.
    (40, 30, 30, 50) → (70, 30, 50) : 70
    (70, 30, 50) → (70, 80) : 80
    (70, 80) → (150) : 150
    70 + 80 + 150 = 300

    문제에서 인접합 슬라임 2개를 합친 다고 했으므로, 아래와 같은 k가 반드시 존재합니다.
    dp_{i,j}가 i 부터 j 까지 슬라임을 합칠 때, 최소비용으로 합치는 경우라고 정의합니다.
    \begin{aligned}
    1 \leq i < j \leq N \\
    (i \leq k \: and \: k < j) \: or \: (i < k \: and \: k \leq j) \\
    dp_{i,j} = dp_{i,k} + dp_{k,j} \\
    \end{aligned}

    dp_{i,j}로 위에서 정의한 것은 뒤에서 다시 사용하도록 하구요,
    dp_1 부터 시작해 보겠습니다.
    \begin{aligned}
    dp_1 = a_1 \\
    \end{aligned}

    dp_2 는
    \begin{aligned}
    dp_2 = a_1 + a_2 \\
    \end{aligned}

    dp_3 는 부터는 아래와 같이 2가지 경우가 생기게 되고, 둘 중의 최소값을 선택해야 합니다.
    \begin{aligned}
    dp_3 = min((a_1 + a_2) + (a_1 + a_2) + a_3 ,\; a_1 + (a_2 + a_3) + (a_2 + a_3)) \\
    \end{aligned}

    위의 식을 dp_3 에서 dp_{1,3}으로 고처 써 보면,
    \begin{aligned}
    dp_{1,3} = min((a_1 + a_2) + (a_1 + a_2) + a_3 ,\; a_1 + (a_2 + a_3) + (a_2 + a_3)) \\
    \end{aligned}

    a_1 + a_2 + a_3 이 양쪽에 모두 있으므로, 아래와 같이 고처 쓸 수 있습니다.
    \begin{aligned}
    dp_{1,3} = min((a_1 + a_2) ,\; (a_2 + a_3)) + (a_1+a_2+a_3) \\
    dp_{1,3} = min((a_1 + a_2) ,\; (a_2 + a_3)) + \sum_{x=1}^3 a_x \\
    \end{aligned}

    위의 식을 그대로, 1,2에서 적용해 보면,
    \begin{aligned}
    dp_{1,2} = min((a_1) ,\; (a_2)) + \sum_{x=1}^2 a_x \\
    \end{aligned}

    위의 식을 그대로, 2,3에서 적용해 보면,
    \begin{aligned}
    dp_{2,3} = min((a_2) ,\; (a_3)) + \sum_{x=2}^3 a_x \\
    \end{aligned}

    아래 Sample Input 에 대해서, 위의 점화식을 적용해 보면 아래 그림과 같이 오답이 나오게 됩니다.

    3
    1 2 3

    위의 그림에서, 빨간색이 오답, 파란색이 정답인대, 위에서 새운 점화식 대로 하면, 빨간색의 오답이 계산됩니다. 여기서 좀더 고민해 볼 부분은, 아래와 같습니다. dp_{1,2}의 정답 3은 사실 오답 4에서 -1 하면 정답이 됩니다. 만약에 dp_1_2가 올바르게 계산되어서, 3이 되었다면, dp_{1,3}은 우리가 새운 점화식 대로, 정답을 계산할 수 있습니다.

    위에서 dp_1 부터 dp_3 까지 오면, dp_1을 a_1으로 정의한 것 때문에, 사실 dp_1,1 이 a_1이 되었는대, dp_{1,1} 과 dp_{2,2}, dp_{3,3}을 0 으로 두면, dp_{1,2}를 계산할 때, a_1을 더하지 않게 됨으로, 정답을 계산하게 됩니다.

    즉, 위의 그림에 보라색 X 표 처럼, dp_{i,j}일 때, i == j 이면, a_i가 아닌 0으로 두고, 위의 점화식대로 계산하면, dp_{1,N}에 정답을 계산할 수 있습니다.

    위의 점화식을 고처쓰면,
    \begin{aligned}
    dp_{1,3} = min(dp_{1,2} ,\; dp_{2,3}) + \sum_{x=1}^3 a_x \\
    \end{aligned}

    1,3을 i,j를 사용해서, 일반화 시킬 수 있습니다.

    \begin{aligned}
    dp_{i,j} = min(dp_{i,j-1} ,\; dp_{i+1,j}) + \sum_{x=i}^j a_x \\
    \end{aligned}

    위의 점화식에 보면, a_x 의 특정 구간 부분집합을 자주 계산하게 되는 것을 알 수 있습니다. 왜냐하면, 문제 풀이를 시작할 때, k 의 존재에 대해서 설명했었구요, 이에 따라서, 아래와 같은 i, k구간 k+1, j 구간의 합을 자주 계산하게 됩니다. 이부분을 좀더 빠르게 계산하기 위해서 prefix sum 알고리즘을 사용해서, 보다 빠르게 계산할 수 있습니다. prefix sum은 여기 참고

    마지막으로 dp_{1_4} 를 설명하고, 이번 문제 풀이를 마무리 하겠습니다. 글의 앞부분에서 k가 존재하는 범위에 대해서 설명한 것 과 같이 dp_{1_4} 는 아래와 같이 3가지 경우중에 최소값을 선택해야 합니다.
    \begin{aligned}
    dp_{1,4} = min(dp_{1,1}+dp_{2,4} \:, dp_{1,2}+dp_{3,4} \:, dp_{1,3}+dp_{4,4}) + \sum_{x=1}^4 a_x \\
    \end{aligned}

    수식으로 표현하면, j-i 값이 커짐에 다라서, 위와 같이 min 에 들어가야할 항의 개수가 많아 지면서, 점점 길어짐니다.
    자바 코드로 표현하면 좀더 간결하게 표현할 수 있습니다.

    for(int k=i; k<j; ++k) {
        dp[i][j] = Math.min(dp[i][j], dp[i][k] + dp[k+1][j] + (prefixSum[j] - prefixSum[i-1]));
    }

    마지막으로 dp_{i_j}를 계산하는 방향은 아래와 그림과 같습니다. #1~#3 순서로 계산합니다.

    import java.io.DataInputStream;
    import java.io.FileInputStream;
    import java.io.IOException;
    import java.util.Scanner;
    
    /**
     * N - Slimes / atcoder.jp
     * 문제링크 : https://atcoder.jp/contests/dp/tasks/dp_n
     * 문제해설 : https://jinpyo.kim/EducationalDP-solution
     * Submission : https://atcoder.jp/contests/dp/submissions/9653540
     */
    
    public class Main {
        public static void main(String[] args) throws IOException {
    //        Scanner sc = new Scanner(System.in);
            Reader sc = new Reader();
            int N = sc.nextInt();
            long[] as = new long[N+1];
            long[][] dp = new long[N+1][N+1];
            long[] prefixSum = new long[N+1];
    
            for(int i=1; i<=N; ++i) {
                as[i] = sc.nextLong();
                prefixSum[i] = prefixSum[i-1] + as[i];
            }
    
            for(int ij_d=1; ij_d<=N; ++ij_d) {
                for(int i=1; i+ij_d<=N; ++i) {
                    int j = i+ij_d;
                    dp[i][j] = Long.MAX_VALUE;
    
                    for(int k=i; k<j; ++k) {
                        dp[i][j] = Math.min(dp[i][j], dp[i][k] + dp[k+1][j] + (prefixSum[j] - prefixSum[i-1]));
                    }
                }
            }
    
            System.out.println(dp[1][N]);
        }
    
        // https://www.geeksforgeeks.org/fast-io-in-java-in-competitive-programming/
        static class Reader
        {
            final private int BUFFER_SIZE = 1 << 16;
            private DataInputStream din;
            private byte[] buffer;
            private int bufferPointer, bytesRead;
    
            public Reader()
            {
                din = new DataInputStream(System.in);
                buffer = new byte[BUFFER_SIZE];
                bufferPointer = bytesRead = 0;
            }
    
            public Reader(String file_name) throws IOException
            {
                din = new DataInputStream(new FileInputStream(file_name));
                buffer = new byte[BUFFER_SIZE];
                bufferPointer = bytesRead = 0;
            }
    
            public String nextLine() throws IOException
            {
                byte[] buf = new byte[64]; // line length
                int cnt = 0, c;
                while ((c = read()) != -1)
                {
                    if (c == '\n')
                        break;
                    buf[cnt++] = (byte) c;
                }
                return new String(buf, 0, cnt);
            }
    
            public int nextInt() throws IOException
            {
                int ret = 0;
                byte c = read();
                while (c <= ' ')
                    c = read();
                boolean neg = (c == '-');
                if (neg)
                    c = read();
                do
                {
                    ret = ret * 10 + c - '0';
                }  while ((c = read()) >= '0' && c <= '9');
    
                if (neg)
                    return -ret;
                return ret;
            }
    
            public long nextLong() throws IOException
            {
                long ret = 0;
                byte c = read();
                while (c <= ' ')
                    c = read();
                boolean neg = (c == '-');
                if (neg)
                    c = read();
                do {
                    ret = ret * 10 + c - '0';
                }
                while ((c = read()) >= '0' && c <= '9');
                if (neg)
                    return -ret;
                return ret;
            }
    
            public double nextDouble() throws IOException
            {
                double ret = 0, div = 1;
                byte c = read();
                while (c <= ' ')
                    c = read();
                boolean neg = (c == '-');
                if (neg)
                    c = read();
    
                do {
                    ret = ret * 10 + c - '0';
                }
                while ((c = read()) >= '0' && c <= '9');
    
                if (c == '.')
                {
                    while ((c = read()) >= '0' && c <= '9')
                    {
                        ret += (c - '0') / (div *= 10);
                    }
                }
    
                if (neg)
                    return -ret;
                return ret;
            }
    
            private void fillBuffer() throws IOException
            {
                bytesRead = din.read(buffer, bufferPointer = 0, BUFFER_SIZE);
                if (bytesRead == -1)
                    buffer[0] = -1;
            }
    
            private byte read() throws IOException
            {
                if (bufferPointer == bytesRead)
                    fillBuffer();
                return buffer[bufferPointer++];
            }
    
            public void close() throws IOException
            {
                if (din == null)
                    return;
                din.close();
            }
        }
    }
    
    728x90

    'atcoder.jp' 카테고리의 다른 글

    L - Deque / atcoder.jp  (0) 2020.01.18
    K - Stones / atcoder.jp  (0) 2020.01.16
    I - Coins / atcoder.jp  (0) 2020.01.13
    H - Grid 1 / atcoder.jp  (0) 2020.01.01
    G - Longest Path / atcoder.jp  (0) 2020.01.01
Designed by Tistory.