research blog for data science

Junction Tree Variational Autoencoder for Molecular Graph Generation 논문 리뷰

|

Junction Tree Variational Auto-encoder(JT-VAE)는 기존 SMILES string 기반 생성 모델들이 SMILES string을 사용하는 것에 문제를 제기하여 캐릭터가 아니라 molecular graph가 직접 입력으로 들어가는 모델입니다. 또한, 유효한 화합물 구조를 생성하기 위해 Junction Tree Algorithm에서 아이디어를 착안하여 모델을 제시하였습니다.

Problem

SMILES string을 입력으로 하는 것은 크리티컬한 2가지 문제가 발생합니다. 먼저, SMILES 표현은 화합물간 유사도를 담아내지 못합니다.

similarity

그림 1

위 그림 1.을 보면, 두 화합물의 구조는 유사하지만 SMILES으로 나타냈을 땐 전혀 다른 표현이 됩니다. 따라서, SMILES 표현의 한계로 인해 VAE와 같은 생성모델들이 화합물 임베딩 공간을 올바르게 형성하지 못합니다. 두번째로, 그래프 형태의 화합물 표현이 SMILES 화합물 표현보다 분자의 화학적 특성을 더 잘 담아냅니다. 이러한 이유로, 본 논문은 그래프적 표현(molecular graph)을 직접적으로 사용하는 것이 유효한 구조를 만들어 내는 것을 향상시킬 것이라 가정을 하고 있습니다.

Junction Tree Variational Auto-Encoder

Molecular graph를 만든다는 건 일반적으로 원자를 하나씩 순차적으로 생성하는 것으로 생각할 수 있습니다(Li et al., 2018). 그러나, 본 논문에서는 이러한 접근법은 유효하지 않은 구조(chemically invalid)를 만들어 낼 가능성이 높다고 합니다. 원자를 하나씩 붙여나가면서 생성하면 중간 단계 구조들은 invalid하며, 완전한 구조가 나올때 까지 딜레이가 길기 때문입니다.

그림 2

따라서, 본 논문은 원자 단위로 molecular graph를 만들어 나가는 것이 아니라 유효한 분자 단위들의 집합을 미리 정해놓고 이 단위들을 붙여 나가면서 화합물을 구축합니다. 마치 자연어처리에서 문장 생성 문제를 풀 때, 사전을 미리 구축해 놓고 그 속에 존재하는 단어들로 문장을 구축해 나가는 것과 같이 생각하면 될 것 같습니다.

junction tree 라는 이름이 붙게 된 이유는 다음과 같습니다. 유효한 분자 단위는 molecular graph내의 sub-graph로 생각할 수 있고, 이 sub-graph는 이 그래프 자체로도 유효한 화학 분자 구성 요소를 이룹니다. 즉, 마치 junction tree의 node가 complete graph인 clique과 유사합니다. 이 부분에서 junction tree 아이디어를 착안한 것입니다

하나의 분자에 대해서 생성되는 방식은 다음과 같습니다. 어떤 유한한 개수의 유효한 화합물 단위(valid components)들로 구성된 집합에서 해당 분자를 구성할 것 같은 요소들을 선택한 후, 그 요소들을 가지고 제일 그럴듯한 구조가 나오도록 조합하는 것입니다. 이런 식의 접근의 장점은 Li et al.(2018)와 다르게 valid한 화합물을 생성하는 것을 보장할 수 있습니다. 또한 구성요소 간의 상호작용 관계도 고려되기 때문에 더 실용적인 방법입니다.

유효한 화합물 단위는 마치 building block과 같은 역할입니다.

그림 3

그림 3.은 Junction Tree VAE(JT-VAE) 모식도입니다. 첫번째 단계로 한 분자가 입력으로 들어오면, 화합물 단위 사전을 이용하여 Tree Decomposition을 수행합니다. 수행 결과, Junction Tree $\tau$ 가 나옵니다.

하나의 분자가 주어졌을 때, 한 분자는 2가지 종류의 표현을 가지게 됩니다. - Molecular graph 와 Junction Tree 표현

