ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • I - Coins / atcoder.jp
    atcoder.jp 2020. 1. 13. 13:12
    728x90

    동전이 홀수개가 있고, 각 동전이 앞면이 나올 확률이 주어짐니다. 앞면개수가 뒷면이 개수보다 많이 나오는, 확률의 합을 구하는 문제입니다.

    i는 i번째 코인을 의미합니다. 따라서, N(동전의 숫자)와 i의 관계는 아래와 같구요.
    \begin{aligned}
    1 \leq i \leq N \:,\: 단\: N은\: 홀수
    \end{aligned}

    Sample Input 1 에는 동전 3개의 확률이 주어졌습니다.

    3
    0.30 0.60 0.80

    동전의 앞면의 나올 확률이 서로 다르기 때문에, 확률변수와 확률을 정의하는 것 부터 시작해 보겠습니다.
    pi를 i번째 동전이 앞면이 나올 확률이 라고 하면, 아래와 같 p_i가 정의 됩니다.
    \begin{aligned}
    p_1 = 0.3 \\
    p_2 = 0.6 \\
    p_3 = 0.8 \\
    \end{aligned}

    1번째 동전이 앞면이 나올 확률을 아래와 같이 임의로 d1이라고 정의해 보겠습니다.
    \begin{aligned}
    d_1 = p_1 = 0.3
    \end{aligned}

    2번째 동전까지 고려했을 때, 앞면이 1개 나올 확률을 d2라고 정의해 보면

    • 1번째 동전이 앞면이 나오고나서, 종속사건(depedent event)으로 2번째 동전이 뒷면이 나올 확률
    • 1번째 동전이 뒷면이 나오고나서, 종속사건(depedent event)으로 2번째 동전이 앞면이 나올 확률
      depedent event이기 때문에, 2개의 확률을 서로 곱합니다.
      \begin{aligned}
      d_2 = p_1 \cdot (1-p_2) + (1-p_1) \cdot p_2
      \end{aligned}

    위의 식에서 + 를 중심으로 왼쪽의 식만 따로 분리하면, 아래와 같구요.
    \begin{aligned}
    p_1 \cdot (1-p_2)
    \end{aligned}
    이것은 위의 d1 * (1-p2)로 아래와 같이 다시 쓸수 있습니다.
    \begin{aligned}
    d_2 = d_1 \cdot (1-p_2) + (1-p_1) \cdot p_2
    \end{aligned}

    위의 식의 아래 부분만, 다시 생각해 보면, 첫번째 동전이 뒷면이 나올 확률입니다.
    \begin{aligned}
    (1-p_1)
    \end{aligned}

    이제 d에 동전의 개수, 동전이 앞면이 나오는 개수 2가지 의미를 모두 부여하면,
    위에서는 d에 동전의 개수만 주었습니다.즉 di 는 1번째 동전부터 i번째 동전까지 고려한 것을 의미했습니다.
    \begin{aligned}
    d_i = d_1부터\: d_i\: 번째\: 동전까지\: 고려함
    \end{aligned}
    여기에, 앞면의 숫자 j를 추가하면, di,j가 되구요, j는 동전의 앞면이 나오는 개수입니다.
    \begin{aligned}
    d_{i,j} = d_1부터\: d_i\: 번째\: 동전까지\: 고려했을 때, 앞면이 j 번 나오는 확률
    \end{aligned}

    위와 같이 di,j를 정의 하면, d2,1, 2번째 동전까지 고려했을 때, 동전의 앞면이 1개 나오는 확률을 식으로 풀어보면
    \begin{aligned}
    d_{2,1} = d_{1,1} \cdot (1-p_2) + d_{1,0} \cdot (p_2)
    \end{aligned}
    위의 식을 2, 1이 아닌, i, j를 사용해서, 좀더 일반화 시켜 보면 아래와 같습니다.
    \begin{aligned}
    d_{i,j} = d_{i-1,1} \cdot (1-p_i) + d_{i-1,j-1} \cdot (p_i)
    \end{aligned}
    di,j가 배열이라고 할 때, 배열의 순서대로 식이 진행되도록하러면 {i-1,j} 보다 {i-1, j-1}이 앞쪽입니다. 따라서, +를 중심으로 좌우식의 위치를 아래와 같이 바꿔주겠습니다.
    \begin{aligned}
    d_{i,j} = d_{i-1,j-1} \cdot (p_i) + d_{i-1,1} \cdot (1-p_i)
    \end{aligned}

    일반화된 식의 의미를 풀어보면,

    • i번째 동전에서, 앞면이 j번 나올 확률을 구하려면
      • i-1번째 동전까지 앞면이 j-1번나올 확률에서, i번째 동전이 앞면이 나올 확률을 곱한다
      • i-1번째 동전까지 앞면이 j번 나올 확률에서, i번째 동전이 뒷면이 나올 확률을 곱한다
      • 그리고, 위의 2개의 확률의 합이 d_{i,j} 입니다.

    문제에서 각 동전은 한번씩만 던지기 때문에, 위에서 정의한, 동전의 앞면의 개수 j의 범위는 아래와 같습니다.
    \begin{aligned}
    0 \leq j \leq i
    \end{aligned}
    즉, j가 0이 되는 경우가 존재합니다. d_{i,0}은 i번째 동전까지 모두 던졌는대, 모두다 뒷면이 나올 확률입니다.
    주의해야할 점은 위의 일반화된 식에서 j-1이 있기 때문에, + 중심으로 수식의 오른쪽은 j 가 0보다 큰 경우에만 계산해야 합니다. 동전의 앞면이 -1번 나오는 경우는 없기 때문에 앞면이 0개수가 나오는 부분 부터 계산하기 시작합니다.

    위의 점화식(recurrence relation)을 유도하기 까지 과정을 살펴 보았습니다. 문제풀기 마지막 단계인, 동전의 앞면의 개수가 뒷면의 개수보다 큰 확률의 합을 구하는 방법을 알아 보겠습니다. 위의 점화식을 2차원 배열로 구현해서 계산하면 아래와 같이 됩니다.

    j는 앞면의 개수 이기 때문에 위에서 설명했듯이 j의 최개값은 i와 같습니다. d_{i,j}를 계산할 때는, 위의 노란색의 좌측 하단, 밑금친 영역만 계산하면 되구요, 답은 녹색으로 칠한 부분의 합이 답이 됩니다.

    5개의 동전을 예를 들어 설명하면, 5개의 동전을 던졌을 때, 앞면의 개수가 뒷면의 개수보다 큰 경우는,

    • 5개 동전중에 앞면이 3개 나오는 확률
    • 5개 동전중에 앞면이 4개 나오는 확률
    • 5개 동전중에 앞면이 5개 나오는 확률
      \begin{aligned}
      답 = d_{5, 3} + d_{5, 4} + d_{5, 5}
      \end{aligned}

    위의 답을 구하는 방법을 식으로 정리해 보겠습니다.
    k를 '동전의 앞면의 개수가, 동전의 뒷면의 개수보다 큰 수'라고 정의 하면,
    가장 작은 k 는 수식으로는(N-1)/2+1이 됩니다. 코드에서는 (N>>1)+1 로 보다 간략하게 코딩합니다.
    가장 큰 K는 N 입니다. 모든 동전이 앞면이 나오는 확률입니다.
    \begin{aligned}
    \sum_{k=(N-1)/2+1}^N d_{N,k}
    \end{aligned}

    코딩을 시작하기 전에 마지막 주의할 점,d_{0,0}에 1을 주고, d_{i,j}를 계산을 시작해야합니다.
    왜냐하면, d_{1,0}을 계산할 때, d_{0,0}을 참조하게 되는대, 이때, d_{0,0}이 0이면, 잘못된 값을 계산하게 됩니다.

    import java.io.DataInputStream;
    import java.io.FileInputStream;
    import java.io.IOException;
    import java.io.PrintWriter;
    import java.util.HashSet;
    import java.util.Scanner;
    
    // I - Coins / atcoder.jp
    // 문제 링크 : https://atcoder.jp/contests/dp/tasks/dp_i
    // 문제 해설 : https://jinpyo.kim/EducationalDP-solution
    // Submission : https://atcoder.jp/contests/dp/submissions/9282200
    
    public class Main {
        public static int N;
        public static HashSet<Integer> hs;
        public static void main(String[] args) throws IOException {
    //        Scanner sc  = new Scanner(System.in);
            Reader sc  = new Reader();
            PrintWriter pw = new PrintWriter(System.out);
    
            N = sc.nextInt();
            double[] ps = new double[N+1];
            double[][] dp = new double[N+1][N+1];
    
            for(int i=1; i<=N; ++i) {
                ps[i] = sc.nextDouble();
            }
    
            dp[0][0] = 1;
    
            // i동전까지 고려 했을 때, Head가 j번 나올 확률
            // Head가 J번 나온다는 것은 2가지 확률의 합을 의미합니다.
            // i-1 동전까지 j-1번 Head가 나올 확률에 * i동전이 Head가 나올 확률
            // i-1 동전까지 j번 Head가 나올 확률에 * i동전이 Tail이 나올 확률
            // 이 두 확률의 합입니다.
            for(int i=1; i<=N; ++i) {
                for(int j=0; j<=i; ++j) {
                    if(j-1>=0)
                        dp[i][j] += dp[i-1][j-1] * ps[i];
                    dp[i][j] += dp[i-1][j] * (1-ps[i]);
                }
            }
    
            d0(dp);
    
            int k = (N>>1) + 1;
            double p = 0;
            for(int j=k; j<=N; ++j) {
                p += dp[N][j];
            }
    
            pw.printf("%1.10f\n", p);
            pw.close();
        }
    
        public static void d0(double[][] dp) {
            System.out.print("    ");
    
            for(int j=0; j<dp[0].length; ++j){
                System.out.printf("%3d ", j);
            }
            System.out.println();
    
            for(int i=0; i<dp.length; ++i){
                System.out.printf("%d|", i);
                for(int j=0; j<dp[0].length; ++j){
                    System.out.printf("%f ", dp[i][j]);
                }
                System.out.println();
            }
        }
    
        // https://www.geeksforgeeks.org/fast-io-in-java-in-competitive-programming/
        static class Reader
        {
            final private int BUFFER_SIZE = 1 << 16;
            private DataInputStream din;
            private byte[] buffer;
            private int bufferPointer, bytesRead;
    
            public Reader()
            {
                din = new DataInputStream(System.in);
                buffer = new byte[BUFFER_SIZE];
                bufferPointer = bytesRead = 0;
            }
    
            public Reader(String file_name) throws IOException
            {
                din = new DataInputStream(new FileInputStream(file_name));
                buffer = new byte[BUFFER_SIZE];
                bufferPointer = bytesRead = 0;
            }
    
            public String nextLine() throws IOException
            {
                byte[] buf = new byte[64]; // line length
                int cnt = 0, c;
                while ((c = read()) != -1)
                {
                    if (c == '\n')
                        break;
                    buf[cnt++] = (byte) c;
                }
                return new String(buf, 0, cnt);
            }
    
            public int nextInt() throws IOException
            {
                int ret = 0;
                byte c = read();
                while (c <= ' ')
                    c = read();
                boolean neg = (c == '-');
                if (neg)
                    c = read();
                do
                {
                    ret = ret * 10 + c - '0';
                }  while ((c = read()) >= '0' && c <= '9');
    
                if (neg)
                    return -ret;
                return ret;
            }
    
            public long nextLong() throws IOException
            {
                long ret = 0;
                byte c = read();
                while (c <= ' ')
                    c = read();
                boolean neg = (c == '-');
                if (neg)
                    c = read();
                do {
                    ret = ret * 10 + c - '0';
                }
                while ((c = read()) >= '0' && c <= '9');
                if (neg)
                    return -ret;
                return ret;
            }
    
            public double nextDouble() throws IOException
            {
                double ret = 0, div = 1;
                byte c = read();
                while (c <= ' ')
                    c = read();
                boolean neg = (c == '-');
                if (neg)
                    c = read();
    
                do {
                    ret = ret * 10 + c - '0';
                }
                while ((c = read()) >= '0' && c <= '9');
    
                if (c == '.')
                {
                    while ((c = read()) >= '0' && c <= '9')
                    {
                        ret += (c - '0') / (div *= 10);
                    }
                }
    
                if (neg)
                    return -ret;
                return ret;
            }
    
            private void fillBuffer() throws IOException
            {
                bytesRead = din.read(buffer, bufferPointer = 0, BUFFER_SIZE);
                if (bytesRead == -1)
                    buffer[0] = -1;
            }
    
            private byte read() throws IOException
            {
                if (bufferPointer == bytesRead)
                    fillBuffer();
                return buffer[bufferPointer++];
            }
    
            public void close() throws IOException
            {
                if (din == null)
                    return;
                din.close();
            }
        }
    }
    
    728x90

    'atcoder.jp' 카테고리의 다른 글

    L - Deque / atcoder.jp  (0) 2020.01.18
    K - Stones / atcoder.jp  (0) 2020.01.16
    H - Grid 1 / atcoder.jp  (0) 2020.01.01
    G - Longest Path / atcoder.jp  (0) 2020.01.01
    F - LCS / atcoder.jp  (0) 2019.12.31
Designed by Tistory.