research blog for data science

Reinforcement Learning 소개[1]

|

이번 포스팅은 강화학습이 기존에 알려진 여러 방법론들과의 비교를 통한 강화학습 특성과 구성요소를 다룹니다. CS234 1강, Deep Mind의 David Silver 강화학습 강의 1강, Richard S. Sutton 교재 Reinforcement Learning: An Introduction의 Chapter 1 기반으로 작성하였습니다.


introRL

그림 1.

아래 그림과 같이, Computer Science, Engineering, Mathematics 등 다양한 분야에서 여러 문제들을 풀기 위한 방법론들이 있습니다. 예를 들어, Pyschology 분야의 Classical/Operant Conditioning은 동물들의 의사결정에 대해 연구하고, Enginnering 분야의 Optimal Control와 Mathematics 분야의 Operation Research는 자연 현상을 일련의 시퀀스로 파악하여 공학적 또는 수학적 관점에서 ‘어떻게 하면 최상의 결과를 얻을까?’ 를 연구합니다. 즉, 이러한 연구들의 공통점은 각 분야에서 정의한 문제 해결을 위해 과학적 의사결정(scientific decision making)을 연구한다는 것입니다.

이와 마찬가지로, 강화학습도 “좋은 의사결정을 내리기 위한 방법”에 관한 연구입니다. 특히, “순차적인 의사결정이 필요한 문제”를 풀기 위한 방법론입니다(Learn to make good sequences of decisions). 여기서, “good decisions”은 결국 최적의 해결책(optimal soltuion)을 찾는 것을 의미하고, “learn”은 학습하는 대상이 처한 상황이 어떤지 모른 채, 직접 부딪혀 나가면서 경험을 통해 배워나가는 것을 의미합니다. 이는 마치, 사람이 학습해 나가는 방법과도 유사하죠.

Sutton 교재에서, 강화학습이 사람의 학습 방법과 유사하다고 기술되어 있습니다. 유아기 때, 걷기까지 걷는 방법을 알려주는 선생님이 존재하지 않고, 아기가 스스로 여러번 시도와 실패 끝에 걷는 방법을 터득합니다. 강화학습도 이러한 측면에서의 특성을 가지고 있습니다.

Characteristics of Reinforcement Learning

강화학습의 특징을 다른 방법론과 비교를 통해 알아봅시다. 강화학습의 특징을 우선 정리하고, 그 후 비교를 통해 구체적으로 알아볼 것입니다. 강화학습은 아래와 같이 4가지 특징을 가지고 있습니다.

  • Optimization
  • Delayed consequences
  • Exploration
  • Generalization

1. Optimization
good decision이란 최적의 해결책(optimal solution)에 해당된다고 하였습니다. 즉, Optimization은 강화학습의 목적에 해당되며, 그 목적은 좋은 결정을 내리기 위한 최적의 방법을 찾는 것입니다.

Goal is to find an optimal way to make decisions

2. Delayed Consequences
순차적인 의사결정 문제에서, 현재 내린 결정은 후에 일어날 상황에 영향을 줄 수 있습니다. 예를 들어, 돈을 저축하는 건 현재 시점에선 마이너스 행위일 수도 있지만, 만기 이후를 생각하면 플러스 행위입니다. 즉, 현재 내린 결정에 대한 영향력을 확실히 알 수가 없고(delayed consequences), 이로 인해 결정의 좋고 나쁨을 평가하는 것이 어렵습니다.

3. Exploration
위에서, 강화'학습'은 에이전트(학습하는 사람, 기계 등을 지칭)가 학습하는 상황/대상에 대한 어떠한 정보가 없기 때문에 스스로 배워나가는 것이라 하였습니다. 따라서, 에이전트는 무수히 많은 의사결정을 통해 탐험을 해야합니다. 자전거를 타는 기술을 익히기 위해 수많은 실패를 하는 것처럼 말입니다. 그러나, 이 '탐험'도 '잘'해야 합니다. 어떠한 탐험을 하느냐에 따라 경험하는 것이 다르기 때문입니다. 이렇게 얻어진 경험이 좋은 경험일 수도 나쁜 경험일 수도 있습니다.

4. Generalization
여러 머신러닝 방법론과 마찬가지로, 강화학습은 특정 문제만 풀수 있는 에이전트가 아니라 일반화된 문제를 풀 수 있는 에이전트를 학습하고 싶습니다. 바둑을 예로 들어봅시다. 강화학습을 통해 바둑게임 에이전트를 만들고자 할 때, 대전을 하는 상대방이 어떠한 전략을 가지고 있던 간에 항상 이길 수 있는 에이전트를 만드는 것이 목표지 특정 전략에만 강한 에이전트를 만들고 싶은 것이 아닙니다.

위와 같은 이유로, rule-based 방식으로 순차적 의사결정문제를 풀기가 어렵습니다. rule-based 기반 해결책은 generalization 특성을 갖지 못하기 때문입니다.

pre-programmed policy is hard to get generalization on the problem we want to tackle.

‌ 강화학습은 위와 같이 4가지 특성을 가지고 있습니다. 이제 여러 방법론(Planning, supervised learning, unsupervised learning, imitation learning)과 비교를 통해 위 4가지 특성에 대해 강화학습이 다른 방법론과 어떻게 다른지 알아 봅시다.

AI Planning vs Reinforcement Learning

planning

그림 2. planning

Planning이란 에이전트가 학습하는 환경에 대한 정보를 완벽히 알고 있는 경우입니다. 그림 2.는 유명한 아타리 게임입니다. 만약에 아타리 게임의 에이전트가 게임 콘솔 안의 모든 게임 알고리즘과 하드웨어 작동 방식등 완벽하게 알고 있다고 해봅시다.(거의 불가능한 상황이긴 합니다.)

Agent just computes good sequence of decisions but given model of how decisions impact world

이 말은 에이전트가 현재 게임 상황에서 왼쪽/오른쪽 움직임에 대해 나올 결과를 완벽히 알 수 있다는 뜻입니다. 즉, 에이전트는 더이상 ‘학습’이 아니라 어떻게 의사 결정을 내릴지 ‘계획’하면 되는 것이지요. 따라서, planning은 4가지 특성 중, exploration은 해당되지 않습니다.

  • Optimization
  • Delayed consequences
  • Exploration
  • Generalization

Supervised/Unsupervised Learning vs Reinforcement Learning

지도학습은 라벨이 있는 데이터 셋 ${{(x_1,y_1), \dots ,(x_i,y_i))}}$ 을 학습 한 후, 입력 $x_i$ 가 들어오면 입력에 대한 라벨 $\hat y_i$ 를 예측하는 문제입니다. 반면에 비지도 학습은 라벨이 없는 데이터 셋에 대하여 ${{x_1, \dots ,x_i}}$ 에 대하여 학습을 통해 데이터 셋의 구조를 파악하는 것입니다. 강화학습과의 차이점은 데이터 셋의 유무에 있습니다. 강화학습은 어떻게 보면, 탐험을 통해 데이터 셋을 스스로 구축해 나가는 것이라 볼 수 있지만, 지도/비지도 학습은 이 경험에 해당되는 데이터 셋이 주어진 것이기 때문에, 탐험을 할 필요가 없습니다. 또다른 차이점은 의사결정에 해당되는 라벨 예측 행위가 추후 또다른 예측 행위에 영향을 주지 않습니다. 따라서, exploration과 delayed consequences가 없습니다.

  • Optimization
  • Delayed consequences
  • Exploration
  • Generalization