2가지 표현을 가지고 있는 것과 같이 Graph Encoder/Decoder와 Tree Encoder/Decoder로 구성됩니다. Molecular Graph는 Graph Encoder에 의해 $z_{G}$ 로 인코딩됩니다. 마찬가지로, Tree Encoder에 의해 Junction Tree $\tau$ 는 $z_{\tau}$ 으로 인코딩됩니다. 그런 후 Tree Decoder와 화합물 사전을 이용하여 제일 가능성이 높은 화합물 단위를 조합하여 Junction Tree $\hat{\tau}$ 를 생성합니다. 이 때, Junction Tree $\hat{\tau}$ 에서 화합물 단위간 연결(edge)는 화합물 간 결합 방향 및 결합 종류(단일결합, 이중결합 등등)에 관한 정보를 포함하고 있지 않고, 상대적인 배열에 관한 정보만을 담고 있습니다. 그 다음, graph decoder에 의해, Junction Tree $\hat{\tau}$ 와 $z_{G}$ 는 Molecular graph 표현으로 나타내지게 됩니다. 이 과정에서 화합물 단위 사이의 결합이 정해지게 됩니다. 화합물 단위 간 결합될 수 있는 후보들을 나열 한 후, 각 후보 군에 대한 결합점수를 매기고, 점수가 가장 높은 결합이 화합물 단위 사이의 결합으로 결정됩니다.

Junction Tree Algorithm

junction tree algorithm은 probabilistic graphical model에서 inference problem를 효율적으로 풀기 위한 알고리즘입니다.

  • inference 문제 2가지 종류
  1. Marginal inference : what is the probability of a given variable in our model after we sum everything else out?
  2. $$p(y=1) = \sum_{x_1}\sum_{x_2}\sum_{x_3}\dots\sum_{x_n}p(y=1,x_1,x_2,x_3,\dots,x_n)$$
  3. Maximum a posteriori(MAP) inference : what is the most likely assignment to the variables in the model(possibly conditioned on evidence)?
  4. $$\max_{x1,\dots,x_n} p(y=1,x_1,x_2,x_3,\dots,x_n)$$

변수 간 dependence가 표현된 directed acyclic graph $G$ 를 undirected graph로 변환한 뒤, 정해진 변수 간 order에 의해 변수들의 cluster를 하나의 single node로 구성하고, 특정 규칙 아래(변수 간 dependence가 잘 반영될 수 있도록), cluster간 edge를 연결하여 Junction Tree $\tau_{G}$=($\nu$, $\varepsilon$), $\nu$ : nodes, $\varepsilon$ : edges 를 구축합니다.

변수간 order에 의해 cluster를 구성해 나가면서 tree를 구축하는 건 variable elimination과 관련이 있습니다.

구축된 tree는 cycle-free이고, 변수들의 cluster는 graph 상에서 complete graph(clique)를 이루고 있어야 합니다. 또한 서로 이웃된 cluster 간에 공통되는 변수들이 있을 때, cluster 간 연결된 path 위에 해당 변수들로 구성된 cluster가 있어야 합니다. Junction Tree는 아래 그림에서와 같이 세가지 특성을 가져야 합니다.

그림. 4. Junction Tree Properties

belief propagation as message passing

Junction tree $\tau_{G}$ 를 가지고, inference problem을 푸는 방법 중 하나가 message-passing algorithm을 이용한 belief propagation입니다. 아래 그림과 같이 variable elimination을 통해 marginal inference 문제를 해결해 나갈 수 있습니다. 이 때, 정해진 변수 순서에 따라 summing out 되면서 변수가 순차적으로 제거됩니다(marginalized out).

그림 5
아래 그림과 같이 variable elimination이 되는 과정이 마치 tree 상에서, 한 node가 marginalization이 되면서 연결되어 있는 다른 노드로 message를 전달하는 과정으로 볼 수 있습니다. 아래 그림에서, $x_1$ 을 summing out에서 제거하기 위해선 우선적으로 $x_2$ 가 summing out되어 $x_1$ 으로 message인 $m_{21}(x_1)$ 이 전달되어야 합니다. 마찬가지로, 우리가 구하고 싶은 marginalized distribution인 $p(x_3)$ 를 구하기 위해선 $x_1, x_4, x_5$ 에서 오는 message가 모두 올 때까지 기다렸다가 계산을 할 수 있습니다.

그림 6

