atcoder.jp

I - Coins / atcoder.jp

건이두 2020. 1. 13. 13:12
728x90

동전이 홀수개가 있고, 각 동전이 앞면이 나올 확률이 주어짐니다. 앞면개수가 뒷면이 개수보다 많이 나오는, 확률의 합을 구하는 문제입니다.

i는 i번째 코인을 의미합니다. 따라서, N(동전의 숫자)와 i의 관계는 아래와 같구요.
\begin{aligned}
1 \leq i \leq N \:,\: 단\: N은\: 홀수
\end{aligned}

Sample Input 1 에는 동전 3개의 확률이 주어졌습니다.

3
0.30 0.60 0.80

동전의 앞면의 나올 확률이 서로 다르기 때문에, 확률변수와 확률을 정의하는 것 부터 시작해 보겠습니다.
pi를 i번째 동전이 앞면이 나올 확률이 라고 하면, 아래와 같 p_i가 정의 됩니다.
\begin{aligned}
p_1 = 0.3 \\
p_2 = 0.6 \\
p_3 = 0.8 \\
\end{aligned}

1번째 동전이 앞면이 나올 확률을 아래와 같이 임의로 d1이라고 정의해 보겠습니다.
\begin{aligned}
d_1 = p_1 = 0.3
\end{aligned}

2번째 동전까지 고려했을 때, 앞면이 1개 나올 확률을 d2라고 정의해 보면

  • 1번째 동전이 앞면이 나오고나서, 종속사건(depedent event)으로 2번째 동전이 뒷면이 나올 확률
  • 1번째 동전이 뒷면이 나오고나서, 종속사건(depedent event)으로 2번째 동전이 앞면이 나올 확률
    depedent event이기 때문에, 2개의 확률을 서로 곱합니다.
    \begin{aligned}
    d_2 = p_1 \cdot (1-p_2) + (1-p_1) \cdot p_2
    \end{aligned}

위의 식에서 + 를 중심으로 왼쪽의 식만 따로 분리하면, 아래와 같구요.
\begin{aligned}
p_1 \cdot (1-p_2)
\end{aligned}
이것은 위의 d1 * (1-p2)로 아래와 같이 다시 쓸수 있습니다.
\begin{aligned}
d_2 = d_1 \cdot (1-p_2) + (1-p_1) \cdot p_2
\end{aligned}

위의 식의 아래 부분만, 다시 생각해 보면, 첫번째 동전이 뒷면이 나올 확률입니다.
\begin{aligned}
(1-p_1)
\end{aligned}

이제 d에 동전의 개수, 동전이 앞면이 나오는 개수 2가지 의미를 모두 부여하면,
위에서는 d에 동전의 개수만 주었습니다.즉 di 는 1번째 동전부터 i번째 동전까지 고려한 것을 의미했습니다.
\begin{aligned}
d_i = d_1부터\: d_i\: 번째\: 동전까지\: 고려함
\end{aligned}
여기에, 앞면의 숫자 j를 추가하면, di,j가 되구요, j는 동전의 앞면이 나오는 개수입니다.
\begin{aligned}
d_{i,j} = d_1부터\: d_i\: 번째\: 동전까지\: 고려했을 때, 앞면이 j 번 나오는 확률
\end{aligned}

위와 같이 di,j를 정의 하면, d2,1, 2번째 동전까지 고려했을 때, 동전의 앞면이 1개 나오는 확률을 식으로 풀어보면
\begin{aligned}
d_{2,1} = d_{1,1} \cdot (1-p_2) + d_{1,0} \cdot (p_2)
\end{aligned}
위의 식을 2, 1이 아닌, i, j를 사용해서, 좀더 일반화 시켜 보면 아래와 같습니다.
\begin{aligned}
d_{i,j} = d_{i-1,1} \cdot (1-p_i) + d_{i-1,j-1} \cdot (p_i)
\end{aligned}
di,j가 배열이라고 할 때, 배열의 순서대로 식이 진행되도록하러면 {i-1,j} 보다 {i-1, j-1}이 앞쪽입니다. 따라서, +를 중심으로 좌우식의 위치를 아래와 같이 바꿔주겠습니다.
\begin{aligned}
d_{i,j} = d_{i-1,j-1} \cdot (p_i) + d_{i-1,1} \cdot (1-p_i)
\end{aligned}