Imitation Learning vs. Reinforcement Learning

둘의 차이점을 비교하기 전에, imitation learning이 뭔지 간략하게 알아봅시다. 좋은 의사결정을 내리기 위한 에이전트를 강화학습을 통해 만들기 위해선 에이전트가 탐험 시 내린 의사결정에 대한 좋고 나쁨을 알려줘야 합니다. 우리는 이를 ‘보상’이라 합니다. 그러나 현실 문제에서 정확한 보상함수를 정의내리기가 어렵습니다. 따라서 이를 해결하기 위해 나온 방법이 imitation learning입니다. 에이전트가 직접 탐험하는 것이 아니라 모방하고 싶은 에이전트의 행동을 지도학습 방식으로 해결하는 것입니다.

imitation learning

그림 3. Imitation Learning

따라서, imitation learning은 모방하고 싶은 에이전트의 경험을 데이터 셋으로서 활용하기 때문에 exploration 요소가 없습니다.

  • Optimization
  • Delayed consequences
  • Exploration
  • Generalization

The Reinforcement Learning Problem

이번 섹션에서는 강화학습 문제를 정의하기 위해 필요한 요소들을 알아봅시다.

Rewards & Sequential Decision Making

좋은 행동을 하는 에이전트를 강화학습을 통해 만들기 위해선 에이전트가 탐험할 때 결정한 행동에 대한 좋고 나쁨을 알려줘야 합니다. 우리는 이를 ‘보상’이라 합니다.

강화학습에서 '의사 결정'을 '행동'이라 부릅니다.

보상 $R_t$ 는 $t$ 스텝에서 에이전트가 의사결정을 잘 내리고 있는지에 대해 환경이 주는 즉각적인 피드백 지표(imediate reward)입니다. 즉, 에이전트의 목표는 매 스텝마다 받는 보상을 누적했을 때, 이 누적값이 최대화가 되도록 의사결정을 하는 것입니다. 이 말은 에이전트는 당장 받는 보상이 아니라 앞으로 받을 보상을 고려해서 행동한다는 것입니다. 이러한 아이디어는 가설로 구축할 수 있습니다.

reward hypothesis
That all of what we mean by goals and purposes can be well thought of as the maximization of the expected value of the cumulative sum of a received scalar signal.

예를 들어, 게임같은 경우 이기면 (+)보상을, 질 경우 (-)보상으로 보상을 정의내릴 수 있습니다. 또한 발전소를 컨트롤 하는 경우, 전력을 생산하는 경우엔 (+)보상을 줄 수 있지만 만약에 안전 임계치를 초과한 경우 (-)보상으로 정의내릴 수 있습니다.

sequential-decision-making

그림 4. Sequential Decision Making

따라서, 강화학습의 목표는 순차적인 의사결정을 통해 누적 보상의 기댓값을 최대화하는 것입니다. 현재 행동은 앞으로 할 행동들에 영향을 줄 수 있으며 더 많은 보상을 나중에 받기 위해 현재 당장 받을 즉각적인 보상은 포기할 수도 있습니다.

그러면 보상은 누가 주는 걸까요 ? 바로 에이전트가 놓여있는 환경입니다. 에이전트 이외 그 밖의 요소가 환경이 될 수 있습니다. 예를 들어, 주식시장을 살펴봅시다. 주식의 매매 여부를 결정하는 주체는 에이전트이고, 팔기, 사기, 그대로 두기는 행동(의사결정)입니다. 만약에 에이전트가 주식을 파는 행동을 하였다면, 이 행동의 좋고 나쁨은 어떻게 결정될까요? 우리는 흔히, 내가 판 주식이 올랐다면 이 행동은 나쁜거라고 생각할 수 있습니다(왜냐면, 더 있다가 팔면 좋았을 테니깐요.). 그럼 내가 판 주식이 오르게 하는 건 어떤걸까요? 바로 주체 이외의 여러 요소들에 의해 결정됩니다. 그 주식과 관련된 여러 기업일 수도 있고, 정치도 해당될 수 있고, 여러 가지 요소가 주식에 영향을 줍니다. 바로 이러한 부분이 강화학습에서의 ‘환경’입니다.

예상할 수 있듯이, 환경을 정의내리기엔 매우 어렵습니다.

따라서, 강화학습의 목표를 다시 정의내리자면, 주체는 “환경과의 상호작용”을 통해서 누적 보상을 최대화하는 것입니다. 그림 4와 그림5는 주체가 환경과의 상호작용하는 일련의 과정을 보여줍니다.

agent-env-interaction

그림 5. 주체와 환경간의 상호작용

매 스텝에서, 에이전트가 누적보상이 최대가 되도록 행동 $A_t$ (action) 을 결정하면, 환경은 선택한 행동에 대한 결과 $O_{t+1}$ (observation) 와 행동에 대한 실제 보상 $R_{t+1}$ (imediate reward) 을 알려줍니다. 주식시장을 다시 예로 들어보면, 주체가 주식을 파는 것이 행동 결정이고, 추후에 그 기업의 주식이 오르는 것이 행동에 대한 결과이며, 이에 대해 돈을 잃는 것이 실제 보상입니다.

History and State

연속적인 의사 결정(sequential decision making)이기 때문에, 매스텝마다 행동 $A$, 관찰 $O$, 보상 $R$이 발생합니다. 따라서 발생한 모든 행동, 관찰, 보상에 대한 시퀀스를 히스토리(history)라고 합니다.

[H_t = O_1, R_1, A_1, \dots, A_{t-1}, O_t, R_t]

The history is the sequence of observations, actions, rewards. In other words, It is all observable variables up to time t

따라서, 에이전트는 히스토리 기반으로 다음에 취할 행동을 선택합니다. 왜냐하면, 히스토리는 이전에 발생한 모든 일들을 다 기록하기 때문에 꽤나 행동 선택에 꽤나 괜찮은 근거가 될 수 있기 때문입니다.

그러나, 행동을 선택할 때마다, 매번 이전 과거 정보를 파악하는 건 힘든 일입니다. 따라서 에이전트는 다음 행동을 선택하는데 상태(state)정보를 이용합니다. 상태정보가 행동 선택의 근거가 되기 위해선 상태는 과거 히스토리 정보를 담고 있어야 합니다. 따라서, 수학적으로 표현하면 상태는 히스토리의 함수입니다.

[S_t = f(H_t)]

State is information assumed to determine what happens next