$i$ node 에서 $j$ node로 가는 message $m_{i \rightarrow j}$ 는 아래와 같이 정의할 수 있습니다. $$m_{i \rightarrow j} = \sum_{x_i}\phi(x_i)\phi(x_i,x_j)\prod_{l \in N(I) \setminus j}m_{l \rightarrow j}(x_i)$$

수식. belief propagation

$i$ node 에서 $j$ node로 가는 message $m_{i \rightarrow j}$ 는 **j node를 제외하고 i로 가는 모든 node의 메세지를 기다렸다가** i node와 연관된 distribution function을 다 계산한 후 summing out하는 것입니다. 특정 node와 연결된 message가 모두 올 때까지 기다렸다가 계산하는 방식을 belief propagation이라 합니다. Loopy belief propagation은 이를 기다리지 않고 계산하고, 모든 노드에서 수렴할 때까지 반복하여 inference 문제를 해결하는 방식입니다. $$m_{i \rightarrow j}^{t+1} = \sum_{x_i}\phi(x_i)\phi(x_i,x_j)\prod_{l \in N(I) \setminus j}m_{l \rightarrow j}^{t}(x_i)$$

수식. Loopy belief propagation

message passing network

[Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212)은 기존에 존재하는 그래프 모델들을 message passing algorithm을 학습하는 모델로 다시 해석하였습니다. 아래와 같이 세가지 함수를 정의하여 그래프 모델들을 분자의 화학적 특성을 예측하는 등 Quantum Chemistry에 적용하는 연구를 하였습니다. - A Message Passing function : $m_{v}^{t+1}=\sum_{w \in N(v)}M_t(h_v^t,h_u^t,e_{uv})$ - A Node Update function : $h_{v}^{l+1}=U_t(h_v^t,m_v^{t+1})$ - A Readout function(ex. classification) : $\hat y = R({h_v^T|v \in G})$ 즉, 한 원자의 특성을 결정짓는 건 원자와 연결된 다른 원자로 부터 오는 정보와 자기 자신에 의해 결정됨을 의미합니다. 본 논문에서도 위와 같은 아이디어를 사용하였습니다.

Tree Decomposition of Molecules

Molecule Junction Tree 는 junction tree $\tau_{G} = (\nu, \varepsilon)$ 에서 $\chi$ 가 추가된 $\tau_{G} = (\nu, \varepsilon, \chi)$ 입니다. $\chi$ 는 junction tree의 node 후보가 될 수 있는 화합물 구성 단위들의 집합 사전을 나타냅니다. 화합물 단위 사전은 ring결합으로 이뤄진 화합물(ex. aromatic compound), 일반 결합(?)(a single edges ex. single bond, double bond, triple bond..)으로만 구성됩니다. 여기서 사용된 집합 사전의 크기 |$\chi$|=$780$ 입니다. 여기서, 집합 사전의 크기가 한정적이기 때문에, 다양한 종류의 분자를 표현하는 것이 가능한 것에 대해 의문이 들 수 있습니다. 본 논문에서는 training set에 존재하는 화합물들 기반으로 분자 집합 사전을 구축했으며, test set에 있는 분자들을 대부분 커버했기 때문에 크게 문제 삼지 않고 넘어갔습니다.

tree-decomp

그림 7. Tree Decomposition of Molecules

위 그림은 tree decomposition을 나타낸 그림입니다. 집합 사전 $\chi$ 를 가지고, cycle구조와 edge 구조로 분해합니다. cycle은 고리형 화합물이고, edge는 단순 결합으로 이뤄진 화합물입니다. 위 그림에서 색칠된 노드가 화합물 단위로 분해된 것을 가르킵니다.

Graph Encoder

