I - Coins / atcoder.jp
- 문제 링크 : https://atcoder.jp/contests/dp/tasks/dp_i
- 문제 해설 :
- Submission : https://atcoder.jp/contests/dp/submissions/9313650
- Java Source : https://github.com/skysign/WSAPT/blob/master/atcoder.jp/I%20-%20Coins/src/Main.java
동전이 홀수개가 있고, 각 동전이 앞면이 나올 확률이 주어짐니다. 앞면개수가 뒷면이 개수보다 많이 나오는, 확률의 합을 구하는 문제입니다.
i는 i번째 코인을 의미합니다. 따라서, N(동전의 숫자)와 i의 관계는 아래와 같구요.
\begin{aligned}
1 \leq i \leq N \:,\: 단\: N은\: 홀수
\end{aligned}
Sample Input 1 에는 동전 3개의 확률이 주어졌습니다.
3
0.30 0.60 0.80
동전의 앞면의 나올 확률이 서로 다르기 때문에, 확률변수와 확률을 정의하는 것 부터 시작해 보겠습니다.
pi를 i번째 동전이 앞면이 나올 확률이 라고 하면, 아래와 같 p_i가 정의 됩니다.
\begin{aligned}
p_1 = 0.3 \\
p_2 = 0.6 \\
p_3 = 0.8 \\
\end{aligned}
1번째 동전이 앞면이 나올 확률을 아래와 같이 임의로 d1이라고 정의해 보겠습니다.
\begin{aligned}
d_1 = p_1 = 0.3
\end{aligned}
2번째 동전까지 고려했을 때, 앞면이 1개 나올 확률을 d2라고 정의해 보면
- 1번째 동전이 앞면이 나오고나서, 종속사건(depedent event)으로 2번째 동전이 뒷면이 나올 확률
- 1번째 동전이 뒷면이 나오고나서, 종속사건(depedent event)으로 2번째 동전이 앞면이 나올 확률
depedent event이기 때문에, 2개의 확률을 서로 곱합니다.
\begin{aligned}
d_2 = p_1 \cdot (1-p_2) + (1-p_1) \cdot p_2
\end{aligned}
위의 식에서 + 를 중심으로 왼쪽의 식만 따로 분리하면, 아래와 같구요.
\begin{aligned}
p_1 \cdot (1-p_2)
\end{aligned}
이것은 위의 d1 * (1-p2)로 아래와 같이 다시 쓸수 있습니다.
\begin{aligned}
d_2 = d_1 \cdot (1-p_2) + (1-p_1) \cdot p_2
\end{aligned}
위의 식의 아래 부분만, 다시 생각해 보면, 첫번째 동전이 뒷면이 나올 확률입니다.
\begin{aligned}
(1-p_1)
\end{aligned}
이제 d에 동전의 개수, 동전이 앞면이 나오는 개수 2가지 의미를 모두 부여하면,
위에서는 d에 동전의 개수만 주었습니다.즉 di 는 1번째 동전부터 i번째 동전까지 고려한 것을 의미했습니다.
\begin{aligned}
d_i = d_1부터\: d_i\: 번째\: 동전까지\: 고려함
\end{aligned}
여기에, 앞면의 숫자 j를 추가하면, di,j가 되구요, j는 동전의 앞면이 나오는 개수입니다.
\begin{aligned}
d_{i,j} = d_1부터\: d_i\: 번째\: 동전까지\: 고려했을 때, 앞면이 j 번 나오는 확률
\end{aligned}
위와 같이 di,j를 정의 하면, d2,1, 2번째 동전까지 고려했을 때, 동전의 앞면이 1개 나오는 확률을 식으로 풀어보면
\begin{aligned}
d_{2,1} = d_{1,1} \cdot (1-p_2) + d_{1,0} \cdot (p_2)
\end{aligned}
위의 식을 2, 1이 아닌, i, j를 사용해서, 좀더 일반화 시켜 보면 아래와 같습니다.
\begin{aligned}
d_{i,j} = d_{i-1,1} \cdot (1-p_i) + d_{i-1,j-1} \cdot (p_i)
\end{aligned}
di,j가 배열이라고 할 때, 배열의 순서대로 식이 진행되도록하러면 {i-1,j} 보다 {i-1, j-1}이 앞쪽입니다. 따라서, +를 중심으로 좌우식의 위치를 아래와 같이 바꿔주겠습니다.
\begin{aligned}
d_{i,j} = d_{i-1,j-1} \cdot (p_i) + d_{i-1,1} \cdot (1-p_i)
\end{aligned}
일반화된 식의 의미를 풀어보면,
- i번째 동전에서, 앞면이 j번 나올 확률을 구하려면
- i-1번째 동전까지 앞면이 j-1번나올 확률에서, i번째 동전이 앞면이 나올 확률을 곱한다
- i-1번째 동전까지 앞면이 j번 나올 확률에서, i번째 동전이 뒷면이 나올 확률을 곱한다
- 그리고, 위의 2개의 확률의 합이 d_{i,j} 입니다.
문제에서 각 동전은 한번씩만 던지기 때문에, 위에서 정의한, 동전의 앞면의 개수 j의 범위는 아래와 같습니다.
\begin{aligned}
0 \leq j \leq i
\end{aligned}
즉, j가 0이 되는 경우가 존재합니다. d_{i,0}은 i번째 동전까지 모두 던졌는대, 모두다 뒷면이 나올 확률입니다.
주의해야할 점은 위의 일반화된 식에서 j-1이 있기 때문에, + 중심으로 수식의 오른쪽은 j 가 0보다 큰 경우에만 계산해야 합니다. 동전의 앞면이 -1번 나오는 경우는 없기 때문에 앞면이 0개수가 나오는 부분 부터 계산하기 시작합니다.
위의 점화식(recurrence relation)을 유도하기 까지 과정을 살펴 보았습니다. 문제풀기 마지막 단계인, 동전의 앞면의 개수가 뒷면의 개수보다 큰 확률의 합을 구하는 방법을 알아 보겠습니다. 위의 점화식을 2차원 배열로 구현해서 계산하면 아래와 같이 됩니다.
j는 앞면의 개수 이기 때문에 위에서 설명했듯이 j의 최개값은 i와 같습니다. d_{i,j}를 계산할 때는, 위의 노란색의 좌측 하단, 밑금친 영역만 계산하면 되구요, 답은 녹색으로 칠한 부분의 합이 답이 됩니다.
5개의 동전을 예를 들어 설명하면, 5개의 동전을 던졌을 때, 앞면의 개수가 뒷면의 개수보다 큰 경우는,
- 5개 동전중에 앞면이 3개 나오는 확률
- 5개 동전중에 앞면이 4개 나오는 확률
- 5개 동전중에 앞면이 5개 나오는 확률
\begin{aligned}
답 = d_{5, 3} + d_{5, 4} + d_{5, 5}
\end{aligned}
위의 답을 구하는 방법을 식으로 정리해 보겠습니다.
k를 '동전의 앞면의 개수가, 동전의 뒷면의 개수보다 큰 수'라고 정의 하면,
가장 작은 k 는 수식으로는(N-1)/2+1이 됩니다. 코드에서는 (N>>1)+1 로 보다 간략하게 코딩합니다.
가장 큰 K는 N 입니다. 모든 동전이 앞면이 나오는 확률입니다.
\begin{aligned}
\sum_{k=(N-1)/2+1}^N d_{N,k}
\end{aligned}
코딩을 시작하기 전에 마지막 주의할 점,d_{0,0}에 1을 주고, d_{i,j}를 계산을 시작해야합니다.
왜냐하면, d_{1,0}을 계산할 때, d_{0,0}을 참조하게 되는대, 이때, d_{0,0}이 0이면, 잘못된 값을 계산하게 됩니다.
import java.io.DataInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.HashSet;
import java.util.Scanner;
// I - Coins / atcoder.jp
// 문제 링크 : https://atcoder.jp/contests/dp/tasks/dp_i
// 문제 해설 : https://jinpyo.kim/EducationalDP-solution
// Submission : https://atcoder.jp/contests/dp/submissions/9282200
public class Main {
public static int N;
public static HashSet<Integer> hs;
public static void main(String[] args) throws IOException {
// Scanner sc = new Scanner(System.in);
Reader sc = new Reader();
PrintWriter pw = new PrintWriter(System.out);
N = sc.nextInt();
double[] ps = new double[N+1];
double[][] dp = new double[N+1][N+1];
for(int i=1; i<=N; ++i) {
ps[i] = sc.nextDouble();
}
dp[0][0] = 1;
// i동전까지 고려 했을 때, Head가 j번 나올 확률
// Head가 J번 나온다는 것은 2가지 확률의 합을 의미합니다.
// i-1 동전까지 j-1번 Head가 나올 확률에 * i동전이 Head가 나올 확률
// i-1 동전까지 j번 Head가 나올 확률에 * i동전이 Tail이 나올 확률
// 이 두 확률의 합입니다.
for(int i=1; i<=N; ++i) {
for(int j=0; j<=i; ++j) {
if(j-1>=0)
dp[i][j] += dp[i-1][j-1] * ps[i];
dp[i][j] += dp[i-1][j] * (1-ps[i]);
}
}
d0(dp);
int k = (N>>1) + 1;
double p = 0;
for(int j=k; j<=N; ++j) {
p += dp[N][j];
}
pw.printf("%1.10f\n", p);
pw.close();
}
public static void d0(double[][] dp) {
System.out.print(" ");
for(int j=0; j<dp[0].length; ++j){
System.out.printf("%3d ", j);
}
System.out.println();
for(int i=0; i<dp.length; ++i){
System.out.printf("%d|", i);
for(int j=0; j<dp[0].length; ++j){
System.out.printf("%f ", dp[i][j]);
}
System.out.println();
}
}
// https://www.geeksforgeeks.org/fast-io-in-java-in-competitive-programming/
static class Reader
{
final private int BUFFER_SIZE = 1 << 16;
private DataInputStream din;
private byte[] buffer;
private int bufferPointer, bytesRead;
public Reader()
{
din = new DataInputStream(System.in);
buffer = new byte[BUFFER_SIZE];
bufferPointer = bytesRead = 0;
}
public Reader(String file_name) throws IOException
{
din = new DataInputStream(new FileInputStream(file_name));
buffer = new byte[BUFFER_SIZE];
bufferPointer = bytesRead = 0;
}
public String nextLine() throws IOException
{
byte[] buf = new byte[64]; // line length
int cnt = 0, c;
while ((c = read()) != -1)
{
if (c == '\n')
break;
buf[cnt++] = (byte) c;
}
return new String(buf, 0, cnt);
}
public int nextInt() throws IOException
{
int ret = 0;
byte c = read();
while (c <= ' ')
c = read();
boolean neg = (c == '-');
if (neg)
c = read();
do
{
ret = ret * 10 + c - '0';
} while ((c = read()) >= '0' && c <= '9');
if (neg)
return -ret;
return ret;
}
public long nextLong() throws IOException
{
long ret = 0;
byte c = read();
while (c <= ' ')
c = read();
boolean neg = (c == '-');
if (neg)
c = read();
do {
ret = ret * 10 + c - '0';
}
while ((c = read()) >= '0' && c <= '9');
if (neg)
return -ret;
return ret;
}
public double nextDouble() throws IOException
{
double ret = 0, div = 1;
byte c = read();
while (c <= ' ')
c = read();
boolean neg = (c == '-');
if (neg)
c = read();
do {
ret = ret * 10 + c - '0';
}
while ((c = read()) >= '0' && c <= '9');
if (c == '.')
{
while ((c = read()) >= '0' && c <= '9')
{
ret += (c - '0') / (div *= 10);
}
}
if (neg)
return -ret;
return ret;
}
private void fillBuffer() throws IOException
{
bytesRead = din.read(buffer, bufferPointer = 0, BUFFER_SIZE);
if (bytesRead == -1)
buffer[0] = -1;
}
private byte read() throws IOException
{
if (bufferPointer == bytesRead)
fillBuffer();
return buffer[bufferPointer++];
}
public void close() throws IOException
{
if (din == null)
return;
din.close();
}
}
}