상태에는 크게 Environment State(World State), Agent State, Information State가 있습니다.

Environment State(World State)

environment state

그림 6. Environment State

환경 상태 $S^e_t$ 는 주체 이외의 환경에 대한 상태로, 예를 들어 게임에서는 게임의 콘솔 내부일수도 있고, 주식시장에서는 주식에 영향을 주는 모든 요소일수 있습니다. 에이전트는 사실 환경을 볼 수 없으며 볼 수 있다 하더라도 불필요한 정보들이 많을 것입니다.

Agent State

agent state

그림 7. Agent State

에이전트 상태 $S^a_t$ 는 행동하는 주체의 상태를 표현한 것입니다(the agent’s internal representation). 에이전트는 이 상태를 기반으로 다음 행동을 선택하고, 강화학습 시 사용되는 상태 정보입니다. 또한, 히스토리의 함수 $S_t^a=f(H_t)$ 로 표현될 수 있습니다.

Information State
정보 상태(information state)는 마코브 성질을 가지는 마코브 상태입니다.

그림 8. Markov State

$S_t$ 가 주어졌을 때 $S_{t+1}$ 의 확률은 $t$ 시점까지의 모든 상태가 주어졌을 때 $S_{t+1}$ 의 확률과 같으면 마코브 상태입니다. 즉, 현재상태 이전의 과거정보들은 미래정보에 대해 아무런 영향을 주지 않는다는 것입니다. 그 이유는 이미 현재 상태는 과거 정보를 충분히 포함하고 있기 때문에, 이 정보만으로 미래를 파악하기에 충분하다는 것입니다.

The state is sufficient statistic of the future.
The future is independent of the past given the present $H_{1:t} \to S_t \to H_{t+1:\infty}$

마코브 상태로는 환경상태 $S^e_t$ 와 히스토리 $H_t$ 입니다. $S^e_t$ 는 주체한테 미치는 영향을 모두 포함하고 있기 때문에 마코브 상태이고 마찬가지로 $H_t$도 관찰가능한 일련의 모든 시퀀스를 포함하고 있기 때문에 역시 마코브 상태입니다.

MDP and POMDP

강화학습 문제를 정의하기 위해서, 상태, 행동, 보상에 대한 정의가 필요합니다. 하지만 상태는 마코브 상태일수도 있고 아닐 수도 있습니다. 그럼 각각에 따라 강화학습 문제를 접근하는 방법도 달라집니다. 아래에서 더 살펴봅시다.

Fully Observable Environments
에이전트가 환경 상태를 직접적으로 관찰할 수 있을 때, 에이전트는 Fully Observability를 가집니다. 이는 결국 에이전트 상태가 환경상태와 동일한 경우입니다.

[O_t = S^a_t = S^e_t]

일반적으로, 에이전트가 Fully Observability를 가질 때, Markov Decision Process(MDP)를 따른다고 합니다.

아래와 같이 MDP의 예로 화성탐사 문제를 정의해봅시다.<p align='center'>

그림 9. 화성탐사기 예제

</p> 위의 그림과 같이, 화성탐사기가 도달할 수 있는 상태는 총 7가지 상태이고, 각 상태에서 취할 수 있는 행동은 왼쪽/오른쪽 두가지입니다. s1 상태에 있으면, +1보상을, s7상태에 있으면 +10보상을 받고, 나머지 상태에서는 0의 보상을 받습니다. 이처럼, 화성탐사기가 어느 상태에 있을 때 어떤 보상을 받을지 다 알고 있는 상황이 화성탐사기 에이전트가 환경상태임을 의미합니다.

Partially Observable Environments
에이전트가 환경을 간접적으로 관찰할 수 밖에 없을 때, 에이전트는 Partial observability를 가집니다. 예를 들어, 로봇은 카메라 센서로만 인식할 수 있는 장애물만 파악할 수 있습니다. 카드게임에서는 상대방의 카드패는 알 수 없고, 본인이 가진 카드만 알 수 있습니다.

[O_t = S^a_t \neq S^e_t]

이 경우 에이전트 상태는 환경 상태와 동일하지 않으며, 에이전트가 partially observable Markov Decision process(POMDP)를 따릅니다.

에이전트는 본인의 상태를 반드시 정의내려야 합니다. 전체 히스토리를 에이전트 상태로 둘 수도 있고, ‘환경에 대한 정보가 ~할 것이다’라는 믿음으로 구축할 수도 있습니다. 아니면, RNN을 이용하여 인코딩된 벡터로 에이전트 상태를 나타낼 수도 있습니다.

  • use history : $S_t^a = H_t$
  • Beliefs of environment state : $S_t^a = (P[S^e_t = s^1], \dots, P[s^e_t = s^n])$
  • Recurrent neural network : $S^a_t = \sigma(S^a_{t-1}W_s+O_tW_o)$

Deterministic ? Stochastic?

그림 10. Deterministic vs. Stochastic

강화학습 문제를 MDP인지 POMDP로 보는 관점말고도 deterministic한지 stochastic한지 보는 관점도 있습니다. deterministic한 강화학습 문제는 환경이 에이전트의 행동에 따라 변할 때 그 결과 오로지 하나의 결과만 보여줍니다(single observation and reward). 그러나, stochastic한 강화학습 문제는 에이전트의 행동에 따라 환경이 변할 때, 가능성 있는 여러 결과를 보여줍니다. 물론 그 결과가 나올 확률과 함께 말이죠.


이번 포스팅은 여기까지 마치겠습니다. Reinforcement Learning 소개(2)에 이어서 포스팅하도록 하겠습니다.


  1. CS234 Winter 2019 course Lecture 1
  2. Richard S. Sutton and Andre G. Barto : Reinforcement Learning : An Introduction
  3. David Silver Lecture 1
  4. Imitation learning : a brief over view of imitation learning, https://medium.com/@SmartLabAI/a-brief-overview-of-imitation-learning-8a8a75c44a9c

Junction Tree Variational Autoencoder for Molecular Graph Generation 논문 리뷰

|

Junction Tree Variational Auto-encoder(JT-VAE)는 기존 SMILES string 기반 생성 모델들이 SMILES string을 사용하는 것에 문제를 제기하여 캐릭터가 아니라 molecular graph가 직접 입력으로 들어가는 모델입니다. 또한, 유효한 화합물 구조를 생성하기 위해 Junction Tree Algorithm에서 아이디어를 착안하여 모델을 제시하였습니다.

Problem

SMILES string을 입력으로 하는 것은 크리티컬한 2가지 문제가 발생합니다. 먼저, SMILES 표현은 화합물간 유사도를 담아내지 못합니다.

similarity

그림 1

위 그림 1.을 보면, 두 화합물의 구조는 유사하지만 SMILES으로 나타냈을 땐 전혀 다른 표현이 됩니다. 따라서, SMILES 표현의 한계로 인해 VAE와 같은 생성모델들이 화합물 임베딩 공간을 올바르게 형성하지 못합니다. 두번째로, 그래프 형태의 화합물 표현이 SMILES 화합물 표현보다 분자의 화학적 특성을 더 잘 담아냅니다. 이러한 이유로, 본 논문은 그래프적 표현(molecular graph)을 직접적으로 사용하는 것이 유효한 구조를 만들어 내는 것을 향상시킬 것이라 가정을 하고 있습니다.