먼저, molecular graph 표현을 graph message passing network을 통해 인코딩합니다. Graph에서 원자에 해당하는 vertex는 특징 벡터 $\mathrm x_{v}$(원자 종류, 공유가 등과 같은 특성을 나타내는 벡터), 원자 간 결합을 나타내는 edge $(u,v) \in E$ 는 특징 벡터 $\mathrm x_{uv}$(결합 종류) 로 표현합니다. 또한 message passing algorithm에 의해 두 원자 간 주고받는 message hidden vector $\nu_{uv}$(u에서 v로 가는 message) 와 $\nu_{vu}$(v에서 u로 가는 message)로 표현합니다. Graph encoder에서 message가 전달되는 방식은 loopy belief propagation을 따릅니다. 아래 식에 의하면, 한 원자의 hidden vector는 결국 자신의 특징과 더불어 자신과 결합하고 있는 원자들로부터 오는 정보로 표현되는 것입니다. $$\nu_{uv}^{(t)} = \tau(W_1^g\mathrm x_u + W_2^g\mathrm x_{uv} + W_3^g\sum_{w \in N(u) \setminus v} \nu_{wu}^{(t-1)})$$ $$\mathbf {h}_u=\tau(\mathrm U_1^g\mathrm x_u + \sum_{v \in N(u)}\mathrm U_2^g\nu_{vu}^{(T)})$$ 한 분자에 대하 최종적인 표현은 $\mathbf h_G = \sum_i \mathbf h_i/|V|$ 로, 분자에 포함된 모든 원자들의 hidden vector들을 평균낸 것입니다. 그런 다음, VAE처럼 평균 $\mu_G$ 와 분산 $\sigma_G$ 를 계산하는 레이어가 각각 연결되고, 잠재 벡터 $z_G$ 는 $\mathcal N(\mu_G, \sigma_G)$ 에서 샘플링됩니다.

Tree Encoder

Tree Encoder도 message passing network방식으로 분자의 junction tree 표현을 인코딩합니다. 각 클러스터 $C_i$ 는 해당 라벨에 대한 정보를 담고 있는 원핫 벡터 $\mathrm x_i$ 로 표현되고, 클러스터 $(C_i,C_j)$ 간 주고받는 메세지 정보는 $\mathbf m_{ij}, \mathbf m_{ji}$ 로 표현합니다. Tree에서,임의로 root node 를 정하고 메세지는 GRU unit을 이용해 $\mathbf m_{ij} = \mathrm GRU(\mathbf x_i, {\mathbf m_{k \in N(i) \setminus j}})$ 와 같이 업데이트됩니다. $$\mathbf s_{ij} = \sum_{k \in N(I) \setminus j}\mathbf m_{kj}$$ $$\mathbf z_{ij} = \sigma(\mathbf W^z \mathbf x_i +\mathbf U^z\mathbf s_{ij}+\mathbf b^z)$$ $$\mathbf r_{kj} = \sigma(\mathbf W^r \mathbf x_i +\mathbf U^r \mathbf m_{ki}+\mathbf b^r)$$ $$\mathbf {\tilde m_{ij}} = tanh(\mathbf W \mathbf x_i+\mathbf U \sum_{k \in N(I) \setminus j}\mathbf r_{ki} \odot \mathbf m_{ki})$$ $$\mathbf m_{ij} = (1-\mathbf z_{ij}) \odot \mathbf s_{ij}+\mathbf z_{ij} \odot \mathbf {\tilde m_{ij}}$$ Graph Encoder와 다르게 Tree Encoder는 loopy belief propagation이 아니라, 특정 노드에서 다른 노드로 메세지를 전달하기 전에, 특정 노드와 연결된 노드들의 메세지가 다 올 때까지 기다리다가 전달하는 belief propagation 방식을 따릅니다. 특정 노드에서 message passing이 완료된 후, 해당 노드의 hidden feature $\mathrm h_i =\tau(\mathbf W^o \mathbf x_i+\sum_{k \in N(i)}\mathbf U^o \mathbf m_{kj})$ 로 계산됩니다. 즉 노드 i로 오는 모든 메세지와 노드 i의 label feature $x_i$ 를 이용하여 hidden feature vector i 를 계산합니다. tree의 최종적인 표현 $\mathbf h_{\mathcal T}=\mathbf h_{root}$ 입니다. Graph encoder에서 평균값으로 계산한 것과는 달리, root node의 hidden feature를 최종 표현으로 둡니다. 그 이유는 tree decoder에서 tree를 생성할 때, 어느 노드에서 시작할 지에 대한 정보가 있어야 하기 때문입니다. 다음으로 graph encoder와 마찬가지로 $\mu_\mathcal T$ 와 $\sigma_\mathcal T$ 를 출력하는 레이어가 각각 연결되고, $\mathcal N(\mu_\mathcal T, \sigma_\mathcal T)$ 에서 latent vector $z_\mathcal T$ 를 샘플링합니다.