일반화된 식의 의미를 풀어보면,

  • i번째 동전에서, 앞면이 j번 나올 확률을 구하려면
    • i-1번째 동전까지 앞면이 j-1번나올 확률에서, i번째 동전이 앞면이 나올 확률을 곱한다
    • i-1번째 동전까지 앞면이 j번 나올 확률에서, i번째 동전이 뒷면이 나올 확률을 곱한다
    • 그리고, 위의 2개의 확률의 합이 d_{i,j} 입니다.

문제에서 각 동전은 한번씩만 던지기 때문에, 위에서 정의한, 동전의 앞면의 개수 j의 범위는 아래와 같습니다.
\begin{aligned}
0 \leq j \leq i
\end{aligned}
즉, j가 0이 되는 경우가 존재합니다. d_{i,0}은 i번째 동전까지 모두 던졌는대, 모두다 뒷면이 나올 확률입니다.
주의해야할 점은 위의 일반화된 식에서 j-1이 있기 때문에, + 중심으로 수식의 오른쪽은 j 가 0보다 큰 경우에만 계산해야 합니다. 동전의 앞면이 -1번 나오는 경우는 없기 때문에 앞면이 0개수가 나오는 부분 부터 계산하기 시작합니다.

위의 점화식(recurrence relation)을 유도하기 까지 과정을 살펴 보았습니다. 문제풀기 마지막 단계인, 동전의 앞면의 개수가 뒷면의 개수보다 큰 확률의 합을 구하는 방법을 알아 보겠습니다. 위의 점화식을 2차원 배열로 구현해서 계산하면 아래와 같이 됩니다.

j는 앞면의 개수 이기 때문에 위에서 설명했듯이 j의 최개값은 i와 같습니다. d_{i,j}를 계산할 때는, 위의 노란색의 좌측 하단, 밑금친 영역만 계산하면 되구요, 답은 녹색으로 칠한 부분의 합이 답이 됩니다.

5개의 동전을 예를 들어 설명하면, 5개의 동전을 던졌을 때, 앞면의 개수가 뒷면의 개수보다 큰 경우는,

  • 5개 동전중에 앞면이 3개 나오는 확률
  • 5개 동전중에 앞면이 4개 나오는 확률
  • 5개 동전중에 앞면이 5개 나오는 확률
    \begin{aligned}
    답 = d_{5, 3} + d_{5, 4} + d_{5, 5}
    \end{aligned}

위의 답을 구하는 방법을 식으로 정리해 보겠습니다.
k를 '동전의 앞면의 개수가, 동전의 뒷면의 개수보다 큰 수'라고 정의 하면,
가장 작은 k 는 수식으로는(N-1)/2+1이 됩니다. 코드에서는 (N>>1)+1 로 보다 간략하게 코딩합니다.
가장 큰 K는 N 입니다. 모든 동전이 앞면이 나오는 확률입니다.
\begin{aligned}
\sum_{k=(N-1)/2+1}^N d_{N,k}
\end{aligned}

코딩을 시작하기 전에 마지막 주의할 점,d_{0,0}에 1을 주고, d_{i,j}를 계산을 시작해야합니다.
왜냐하면, d_{1,0}을 계산할 때, d_{0,0}을 참조하게 되는대, 이때, d_{0,0}이 0이면, 잘못된 값을 계산하게 됩니다.

import java.io.DataInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.HashSet;
import java.util.Scanner;

// I - Coins / atcoder.jp
// 문제 링크 : https://atcoder.jp/contests/dp/tasks/dp_i
// 문제 해설 : https://jinpyo.kim/EducationalDP-solution
// Submission : https://atcoder.jp/contests/dp/submissions/9282200