Junction Tree Variational Auto-Encoder

Molecular graph를 만든다는 건 일반적으로 원자를 하나씩 순차적으로 생성하는 것으로 생각할 수 있습니다(Li et al., 2018). 그러나, 본 논문에서는 이러한 접근법은 유효하지 않은 구조(chemically invalid)를 만들어 낼 가능성이 높다고 합니다. 원자를 하나씩 붙여나가면서 생성하면 중간 단계 구조들은 invalid하며, 완전한 구조가 나올때 까지 딜레이가 길기 때문입니다.

그림 2

따라서, 본 논문은 원자 단위로 molecular graph를 만들어 나가는 것이 아니라 유효한 분자 단위들의 집합을 미리 정해놓고 이 단위들을 붙여 나가면서 화합물을 구축합니다. 마치 자연어처리에서 문장 생성 문제를 풀 때, 사전을 미리 구축해 놓고 그 속에 존재하는 단어들로 문장을 구축해 나가는 것과 같이 생각하면 될 것 같습니다.

junction tree 라는 이름이 붙게 된 이유는 다음과 같습니다. 유효한 분자 단위는 molecular graph내의 sub-graph로 생각할 수 있고, 이 sub-graph는 이 그래프 자체로도 유효한 화학 분자 구성 요소를 이룹니다. 즉, 마치 junction tree의 node가 complete graph인 clique과 유사합니다. 이 부분에서 junction tree 아이디어를 착안한 것입니다

하나의 분자에 대해서 생성되는 방식은 다음과 같습니다. 어떤 유한한 개수의 유효한 화합물 단위(valid components)들로 구성된 집합에서 해당 분자를 구성할 것 같은 요소들을 선택한 후, 그 요소들을 가지고 제일 그럴듯한 구조가 나오도록 조합하는 것입니다. 이런 식의 접근의 장점은 Li et al.(2018)와 다르게 valid한 화합물을 생성하는 것을 보장할 수 있습니다. 또한 구성요소 간의 상호작용 관계도 고려되기 때문에 더 실용적인 방법입니다.

유효한 화합물 단위는 마치 building block과 같은 역할입니다.

그림 3

그림 3.은 Junction Tree VAE(JT-VAE) 모식도입니다. 첫번째 단계로 한 분자가 입력으로 들어오면, 화합물 단위 사전을 이용하여 Tree Decomposition을 수행합니다. 수행 결과, Junction Tree $\tau$ 가 나옵니다.

하나의 분자가 주어졌을 때, 한 분자는 2가지 종류의 표현을 가지게 됩니다. - Molecular graph 와 Junction Tree 표현

2가지 표현을 가지고 있는 것과 같이 Graph Encoder/Decoder와 Tree Encoder/Decoder로 구성됩니다. Molecular Graph는 Graph Encoder에 의해 $z_{G}$ 로 인코딩됩니다. 마찬가지로, Tree Encoder에 의해 Junction Tree $\tau$ 는 $z_{\tau}$ 으로 인코딩됩니다. 그런 후 Tree Decoder와 화합물 사전을 이용하여 제일 가능성이 높은 화합물 단위를 조합하여 Junction Tree $\hat{\tau}$ 를 생성합니다. 이 때, Junction Tree $\hat{\tau}$ 에서 화합물 단위간 연결(edge)는 화합물 간 결합 방향 및 결합 종류(단일결합, 이중결합 등등)에 관한 정보를 포함하고 있지 않고, 상대적인 배열에 관한 정보만을 담고 있습니다. 그 다음, graph decoder에 의해, Junction Tree $\hat{\tau}$ 와 $z_{G}$ 는 Molecular graph 표현으로 나타내지게 됩니다. 이 과정에서 화합물 단위 사이의 결합이 정해지게 됩니다. 화합물 단위 간 결합될 수 있는 후보들을 나열 한 후, 각 후보 군에 대한 결합점수를 매기고, 점수가 가장 높은 결합이 화합물 단위 사이의 결합으로 결정됩니다.

Junction Tree Algorithm

junction tree algorithm은 probabilistic graphical model에서 inference problem를 효율적으로 풀기 위한 알고리즘입니다.

  • inference 문제 2가지 종류
  1. Marginal inference : what is the probability of a given variable in our model after we sum everything else out?
  2. $$p(y=1) = \sum_{x_1}\sum_{x_2}\sum_{x_3}\dots\sum_{x_n}p(y=1,x_1,x_2,x_3,\dots,x_n)$$
  3. Maximum a posteriori(MAP) inference : what is the most likely assignment to the variables in the model(possibly conditioned on evidence)?
  4. $$\max_{x1,\dots,x_n} p(y=1,x_1,x_2,x_3,\dots,x_n)$$

변수 간 dependence가 표현된 directed acyclic graph $G$ 를 undirected graph로 변환한 뒤, 정해진 변수 간 order에 의해 변수들의 cluster를 하나의 single node로 구성하고, 특정 규칙 아래(변수 간 dependence가 잘 반영될 수 있도록), cluster간 edge를 연결하여 Junction Tree $\tau_{G}$=($\nu$, $\varepsilon$), $\nu$ : nodes, $\varepsilon$ : edges 를 구축합니다.

변수간 order에 의해 cluster를 구성해 나가면서 tree를 구축하는 건 variable elimination과 관련이 있습니다.

구축된 tree는 cycle-free이고, 변수들의 cluster는 graph 상에서 complete graph(clique)를 이루고 있어야 합니다. 또한 서로 이웃된 cluster 간에 공통되는 변수들이 있을 때, cluster 간 연결된 path 위에 해당 변수들로 구성된 cluster가 있어야 합니다. Junction Tree는 아래 그림에서와 같이 세가지 특성을 가져야 합니다.

그림. 4. Junction Tree Properties

belief propagation as message passing

Junction tree $\tau_{G}$ 를 가지고, inference problem을 푸는 방법 중 하나가 message-passing algorithm을 이용한 belief propagation입니다. 아래 그림과 같이 variable elimination을 통해 marginal inference 문제를 해결해 나갈 수 있습니다. 이 때, 정해진 변수 순서에 따라 summing out 되면서 변수가 순차적으로 제거됩니다(marginalized out).

그림 5
아래 그림과 같이 variable elimination이 되는 과정이 마치 tree 상에서, 한 node가 marginalization이 되면서 연결되어 있는 다른 노드로 message를 전달하는 과정으로 볼 수 있습니다. 아래 그림에서, $x_1$ 을 summing out에서 제거하기 위해선 우선적으로 $x_2$ 가 summing out되어 $x_1$ 으로 message인 $m_{21}(x_1)$ 이 전달되어야 합니다. 마찬가지로, 우리가 구하고 싶은 marginalized distribution인 $p(x_3)$ 를 구하기 위해선 $x_1, x_4, x_5$ 에서 오는 message가 모두 올 때까지 기다렸다가 계산을 할 수 있습니다.