Tree Decoder

Tree Decoder를 이용해 $z_\mathcal T$ 를 통해 junction tree $\hat {\mathcal T}$ 를 생성합니다. root node부터 시작해서 top-down 방식과 깊이 우선 탐색(depth-first order) 순서로 나머지 node들을 생성해 나갑니다. 깊이 우선 탐색이란, 루트 노드에서 시작해서 다음 분기로 넘어가기 전에 해당 분기를 완벽하게 탐색하는 방법입니다. 노드를 방문할 때마다 두가지 일을 수행합니다.
  1. topological prediction : 해당 노드가 자식 노드를 가지고 있는지에 대한 여부
  2. label prediction : 해당 클러스터의 라벨 예측(화합물 단위 집합 사전에 있는 라벨 예측)

만약 자식 노드가 더 이상 없다면 해당 노드를 탐색하기 직전 노드로 거슬러 올라갑니다. 각 노드를 방문하여 두가지 일을 수행하기 위해선 연결된 다른 노드로부터 정보를 받아야 합니다. Tree decoder에서 전달되는 정보는 message vector $\mathbf h_{ij}$ 를 이용합니다. tree의 node별로 하나씩 순서대로 생성해 나가면서 방문하는 edge들마다 번호를 매기면 최대 번호는 엣지갯수 x 2 가 됩니다. Molecular junction tree $\mathcal T=(\mathcal V, \mathcal E)$ 에 대해, 처음 시작해서 t step이 될 때까지 방문한 엣지들을 $\tilde {\mathcal E}$ = { $(i_1, j_1), \dots,(i_t,j_t)$ } 라 하고, t step일 때 방문한 노드를 $i_t$ 라 한다면, 노드 i에서 노드j로 가는 메세지 $\mathbf h_{i_t, j_t}$ 는 i 노드로 향하는 메세지와 t step에서의 노드 특징 벡터(여기서는 라벨 벡터)에 의해 GRU unit을 이용해 업데이트 됩니다. $$\mathbf h_{i_t, j_t} = \mathrm {GRU}(\mathbf x_{i_t}, \{\mathbf h_{k,i_t}\}_{(k,i_t) \in \mathcal {\tilde E}_t, k \neq j_t})$$

Topological Prediction & Label Prediction

노드 i에 방문했을 때, 자식 노드 j 존재 여부는 아래와 같이 확률을 계산하여 판단합니다. $$p_t = \sigma (\mathbf u^d \bullet \tau(\mathbf W_1^d\mathbf x_{i_t}+\mathbf W_2^d\mathbf z_{\tau}+\mathbf W_3^d\sum_{(k,i_t) \in \mathcal {\tilde E_t}}\mathbf h_{k,i_t}))$$ 자식 노드 j가 있다면, 자식 노드 j의 라벨 예측은 아래와 같습니다. $\mathbf q_j$ 는 화합물 단위 집합 사전에 대한 분포를 나타냅니다. 라벨 예측 후, 자식 노드 j의 특징벡터 $\mathbf x_j$ 는 분포 $\mathbf q_j$ 에서 샘플링됩니다. 샘플링 시, 부모노드와 연결되는 자식노드가 유효하지 않은 화합물 단위가 오면 안되기 때문에 미리 invalid한 화합물 단위들은 분포에서 masking을 하고 샘플링을 진행합니다. $$\mathbf q_j =\mathrm {softmax(\mathbf U_\tau^l(\mathbf W_1^l\mathbf z_{\tau}+ \mathbf W_2^l\mathbf h_{ij}))}$$ Tree decoder가 작동하는 알고리즘과 자세한 설명은 아래와 같습니다.

그림 8

Tree decoder의 목표는 우도 $p(\mathcal T|\mathbf z_{\mathcal T})$ 를 최대화하는 것입니다. 따라서 아래와 같이 크로스 엔트로피 손실함수를 최소화하는 방향으로 학습이 이뤄집니다. $$L_c(\mathcal T) = \sum_tL^d(p_t, \hat p_t) + \sum_jL^l(\mathbf q_j, \mathbf {\hat q_j})$$ 또한 teacher forcing을 이용하여 학습합니다. Teacher forcing이란 매 스텝마다 prediction한 후, 추 후 해당 스텝의 값을 이용할 때 prediction 값이 아니라 ground truth을 이용하는 것입니다.

