본문 바로가기

딥러닝

[딥러닝] GNN이란?

반응형

※ 이번 포스팅에서는 추천 시스템, 관심사 분류 등에 사용되는 GNN에 대해서 알아보자. 

※ Introduction to Graph Neural Networks : A Starting Point for Machine Learning Engineers 논문을 참고하여 작성하였다.


목차

  1. GNN이란?
  2. GNN의 구조

GNN이란?

 GNN(Graph Neural Network)이란 그래프로 표현할 수 있는 데이터를 처리하는 데 사용되는 인공 신경망이다. GNN에 대해 관심 있는 사람들은 그래프에 대해 잘 알고 있겠지만 혹시 모르는 분들을 위해 그래프에 대해서 잠깐 설명하겠다.


Graph

 그래프의 정의는 다음과 같다.

 비선형 자료구조로 여러 개의 "노드"와 이들을 연결하는 "간선"으로 이루어진 자료구조이다.

 

 그래프의 구성 요소는 정의에서 보다시피 노드와 간선이다. 여기서 노드는 그래프 내에서 개별 개체를 나타낸다. 예를 들어, 인터넷 소셜 네트워크를 생각해보자. 각각의 사람이 하나의 노드를 나타낸다.

 

 간선은 노드 간의 관계를 나타낸다. 아까와 같이 소셜 네트워크에서 친구 관계, 직장 동료 관계, 부모 자식 관계 등 여러 관계를 간선으로 표현할 수 있다.


 

 다시 GNN으로 돌아오면 결국 이 그래프 구조를 활용하여 노드의 특징을 추출하고 각 노드들의 관계를 간선을 통해 전달함으로써 원하는 목적을 달성하고자 한다.


GNN의 구조

 GNN의 구조는 인코더와 디코더 부분으로 이루어져있다.

  1. 인코더
  2. 디코더

인코더 (Encoder)

 

 인코더 부분은 그래프 데이터를 저차원 임베딩 벡터로 변환하는 역할을 수행한다. 노드가 가지고 있는 정보와 그래프 구조를 통해 의미있는 벡터 표현을 학습하는 과정이다.

 

 인코더는 크게 세 가지 단계로 구성되어 있다.

  1. 입력 변환
  2. 그래프 구조 반영
  3. 메시지 패싱 수행

1. 입력 변환

노드의 원본 특징 벡터를 학습 가능한 형태로 변환하는 과정이다. 

 

$$ H^{0} = XW_{0} + b_{0}$$

초기 한 노드의 상태 및 정보를 나타내는 벡터를 $X$라고 하면 $X$에 가중치 $W_{0}$를 곱하고, 편향$b_{0}를 더해 변환한다. 이를 통해 변환된 노드 특징 벡터 $H^{(0)}$이 생성된다. 여기서 노드 임베딩의 차원을 조정하여 모델이 학습할 수 있또록 만든다. 또한 추가로 활성화함수를 적용하여 비선형성을 추가한다.

 

2. 그래프 구조 반영

 

그래프의 연결 정보를 반영하는 과정으로, 인접 행렬을 사용하여 이웃 노드 정보를 고려한다. 이때, 인접 행렬을 정규화하여 사용한다.

$$\tilde{A} = D^{-\frac{1}{2}}AD^{-\frac{1}{2}} $$

여기서 $A$는 노드가 연결되어 있는지를 나타내는 인접 행렬이고 $D$는 차수 행렬로 각 노드의 연결된 노드의 개수를 표현하는 행렬이다.

 

● 인접 행렬

 

● 차수 행렬

 

☆ 인접 행렬을 정규화하여 사용한다고 했는데 그 이유는 무엇일까?

 

 인접 행렬을 정규화하는 이유는 노드가 균일하게 연결되어 있는 것이 아니라 어떤 노드는 여러 개의 노드와 연결되어 있고 어떤 노드는 연결된 노드의 수가 매우 적으면 연결이 많은 노드에 과도한 정보 반영이 이뤄지고 연결이 적은 노드는 업데이트할 정보가 부족한 현상이 발생한다. 따라서 이 문제를 보완하고자 정규화를 통해 노드 간 연결 강도를 균형있게 조정하여 학습을 안정적으로 수행할 수 있다.

 

☆ 그렇다면 왜 $ D^{-\frac{1}{2}}AD^{-\frac{1}{2}}$ 로 정규화 하는가?

 

 그냥 $D^{-1}A로 정규화하는 방법도 있지만 이 경우, 차수의 영향력을 줄이는 효과가 있긴 하지만 비대칭이 발생한다. 따라서 정규화할 때는 위와 같이 대칭성을 유지하여 정규화한다.

 

3. 메시지 패싱 수행

 

 GNN의 핵심 과정으로 노드가 이웃 노드의 정보를 받아 업데이트하는 과정이다.

 

$$ H^{(k+1)} = \sigma(\tilde{A}H^{k}W_{k})$$

 

● 이웃 노드의 정보를 집계(aggregation)하여 업데이트한다.

● 활성화 함수를 적용하여 비선형성을 추가한다.