그림 6

$i$ node 에서 $j$ node로 가는 message $m_{i \rightarrow j}$ 는 아래와 같이 정의할 수 있습니다. $$m_{i \rightarrow j} = \sum_{x_i}\phi(x_i)\phi(x_i,x_j)\prod_{l \in N(I) \setminus j}m_{l \rightarrow j}(x_i)$$

수식. belief propagation

$i$ node 에서 $j$ node로 가는 message $m_{i \rightarrow j}$ 는 **j node를 제외하고 i로 가는 모든 node의 메세지를 기다렸다가** i node와 연관된 distribution function을 다 계산한 후 summing out하는 것입니다. 특정 node와 연결된 message가 모두 올 때까지 기다렸다가 계산하는 방식을 belief propagation이라 합니다. Loopy belief propagation은 이를 기다리지 않고 계산하고, 모든 노드에서 수렴할 때까지 반복하여 inference 문제를 해결하는 방식입니다. $$m_{i \rightarrow j}^{t+1} = \sum_{x_i}\phi(x_i)\phi(x_i,x_j)\prod_{l \in N(I) \setminus j}m_{l \rightarrow j}^{t}(x_i)$$

수식. Loopy belief propagation

message passing network

[Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212)은 기존에 존재하는 그래프 모델들을 message passing algorithm을 학습하는 모델로 다시 해석하였습니다. 아래와 같이 세가지 함수를 정의하여 그래프 모델들을 분자의 화학적 특성을 예측하는 등 Quantum Chemistry에 적용하는 연구를 하였습니다. - A Message Passing function : $m_{v}^{t+1}=\sum_{w \in N(v)}M_t(h_v^t,h_u^t,e_{uv})$ - A Node Update function : $h_{v}^{l+1}=U_t(h_v^t,m_v^{t+1})$ - A Readout function(ex. classification) : $\hat y = R({h_v^T|v \in G})$ 즉, 한 원자의 특성을 결정짓는 건 원자와 연결된 다른 원자로 부터 오는 정보와 자기 자신에 의해 결정됨을 의미합니다. 본 논문에서도 위와 같은 아이디어를 사용하였습니다.

Tree Decomposition of Molecules

Molecule Junction Tree 는 junction tree $\tau_{G} = (\nu, \varepsilon)$ 에서 $\chi$ 가 추가된 $\tau_{G} = (\nu, \varepsilon, \chi)$ 입니다. $\chi$ 는 junction tree의 node 후보가 될 수 있는 화합물 구성 단위들의 집합 사전을 나타냅니다. 화합물 단위 사전은 ring결합으로 이뤄진 화합물(ex. aromatic compound), 일반 결합(?)(a single edges ex. single bond, double bond, triple bond..)으로만 구성됩니다. 여기서 사용된 집합 사전의 크기 |$\chi$|=$780$ 입니다. 여기서, 집합 사전의 크기가 한정적이기 때문에, 다양한 종류의 분자를 표현하는 것이 가능한 것에 대해 의문이 들 수 있습니다. 본 논문에서는 training set에 존재하는 화합물들 기반으로 분자 집합 사전을 구축했으며, test set에 있는 분자들을 대부분 커버했기 때문에 크게 문제 삼지 않고 넘어갔습니다.

tree-decomp

그림 7. Tree Decomposition of Molecules

위 그림은 tree decomposition을 나타낸 그림입니다. 집합 사전 $\chi$ 를 가지고, cycle구조와 edge 구조로 분해합니다. cycle은 고리형 화합물이고, edge는 단순 결합으로 이뤄진 화합물입니다. 위 그림에서 색칠된 노드가 화합물 단위로 분해된 것을 가르킵니다.

Graph Encoder

먼저, molecular graph 표현을 graph message passing network을 통해 인코딩합니다. Graph에서 원자에 해당하는 vertex는 특징 벡터 $\mathrm x_{v}$(원자 종류, 공유가 등과 같은 특성을 나타내는 벡터), 원자 간 결합을 나타내는 edge $(u,v) \in E$ 는 특징 벡터 $\mathrm x_{uv}$(결합 종류) 로 표현합니다. 또한 message passing algorithm에 의해 두 원자 간 주고받는 message hidden vector $\nu_{uv}$(u에서 v로 가는 message) 와 $\nu_{vu}$(v에서 u로 가는 message)로 표현합니다. Graph encoder에서 message가 전달되는 방식은 loopy belief propagation을 따릅니다. 아래 식에 의하면, 한 원자의 hidden vector는 결국 자신의 특징과 더불어 자신과 결합하고 있는 원자들로부터 오는 정보로 표현되는 것입니다. $$\nu_{uv}^{(t)} = \tau(W_1^g\mathrm x_u + W_2^g\mathrm x_{uv} + W_3^g\sum_{w \in N(u) \setminus v} \nu_{wu}^{(t-1)})$$ $$\mathbf {h}_u=\tau(\mathrm U_1^g\mathrm x_u + \sum_{v \in N(u)}\mathrm U_2^g\nu_{vu}^{(T)})$$ 한 분자에 대하 최종적인 표현은 $\mathbf h_G = \sum_i \mathbf h_i/|V|$ 로, 분자에 포함된 모든 원자들의 hidden vector들을 평균낸 것입니다. 그런 다음, VAE처럼 평균 $\mu_G$ 와 분산 $\sigma_G$ 를 계산하는 레이어가 각각 연결되고, 잠재 벡터 $z_G$ 는 $\mathcal N(\mu_G, \sigma_G)$ 에서 샘플링됩니다.

Tree Encoder