Graph Decoder

JT-VAE는 마지막으로 graph decoder를 통해 molecular graph를 생성합니다. 그러나 하나의 molecular junction tree는 하나의 molecular graph에 대응하는 것이 아니라 화합물 단위인 subgraph를 어떻게 조합하느냐에 따라 여러 개의 molecular graph를 나타낼 수 있습니다. junction tree의 edge는 단순히 subgraph들의 상대적인 배열만을 나타낸다고 했습니다. 이렇기 때문에 Graph Decoder의 역할은 올바른 molecular graph를 만들기 위해 subgraph를 잘 조합하는 것입니다. $$\hat G = argmax_{G' \in \mathcal g(\mathcal {\hat T})}f^a(G')$$ Tree Decoder에서 root node에서 하나씩 node를 붙여나가듯이, 마찬가지로 subgraph를 하나씩 붙여나가는 것입니다. 그러나 이 때, subgraph를 붙여나갈 때 여러가지 경우의 수가 나오기 때문에 scoring function $f^a$ 을 이용해서 각 경우의 수에 대해 점수를 매깁니다. 가장 높은 점수가 나온 subgraph 간 조합을 두 subgraph간 조합으로 보고 다음 subgraph를 붙여나갑니다. subgraph를 붙여나가는 순서는 tree decoder에서 디코딩된 노드 순을 따릅니다. 그림 8의 예를 보면, 생성되는 tree node 순서에 따라, subgraph를 1->2->3->4->5 순으로 붙여나가는 것입니다.

$G_i$ 를 특정 sub graph cluster인 $C_i$ 와 그것의 neighbor clusters $C_j, j \in N_{\mathcal {\hat T}}(i)$ 을 조합해서 나온 그래프라 한다면, $G_i$ 에 대한 점수는 $f^a_i(G_i) =\mathbf h_{G_i}\bullet\mathbf z_G$ 입니다. $\mathbf h_{G_i}$ 는 subgraph $G_i$ 에 대한 hidden vector representation 입니다. 즉, Graph decoder의 역할은 조합해 가면서 나오는 각 subgraph에 대하여 hidden vector representation을 message passing algorithm을 통해 구하는 것입니다. Junction tree message passing network는 아래와 같습니다. $$\mu_{uv}^t = \tau(\mathbf W_1^a\mathbf x_u + \mathbf W_2^a\mathbf x_{uv} + \mathbf W^a_3 \tilde \mu_{uv}^{(t-1)})$$ $$\tilde \mu_{uv}^{(t-1)} = \begin{cases} \sum_{w \in N(u) \setminus v}\mu_{wu}^{(t-1)} & \quad \alpha_u = \alpha_v \\ \mathbf {\hat m}_{\alpha_u,\alpha_v} + \sum_{w \in N(u) \setminus v}\mu_{wu}^{(t-1)} & \quad \alpha_u \neq \alpha_v \end{cases}$$ 위 수식을 보면, message 계산 과정이 graph encoder와 비슷합니다. 하나 차이점은 u 원자와 v 원자가 다른 cluster라면 즉, 다른 subgraph라면 전달되는 메세지에 다른 subgraph에서 온 것을 추가적으로 알려준다는 점입니다(provides a tree dependent positional context for bond (u, v)). 이 때 메세지 $\mathbf {\hat m}_{\alpha_u,\alpha_v}$ 는 sub-graph $G_i$ 를 graph encoder를 통과시켜 계산된 메세지입니다. Graph decoder 학습은 $\mathcal {L_g}(G) = \sum_i [f^a(G_i) - log \sum_{G' \in \mathcal g_i} exp(f^a(G'_i))]$ 을 최대화하는 과정입니다. Tree decoder와 마찬가지로 teacher forcing을 이용해 학습합니다.
Complexity
Tree decomposition에 의해, 두 클러스터 간에 공유되는 원자 갯수가 최대 4개이며 또한 tree decoder 과정에서 invalid한 화합물 단위가 나오지 않도록 masking을 통해 sampling 하기 때문에 복잡도는 그리 높지 않습니다. 따라서 JT-VAE의 계산복잡도는 molecular graph의 sub-graph 수의 비례합니다.

Comments