public class Main {
    public static int N;
    public static HashSet<Integer> hs;
    public static void main(String[] args) throws IOException {
//        Scanner sc  = new Scanner(System.in);
        Reader sc  = new Reader();
        PrintWriter pw = new PrintWriter(System.out);

        N = sc.nextInt();
        double[] ps = new double[N+1];
        double[][] dp = new double[N+1][N+1];

        for(int i=1; i<=N; ++i) {
            ps[i] = sc.nextDouble();
        }

        dp[0][0] = 1;

        // i동전까지 고려 했을 때, Head가 j번 나올 확률
        // Head가 J번 나온다는 것은 2가지 확률의 합을 의미합니다.
        // i-1 동전까지 j-1번 Head가 나올 확률에 * i동전이 Head가 나올 확률
        // i-1 동전까지 j번 Head가 나올 확률에 * i동전이 Tail이 나올 확률
        // 이 두 확률의 합입니다.
        for(int i=1; i<=N; ++i) {
            for(int j=0; j<=i; ++j) {
                if(j-1>=0)
                    dp[i][j] += dp[i-1][j-1] * ps[i];
                dp[i][j] += dp[i-1][j] * (1-ps[i]);
            }
        }

        d0(dp);

        int k = (N>>1) + 1;
        double p = 0;
        for(int j=k; j<=N; ++j) {
            p += dp[N][j];
        }

        pw.printf("%1.10f\n", p);
        pw.close();
    }

    public static void d0(double[][] dp) {
        System.out.print("    ");

        for(int j=0; j<dp[0].length; ++j){
            System.out.printf("%3d ", j);
        }
        System.out.println();

        for(int i=0; i<dp.length; ++i){
            System.out.printf("%d|", i);
            for(int j=0; j<dp[0].length; ++j){
                System.out.printf("%f ", dp[i][j]);
            }
            System.out.println();
        }
    }

    // https://www.geeksforgeeks.org/fast-io-in-java-in-competitive-programming/
    static class Reader
    {
        final private int BUFFER_SIZE = 1 << 16;
        private DataInputStream din;
        private byte[] buffer;
        private int bufferPointer, bytesRead;

        public Reader()
        {
            din = new DataInputStream(System.in);
            buffer = new byte[BUFFER_SIZE];
            bufferPointer = bytesRead = 0;
        }

        public Reader(String file_name) throws IOException
        {
            din = new DataInputStream(new FileInputStream(file_name));
            buffer = new byte[BUFFER_SIZE];
            bufferPointer = bytesRead = 0;
        }

        public String nextLine() throws IOException
        {
            byte[] buf = new byte[64]; // line length
            int cnt = 0, c;
            while ((c = read()) != -1)
            {
                if (c == '\n')
                    break;
                buf[cnt++] = (byte) c;
            }
            return new String(buf, 0, cnt);
        }

        public int nextInt() throws IOException
        {
            int ret = 0;
            byte c = read();
            while (c <= ' ')
                c = read();
            boolean neg = (c == '-');
            if (neg)
                c = read();
            do
            {
                ret = ret * 10 + c - '0';
            }  while ((c = read()) >= '0' && c <= '9');

            if (neg)
                return -ret;
            return ret;
        }

        public long nextLong() throws IOException
        {
            long ret = 0;
            byte c = read();
            while (c <= ' ')
                c = read();
            boolean neg = (c == '-');
            if (neg)
                c = read();
            do {
                ret = ret * 10 + c - '0';
            }
            while ((c = read()) >= '0' && c <= '9');
            if (neg)
                return -ret;
            return ret;
        }

        public double nextDouble() throws IOException
        {
            double ret = 0, div = 1;
            byte c = read();
            while (c <= ' ')
                c = read();
            boolean neg = (c == '-');
            if (neg)
                c = read();

            do {
                ret = ret * 10 + c - '0';
            }
            while ((c = read()) >= '0' && c <= '9');

            if (c == '.')
            {
                while ((c = read()) >= '0' && c <= '9')
                {
                    ret += (c - '0') / (div *= 10);
                }
            }

            if (neg)
                return -ret;
            return ret;
        }

        private void fillBuffer() throws IOException
        {
            bytesRead = din.read(buffer, bufferPointer = 0, BUFFER_SIZE);
            if (bytesRead == -1)
                buffer[0] = -1;
        }

        private byte read() throws IOException
        {
            if (bufferPointer == bytesRead)
                fillBuffer();
            return buffer[bufferPointer++];
        }

        public void close() throws IOException
        {
            if (din == null)
                return;
            din.close();
        }
    }
}
728x90