Tree Encoder도 message passing network방식으로 분자의 junction tree 표현을 인코딩합니다. 각 클러스터 $C_i$ 는 해당 라벨에 대한 정보를 담고 있는 원핫 벡터 $\mathrm x_i$ 로 표현되고, 클러스터 $(C_i,C_j)$ 간 주고받는 메세지 정보는 $\mathbf m_{ij}, \mathbf m_{ji}$ 로 표현합니다. Tree에서,임의로 root node 를 정하고 메세지는 GRU unit을 이용해 $\mathbf m_{ij} = \mathrm GRU(\mathbf x_i, {\mathbf m_{k \in N(i) \setminus j}})$ 와 같이 업데이트됩니다. $$\mathbf s_{ij} = \sum_{k \in N(I) \setminus j}\mathbf m_{kj}$$ $$\mathbf z_{ij} = \sigma(\mathbf W^z \mathbf x_i +\mathbf U^z\mathbf s_{ij}+\mathbf b^z)$$ $$\mathbf r_{kj} = \sigma(\mathbf W^r \mathbf x_i +\mathbf U^r \mathbf m_{ki}+\mathbf b^r)$$ $$\mathbf {\tilde m_{ij}} = tanh(\mathbf W \mathbf x_i+\mathbf U \sum_{k \in N(I) \setminus j}\mathbf r_{ki} \odot \mathbf m_{ki})$$ $$\mathbf m_{ij} = (1-\mathbf z_{ij}) \odot \mathbf s_{ij}+\mathbf z_{ij} \odot \mathbf {\tilde m_{ij}}$$ Graph Encoder와 다르게 Tree Encoder는 loopy belief propagation이 아니라, 특정 노드에서 다른 노드로 메세지를 전달하기 전에, 특정 노드와 연결된 노드들의 메세지가 다 올 때까지 기다리다가 전달하는 belief propagation 방식을 따릅니다. 특정 노드에서 message passing이 완료된 후, 해당 노드의 hidden feature $\mathrm h_i =\tau(\mathbf W^o \mathbf x_i+\sum_{k \in N(i)}\mathbf U^o \mathbf m_{kj})$ 로 계산됩니다. 즉 노드 i로 오는 모든 메세지와 노드 i의 label feature $x_i$ 를 이용하여 hidden feature vector i 를 계산합니다. tree의 최종적인 표현 $\mathbf h_{\mathcal T}=\mathbf h_{root}$ 입니다. Graph encoder에서 평균값으로 계산한 것과는 달리, root node의 hidden feature를 최종 표현으로 둡니다. 그 이유는 tree decoder에서 tree를 생성할 때, 어느 노드에서 시작할 지에 대한 정보가 있어야 하기 때문입니다. 다음으로 graph encoder와 마찬가지로 $\mu_\mathcal T$ 와 $\sigma_\mathcal T$ 를 출력하는 레이어가 각각 연결되고, $\mathcal N(\mu_\mathcal T, \sigma_\mathcal T)$ 에서 latent vector $z_\mathcal T$ 를 샘플링합니다.

Tree Decoder

Tree Decoder를 이용해 $z_\mathcal T$ 를 통해 junction tree $\hat {\mathcal T}$ 를 생성합니다. root node부터 시작해서 top-down 방식과 깊이 우선 탐색(depth-first order) 순서로 나머지 node들을 생성해 나갑니다. 깊이 우선 탐색이란, 루트 노드에서 시작해서 다음 분기로 넘어가기 전에 해당 분기를 완벽하게 탐색하는 방법입니다. 노드를 방문할 때마다 두가지 일을 수행합니다.
  1. topological prediction : 해당 노드가 자식 노드를 가지고 있는지에 대한 여부
  2. label prediction : 해당 클러스터의 라벨 예측(화합물 단위 집합 사전에 있는 라벨 예측)

만약 자식 노드가 더 이상 없다면 해당 노드를 탐색하기 직전 노드로 거슬러 올라갑니다. 각 노드를 방문하여 두가지 일을 수행하기 위해선 연결된 다른 노드로부터 정보를 받아야 합니다. Tree decoder에서 전달되는 정보는 message vector $\mathbf h_{ij}$ 를 이용합니다. tree의 node별로 하나씩 순서대로 생성해 나가면서 방문하는 edge들마다 번호를 매기면 최대 번호는 엣지갯수 x 2 가 됩니다. Molecular junction tree $\mathcal T=(\mathcal V, \mathcal E)$ 에 대해, 처음 시작해서 t step이 될 때까지 방문한 엣지들을 $\tilde {\mathcal E}$ = { $(i_1, j_1), \dots,(i_t,j_t)$ } 라 하고, t step일 때 방문한 노드를 $i_t$ 라 한다면, 노드 i에서 노드j로 가는 메세지 $\mathbf h_{i_t, j_t}$ 는 i 노드로 향하는 메세지와 t step에서의 노드 특징 벡터(여기서는 라벨 벡터)에 의해 GRU unit을 이용해 업데이트 됩니다. $$\mathbf h_{i_t, j_t} = \mathrm {GRU}(\mathbf x_{i_t}, \{\mathbf h_{k,i_t}\}_{(k,i_t) \in \mathcal {\tilde E}_t, k \neq j_t})$$

Topological Prediction & Label Prediction

노드 i에 방문했을 때, 자식 노드 j 존재 여부는 아래와 같이 확률을 계산하여 판단합니다. $$p_t = \sigma (\mathbf u^d \bullet \tau(\mathbf W_1^d\mathbf x_{i_t}+\mathbf W_2^d\mathbf z_{\tau}+\mathbf W_3^d\sum_{(k,i_t) \in \mathcal {\tilde E_t}}\mathbf h_{k,i_t}))$$ 자식 노드 j가 있다면, 자식 노드 j의 라벨 예측은 아래와 같습니다. $\mathbf q_j$ 는 화합물 단위 집합 사전에 대한 분포를 나타냅니다. 라벨 예측 후, 자식 노드 j의 특징벡터 $\mathbf x_j$ 는 분포 $\mathbf q_j$ 에서 샘플링됩니다. 샘플링 시, 부모노드와 연결되는 자식노드가 유효하지 않은 화합물 단위가 오면 안되기 때문에 미리 invalid한 화합물 단위들은 분포에서 masking을 하고 샘플링을 진행합니다. $$\mathbf q_j =\mathrm {softmax(\mathbf U_\tau^l(\mathbf W_1^l\mathbf z_{\tau}+ \mathbf W_2^l\mathbf h_{ij}))}$$ Tree decoder가 작동하는 알고리즘과 자세한 설명은 아래와 같습니다.

그림 8

Tree decoder의 목표는 우도 $p(\mathcal T|\mathbf z_{\mathcal T})$ 를 최대화하는 것입니다. 따라서 아래와 같이 크로스 엔트로피 손실함수를 최소화하는 방향으로 학습이 이뤄집니다. $$L_c(\mathcal T) = \sum_tL^d(p_t, \hat p_t) + \sum_jL^l(\mathbf q_j, \mathbf {\hat q_j})$$ 또한 teacher forcing을 이용하여 학습합니다. Teacher forcing이란 매 스텝마다 prediction한 후, 추 후 해당 스텝의 값을 이용할 때 prediction 값이 아니라 ground truth을 이용하는 것입니다.

Graph Decoder

JT-VAE는 마지막으로 graph decoder를 통해 molecular graph를 생성합니다. 그러나 하나의 molecular junction tree는 하나의 molecular graph에 대응하는 것이 아니라 화합물 단위인 subgraph를 어떻게 조합하느냐에 따라 여러 개의 molecular graph를 나타낼 수 있습니다. junction tree의 edge는 단순히 subgraph들의 상대적인 배열만을 나타낸다고 했습니다. 이렇기 때문에 Graph Decoder의 역할은 올바른 molecular graph를 만들기 위해 subgraph를 잘 조합하는 것입니다. $$\hat G = argmax_{G' \in \mathcal g(\mathcal {\hat T})}f^a(G')$$ Tree Decoder에서 root node에서 하나씩 node를 붙여나가듯이, 마찬가지로 subgraph를 하나씩 붙여나가는 것입니다. 그러나 이 때, subgraph를 붙여나갈 때 여러가지 경우의 수가 나오기 때문에 scoring function $f^a$ 을 이용해서 각 경우의 수에 대해 점수를 매깁니다. 가장 높은 점수가 나온 subgraph 간 조합을 두 subgraph간 조합으로 보고 다음 subgraph를 붙여나갑니다. subgraph를 붙여나가는 순서는 tree decoder에서 디코딩된 노드 순을 따릅니다. 그림 8의 예를 보면, 생성되는 tree node 순서에 따라, subgraph를 1->2->3->4->5 순으로 붙여나가는 것입니다.