● 이 과정을 K번 반복하면서 점점 더 풍부한 노드 표현을 학습한다.(여기서 K는 어느 정도로 떨어진 노드의 정보까지를 수용할 것인지를 정하는 하이퍼파라미터이다.)

 

메시지 패싱 과정

 

1. 메시지 생성

$$m^{(k)}_{ij} = Wh^{(k)}_{j}$$

  • 노드 $i$는 이웃 노드 $j$로부터 받을 메시지를 생성한다.

2. 메시지 집계

$$a^{(k)}_{i} = \sum_{j\in N(i)}m^{(k)}_{ij}$$

  • 노드 $i$는 이웃으로부터 받은 메시지를 하나로 합친다.

3. 노드 상태 업데이트

$$h^{(k+1)}_{i} = \sigma(Wa^{(k)}_{i})$$

  • 새로운 노드 상태로 업데이트 한다.

이 과정을 K번 반복하면서 점점 더 멀리 있는 이웃들의 정보를 반영한다. 최종적으로는 노드의 학습된 표현 $h^{(K)}_{i}$가 만들어진다.

 

최종적으로 인코더에서 반환되는 값은 노드의 학습된 표현이다. 이 노드 임베딩은 디코더에서 사용되어 다양한 task를 수행한다.


디코더 (Decoder)

 

 디코더는 인코더가 반환한 학습된 임베딩을 바탕으로 Task에 맞는 출력을 생성하는 부분이다. Task에 따라 구조가 다르므로 여기서는 task의 종류에 따라 어떤 작업을 수행하는지 알아보자.

 

1. 노드 분류

  • 개별 노드가 특정 클래스에 속하는지를 예측하는 문제
  • Ex) 소셜 네트워크에서 사용자의 관심사를 분류, 문서 카테고리를 분류

2. 링크 예측

  • 두 노드가 연결될 가능성을 예측하는 문제
  • Ex) 추천 시스템에서 사용자-상품을 추천, 논문 인용 관계를 예측

3. 그래프 분류

  • 그래프 전체가 특정 클래스에 속하는지를 예측하는 문제
  • Ex) 분자의 화학적 특성을 예측, 소셜 네트워크 유형을 분류

1. 노드 분류

 개별 노드가 특정 클래스에 속하는지를 예측하는 문제이다.

Ex) 소셜 네트워크에서 사용자의 관심사를 분류하는 문제, 문서의 카테고리를 분류하는 문제

 

☆ 디코더 설계

  • 메시지 패싱 후, 각 노드의 최종 임베딩을 분류기에 입력한다.
  • 입력된 최종 임베딩으로 분류기는 softmax 활성화 함수를 사용하여 확률값을 입력한다.

$$\hat{y}_{i} = Softmax(W_{out}h^{(K)}_{i})$$

$W_{out}$ : 학습 가능한 가중치 행렬

Softmax를 통해 노드가 특정 클래스에 속할 확률을 예측한다.

 

☆ 손실함수 식

 

$$L = -\sum_{i}y_{i}log\hat{y}_{i}$$

 

2. 링크 예측

 두 노드가 연결될 가능성을 예측하는 문제이다.

Ex) 추천 시스템에서 사용자에게 상품을 추천하는 목적, 논문의 인용 관계를 예측하는 목적

 

☆ 디코더 설계

 

  • 메시지 패싱 후에 두 노드의 임베딩을 결합하여 연결 가능성을 예측한다.

$$\hat{y}_{ij} = \sigma(h^{(K)}_{i}\cdot h^{(K)}_{j})$$

 

두 노드 임베딩의 내적을 계산해서 연결 여부를 예측한다. 이때 시그모이드 함수를 적용하여 0~1범위의 확률값을 생성한다.

 

☆ 손실함수 식 (Binary Cross Entropy Loss)

 

$$L =  -\sum_{(i,j)\in E}[y_{ij}log\hat{y}_{ij}+(1-y_{ij})log(1-\hat{y}_{ij})]$$

 

3. 그래프 분류

 그래프 전체가 특정 클래스에 속하는지를 예측하는 문제이다.

Ex) 분자의 화학적 특성을 예측하는 문제, 소셜 네트워크의 유형을 분류하는 문제

 

☆ 디코더 설계

  • 개별 노드의 임베딩을 Pooling하여 그래프 전체를 하나의 벡터로 변환한다.
  • 변환된 벡터를 분류기에 입력하여 클래스를 예측한다.

$$h_{G} = Pooling(h^{(K)}_{1}, h^{(K)}_{2}, ..., h^{(K)}_{N})$$

$$\hat{y}_{G} = Softmax(W_{out}h_{G})$$

 

 풀링(Pooling) 기법에 대해서 간단하게 설명하면 여러 개의 값을 하나의 대표 값으로 변환하여 데이터의 크기를 줄이면서도 중요한 정보를 유지하는 방법이다. 풀링 기법의 예시로는 평균 풀링, 합 풀링, 최대 풀링이 있다.

 

☆ 손실함수 식

 

$$ L = -\sum_{G}y_{G}log\hat{y}_{G}$$

 


반응형