graphical model에서 inference하는 방법에 대해 정리하였다. (미완성)
node들에 대한 posterior를 계산하고 싶다고 하자. 이번 장에서는 exact inference에 대해 집중해서 알아보자.
chain의 모습을 갖는 undirected의 joint distribution에 대해 살펴보자. 각 variable들은 K개의 states를 갖는 discrete variable이라고 가정한다. 그러면 joint disribution은 (N−1)K2개의 parameter들을 갖고 있다.
p(x)=Z1ψ1,2(x1,x2)ψ2,3(x2,x3)…ψN−1,N(xN−1,xN)
이제 marginal distribution p(xn)를 inference해보려고 한다. 가장 쉽게 보이지만 복잡하고 시간이 오래걸리는 방법은 아래처럼 다 summation하는 것이다.
p(xn)=x1∑…xn−1∑xn+1∑…xN∑p(x)
joint는 K개의 state를 갖는 N개의 variable이 있기 때문에 KN개의 값이 존재하고 이를 계산하는 것은 비효율적이다. 그렇다면 chain의 특징을 이용해서 조금 더 효율적인 방법을 이용해보자.
p(xn)=Z1[xn−1∑ψn−1,n(xn−1,xn)…[x2∑ψ2,3(x2,x3)[x1∑ψ1,2(x1,x2)]]…][xn+1∑ψn,n+1(xn,xn+1)…[xN∑ψN−1,N(xN−1,xN)]…]
위의 방법으로 구하면 total cost는 O(NK2)이다. chain처럼 conditional independence를 찾아서 이용하는 것의 장점을 느낄 수 있다. 위와 같은 방법은 local messages를 보내는 것으로 해석할 수 있다. 크게 보면 marginal p(xn)은 두 개의 factor로 나누어서 생각할 수 있다.
p(xn)=Z1μα(xn)μβ(xn)
각각 xn의 앞, 뒤에서 흘러오는 message로 이해할 수 있다.
graph에서 tree는 어떤 두 개의 node를 선택했을 때 오직 하나의 path만 존재하는 것을 의미한다. 그리고 모든 node는 하나의 parent node를 갖고 가장 위에 있는 node는 root라고 부른다. local message passing을 이용한 inference를 이 tree에 이용하는 sum-product algorithm에 대해 배울 것이다.
graph에 node를 추가하여서 decomposition을 explicit하게 하는 것이다. xs를 subset of the variable이라고 하면 joint를 아래와 같이 나타낼 수 있다.
p(x)=s∏fs(xs)
여기서 fs는 a function of a corresponding set of variables 이다. 각 factor fs(xs)는 directed의 경우 local conditional distribution의 역할과 같고 undirected의 경우 potential function이라고 할 수 있다.
undirected, directed 모두 factor graph로 일반화가 가능해지는 것으로 이해할 수 있을 것 같다.
tree-structured graph에서 exact inference를 하기 위한 방법을 알아보자. variable들은 discrete이라고 가정하기에 summation으로 계산을 진행한다. (물론 continuous도 동일하게 가능) loop없는 directed graph에서 exact inference하는 알고리즘은 belief propagation이라 하고 이는 sum-product algorithm의 특별한 경우에 해당한다.
original graph는
- undirected tree, directed tree, polytree
이고 이에 대응되는 factor graph는 tree structure를 가진다. original graph는 factor graph로 바꾸는 과정을 통해 undirected, directed에 동일한 방법을 적용할 수 있게 된다. 우리는 이런 과정을 통해 최종적으로 얻고자 하는 바는 아래와 같다.
- to obtain an efficient, exact inference algorithm for finding marginals
- in situations where several marginals are required to allow computations to be shared efficiently
먼저 marginal을 구하는 것부터 시작해보자.
p(x)=x−x∑p(x)
우리는 tree structure를 다루고 있고 이를 통해 joint distribution의 factor들을 그룹으로 partition할 수 있다.
p(x)=s∈ne(x)∏Fs(x,Xs)
- ne(x) : 이웃 variable을 의미
- Xs : factor node fx를 통해 x와 연결된 subtree에 있는 set of all variables
- Fs(x,Xs) : the product of all the factors in the group associated with factor fs
이를 통해 marginal식을 살펴보면
p(x)=x−x∑s∈ne(x)∏Fs(x,Xs)=s∈ne(x)∏x−x∑Fs(x,Xs)=s∈ne(x)∏μfs→x(x)
여기서 우리는 새로운 a set of functions μfs→x(x)=∑x−xFs(x,Xs)을 만나게 된다. 이는 factor nodes fs에서 x를 향하는 messages라고 볼 수 있다. 그래서 marginal은 node x에 도착하는 message들의 product라고 이해할 수 있다.
Fs(x,Xs)을 조금 더 factorize해보자.
Fs(x,Xs)=fs(x,x1,…,xM)G1(x1,Xs1)…GM(xM,XsM)
μfs→x(x)=x1∑…xM∑fs(x,x1,…,xM)m∈ne(fs)−x∏[Xsm∑Gm(xm,Xsm)]=x1∑…xM∑fs(x,x1,…,xM)m∈ne(fs)−x∏μxm→fs(xm)
이번에는 μxm→fs(xm)=∑XsmGm(xm,Xsm) 이번에는 variable에서 factor로 가는 message를 의미한다.
이처럼 message를 보내는 flow를 이용하여 marginal distribution을 구할 수 있다. 구체적인 예시는 PRML책 409page를 보면 된다. 내용이 꽤 길어서 일부 생략하기로 한다.
이번에는 high probability를 갖는 latent variable을 구하고 싶은 경우를 생각해보자.
xmax=argxmaxp(x) p(xmax)=xmaxp(x)
이를 구하기 위해 먼저 chain 예시를 한 번 살펴보자.
- 아래 식을 전개하는데 이용한 식
- maxxp(x)=maxx1…maxxMp(x)
- max(ab,ac)=amax(b,c),;where;a>0
xmaxp(x)=Z1x1max…xNmax[ψ1,2(x1,x2)…ψN−1,N(xN−1,xN)]=Z1x1max[ψ1,2(x1,x2)[…xNmaxψN−1,N(xN−1,xN)]]
이는 이전의 sum-product algorithm처럼 message를 전달하는 것으로 이해할 수 있다. 이제는 tree-structured factor graph를 통해 일반적인 경우로 알아보자.
sum-product algorithm과 거의 비슷하다. sum이 max로 바뀐 경우라고 이해할 수 있다. 거기에 추가로 product를 log를 씌워서 sum으로 implement한다.
μf→x(x)=x1,…,xMmax[logf(x,x1,…,xM)+m∈ne(fs)−x∑μxm→f(xm)]mux→f(x)=l∈ne(x)−f∑μfl→x(x)