$G_i$ 를 특정 sub graph cluster인 $C_i$ 와 그것의 neighbor clusters $C_j, j \in N_{\mathcal {\hat T}}(i)$ 을 조합해서 나온 그래프라 한다면, $G_i$ 에 대한 점수는 $f^a_i(G_i) =\mathbf h_{G_i}\bullet\mathbf z_G$ 입니다. $\mathbf h_{G_i}$ 는 subgraph $G_i$ 에 대한 hidden vector representation 입니다. 즉, Graph decoder의 역할은 조합해 가면서 나오는 각 subgraph에 대하여 hidden vector representation을 message passing algorithm을 통해 구하는 것입니다. Junction tree message passing network는 아래와 같습니다. $$\mu_{uv}^t = \tau(\mathbf W_1^a\mathbf x_u + \mathbf W_2^a\mathbf x_{uv} + \mathbf W^a_3 \tilde \mu_{uv}^{(t-1)})$$ $$\tilde \mu_{uv}^{(t-1)} = \begin{cases} \sum_{w \in N(u) \setminus v}\mu_{wu}^{(t-1)} & \quad \alpha_u = \alpha_v \\ \mathbf {\hat m}_{\alpha_u,\alpha_v} + \sum_{w \in N(u) \setminus v}\mu_{wu}^{(t-1)} & \quad \alpha_u \neq \alpha_v \end{cases}$$ 위 수식을 보면, message 계산 과정이 graph encoder와 비슷합니다. 하나 차이점은 u 원자와 v 원자가 다른 cluster라면 즉, 다른 subgraph라면 전달되는 메세지에 다른 subgraph에서 온 것을 추가적으로 알려준다는 점입니다(provides a tree dependent positional context for bond (u, v)). 이 때 메세지 $\mathbf {\hat m}_{\alpha_u,\alpha_v}$ 는 sub-graph $G_i$ 를 graph encoder를 통과시켜 계산된 메세지입니다. Graph decoder 학습은 $\mathcal {L_g}(G) = \sum_i [f^a(G_i) - log \sum_{G' \in \mathcal g_i} exp(f^a(G'_i))]$ 을 최대화하는 과정입니다. Tree decoder와 마찬가지로 teacher forcing을 이용해 학습합니다.
Complexity
Tree decomposition에 의해, 두 클러스터 간에 공유되는 원자 갯수가 최대 4개이며 또한 tree decoder 과정에서 invalid한 화합물 단위가 나오지 않도록 masking을 통해 sampling 하기 때문에 복잡도는 그리 높지 않습니다. 따라서 JT-VAE의 계산복잡도는 molecular graph의 sub-graph 수의 비례합니다.

Grammar Variational Autoencoder 논문 리뷰

|

Grammar Variational Auto-Encoder(GVAE)는 Gomez-Bomb barelli et al.(2016)1의 VAE기반으로 생성된 신약 후보 물질들이 대부분 유효하지 않는 것에 문제를 제기하여, SMILES string 생성 문법을 직접적으로 제약 조건으로 걸어 유효한 신약 후보 물질을 생성하는 모델입니다.

Problem

신약 후보 물질의 SMILES string 생성 모델들(RNN, VAE)의 단점은 유효하지 않은 string을 생성하는 경우가 많이 발생한다는 것입니다.

  • Valid String : c1ccccc1 (benzene)
  • Invalid String : c1ccccc2 (??)

Gomez-Bomb barelli et al.(2016)가 제안한 VAE는 decoder를 통해 연속적인 latent space에서 discrete한 SMILES string을 생성합니다. 그러나, 유효하지 않는 string에 대한 확률이 높아지도록 학습이 된 경우, 학습이 완료가 된 후에도 계속 올바르지 않은 SMILES string이 생성되는 문제가 발생합니다.

따라서 본 논문에서는 이러한 이슈를 완화하기 위해 SMILES string을 생성하는 문법에 관한 정보를 모델에게 직접적으로 알려줌으로써 유효한 SMILES string을 생성하도록 하는 모델(Grammar VAE)을 제안하였습니다.

Context-free grammars2

SMILES string을 생성하는 문법은 문맥 자유 문법(Context-free grammar, CFG)을 따릅니다. 다른 나라 언어를 이해하거나 그 나라 언어로 대화나 글을 쓰기 위해선 문법을 이해하고 있어야 합니다. 마찬가지로 프로그래밍 언어를 이해하기 위해선 그 언어를 정의한 문법을 이해하고 있어야 합니다. 대다수 프로그래밍 언어들이 CFG기반입니다. CFG은 촘스키 위계의 type-2에 해당하는 문법입니다.

그림1. 촘스키 위계

문맥 자유 문법은 $G=(V, \Sigma, R, S)$ 4개의 순서쌍으로 구성됩니다.

  • V : non-terminal 심볼들의 유한집합
  • $\Sigma$ : terminal 심볼들의 유한집합
  • R : 생성규칙(production rules)의 유한집합
  • S : 시작(start) 심볼

예를 들어, 문법 G=({A}, {a,b,c}, P, A), P : A $\rightarrow$ aA, A $\rightarrow$ abc가 있다면, 문법아래 생성될 수 있는 string은 aabc입니다. 위의 예는 단순하지만 생성 규칙에 따라 나올 수 있는 string의 경우의 수는 매우 많습니다. 이렇게 생성된 string을 tree구조로도 표현할 수 있습니다.

문법 G=({S}, {a,b}, P, S), P : S $\rightarrow$ SS | aSb | $\epsilon$ 이라면, 생성규칙에 따라 생성된 string중 하나는 $S \rightarrow SS \rightarrow aSbS \rightarrow abS \rightarrow abaSb \rightarrow abaaSbb \rightarrow abaabb$ 입니다. 이를 tree구조로 나타내면 아래 그림과 같습니다.

그림2. CFG grammar 예시

GVAE는 CFG의 아이디어를 이용한 모델입니다. CFG기반의 SMILES grammar가 있으며, encoder의 입력값은 SMILES string이 아니라 각 화합물 SMILES string을 생성하기 위해 사용된 생성규칙들입니다. 마찬가지로, decoding 결과는 SMILES string이 아니라, SMILES string에 관한 생성 규칙입니다. 시퀀스 별로 그 다음으로 나올 가능성이 높은 생성 규칙이 결과로 나옵니다. 세부적인 모델 설명은 아래와 같습니다.

Methods

본 논문에서 사용된 VAE의 encoder와 decoder는 Gomez-Bomb barelli et al.(2016)와 동일한 구조를 사용하였습니다.

encoding

gvae

그림3. GVAE encoder

위 그림은 모델의 encoder가 SMILES grammar와 함께 구현되는 과정에 관한 것입니다. 그림3의 1번은 SMILES grammar의 일부입니다. 전체 SMILES grammars는 논문 참고하시기 바랍니다. 예를 들어 벤젠 SMILES string인 c1ccccc1을 encoding 한다고 했을 때(2번), SMILES grammar에 따라 벤젠 SMILEs string의 parse tree를 구축합니다. 그런 뒤, 이 parse tree를 위에서부터 아래, 왼쪽에서 오른쪽 방향으로 생성 규칙들로 다시 분해된 후(3번), 분해된 각 규칙들은 원핫벡터로 변환됩니다(4번 그림). 이 때, 원핫벡터의 차원 $K$ 은 SMILES grammar 생성 규칙의 개수입니다. $T(X)$ 를 분해된 생성규칙들의 개수라 할 때, 벤젠의 생성규칙을 인코딩한 행렬의 차원은 $T(X)\times K$ 가 됩니다. 그 후, Deep CNN을 거쳐서, 벤젠에 대한 생성규칙을 latent space 상에 $z$ 로 맵핑합니다.

decoding

gvae-decoding

그림4. GVAE decoder

다음은 latent space상에 $z$ 로 맵핑된 벤젠 생성규칙이 어떻게 다시 discrete한 시퀀스를 가진 생성규칙들로 이뤄진 string으로 변환되는지에 관한 과정에 대한 설명입니다. GVAE에서 Decoder의 핵심은 항상 유효(valid)한 string이 나오도록 생성규칙들을 선택하는 것입니다. 먼저, 잠재 벡터 $z$ 를 RNN layer를 통과하여 시퀀스 별로 logit이 출력됩니다(그림4의 2번). logit 벡터의 각 차원은 하나의 SMILES grammar 생성규칙에 대응됩니다. 타임 시퀀스의 최대 길이는 $T_{max}$ 이며 따라서 최대로 나올 수 있는 logit 벡터의 갯수도 $T_{max}$ 입니다.

Masking Vector

decoding 결과로 출력된 일련의 생성규칙 시퀀스들이 유효하기 위해서 last-in first-out(LIFO) stack과 masking vector가 등장합니다.

decoding process

그림5. decoding process

그림 5.는 stack과 masking vector를 이용한 decoding과정을 나타낸 그림입니다. 제일 첫 심볼은 항상 smiles가 나와야 하므로, (1)처럼 smiles를 stack합니다. 그 다음, smiles를 뽑은 후, smiles으로 시작하는 생성규칙은 1 그 외 나머지는 0으로 구성된 masking vector를 구성한 뒤 첫번째 시퀀스 logit가 element-wise 곱을 합니다((3)). 그런 다음, 아래 mask된 분포에 따라 sampling을 하면 (4)와 같이, smiles $\rightarrow$ chain 생성규칙이 출력됩니다.

$$p(\text{x}_t = k|\alpha,\text{z}) = \frac{m_{\alpha, \,k}exp(f_{tk})}{\sum_{j=1}^Km_{\alpha, \,k}exp(f_{tj})}$$

수식1. masked distribution at timestep t

위와 같은 방법으로 $t \rightarrow T_{max}$ 가 될 때까지, sampling을 하면 한 화합물을 구성하는 생성규칙들을 출력하였고, 결국 문법적으로 유효한 화합물을 생성한 것입니다.

Bayesian Optimization

Gomez-Bomb barelli et al.(2016)1를 보면, 약물 특성을 포함한 latent space를 구축하기 위해, latent layer에 약물 특성을 예측하는 MLP layer를 추가하여 학습을 진행합니다. 마찬가지로, 약물 특성이 담긴 latent space를 구축하기 위해 VAE 학습 완료 후, Sparse Gaussian Process(SGP)를 이용하여 예측모델을 학습합니다. 여기서 사용된 약물 특성은 penalized logP 입니다.

Experiment Result

GVAE 모델의 성능은 Gomez-Bomb barelli et. al.(2016)1와 유사한 VAE인 Character VAE(CVAE) 성능과 비교하였습니다. 정성적인 평가를 위해, 임의의 수학표현식을 생성하여 유효한 수학표현식을 만드는지를 확인하였습니다.

gvae-result1

그림6. 결과 1

<그림 6.>의 Table 1.은두 모델의 embedding 공간의 smoothness를 보여주는 결과입니다. 각각 두 개의 식(볼드체)을 encoding 한 뒤, latent space 상의 두 점을 선형보간법(linear interpolation)을 한 것입니다. 보시면, GVAE는 100% 유효한 수학식이 공간 위에 있지만 CVAE는 그렇지 않은 것을 확인하실 수 있습니다. 이처럼, GVAE가 유효한 string으로 구성된 latent space를 더 잘 구축한다는 것입니다.

Table 2.은 각 latent space에서 z를 여러 번 샘플링 한 후, decoding 결과 유효한 수학표현식 또는 분자 string의 비율을 나타낸 것입니다. GVAE가 CVAE보다 문법적으로 의미있는 string을 더 잘 출력함을 확인할 수 있습니다.

gvae-result2

그림 7. 결과 2

그림 7의 결과는 penalized logP와 latent vector를 가지고 Bayesian Optimization한 결과, penalized logP score가 높은 순으로 3개를 뽑은 결과입니다. CVAE와 GVAE 모두 유효한 string을 내놨을 때, GVAE의 약물 특성에 관한 score가 더 높습니다. 즉, GVAE를 Bayesian Optimization까지 완료 후 형성된 latent space는 valid한 string을 내뱉는 공간을 형성했을 뿐만 아니라 약물 특성을 잘 포함하는 공간을 구축했음을 의미합니다.

gvae-result2

그림 8. 결과 3

그림 8의 결과는 약물 특성(penalized logP)에 관한 예측 성능에 관한 표입니다. Loss function을 Log-Likelihood와 RMSE를 모두 사용했을 때, GVAE가 CVAE보다 성능이 더 낫습니다. 다만, 개인적인 의견으로는 두개의 성능 차이는 거의 나지 않는 것으로 보입니다.

논문 한줄 리뷰평

GVAE는 이제까지 여러 Generative Model이 상당 수가 SMILES 문법에 어긋나는 string을 출력한다라는 단점을 보완하는 논문입니다. 하지만 valid한 string이 얼마나 신약개발에 적합한 string인지에 대한 결과는 논문에 실리지 않았습니다(이 부분은 대부분의 신약개발 논문들이 나와있지 않는 것으로 보입니다.). 하지만 valid한 string을 내놓는 것에 있어서 제일 실용적인(practical)하지도 않을까 생각이 듭니다.