ころがる狸

ころがる狸のデータ解析ブログ

【Graph Attention Networks解説】実装から読み解くGAT

こんにちは。機械学習の適用先としては、自然言語処理、画像解析、時系列解析など幅広い分野があるわけですが、今日はグラフ構造に対する機械学習モデルを紹介したいと思います。グラフで表現出るものは多く、例えば人間関係だとか、論文の引用・被引用関係、さらには化合物の構造なども当てはまります。近年のグラフニューラルネットワークの多くはグラフの頂点や辺を何らかの特徴量で表現し、それらを周囲の情報を取り込みながら更新していくという仕組みを取っています。数多くの事例が報告されていますが、特に注目されているGraph Attention Networks(GAT)について取り上げます。
原著論文はこちら。これを理解するための鍵は、グラフの頂点を表す特徴量をどのように更新するか、そしてグラフの頂点と頂点の「つながり」の重要度をどのように計算するか、という2点にあると思います。
arxiv.org

Graph Attention Networks

GithubでTensorFlowやPyTorchでの実装が公開されていますが、ここではPyTorch実装をクローンします。以下のURLをご参照下さい。本稿ではGATの仕組みだけでなく実装も理解して自分で組めるところまでもっていくことを目指したいです。そのためGATモデル構築部分に関しては、ソースコードを適宜参照しながら解説します。
GitHub - Diego999/pyGAT: Pytorch implementation of the Graph Attention Network model by Veličković et. al (2017, https://arxiv.org/abs/1710.10903)

学習データ

Coraデータセットの解説

この論文ではCora, Citeseer, Pubmed, PPIという4つのデータセットに対して様々なグラフモデルとの精度比較を行い、GATの高い精度を報告しています。ここではCoraデータセットについてのみ説明しましょう。
Coraデータセットには機械学習系の2708報の論文の引用・被引用関係と論文のラベルの情報が格納されています。このデータセットはcora.citescora.contentという2つのファイルから構成されており、cora.contentには
(paper_id) (word_attributes) + (class_label)
というフォーマットで論文のID, word_attribute(その論文の特徴量に相当, 1433次元), ラベルが並んでいます。ここでword_attributeは学習データの全論文の頻出単語として選び出した1433単語について、その論文に含まれていれば1、なければ0を格納したベクトルになります。ラベルは、以下の7つのクラスのいずれかが割り当てられています。

  • Case_Based,
  • Generic_Algorithms,
  • Neural_Networks,
  • Probabilistic_Methods,
  • Reinforcement_Learning,
  • Rule_Learning,
  • Theory

このデータの束が2708個入っているわけです。
続いてcora.citesには論文の引用関係が以下のようなフォーマットで格納されています。
(ID of cited paper) (ID of citing paper)
引用された・引用した論文IDの順です。これは全部で5429個あります。つまり論文同士のネットワークに、計5429個の関係が存在するということですね。NetworkXでCoraデータセットを可視化したのがこちらの図です。このようなグラフをニューラルネットワークで読み込んで、各論文の引用・被引用関係を学習して論文のジャンル(7つのクラスのうちいずれか)を予測するのが今回のタスクになります。

f:id:Dajiro:20200512212413p:plain
Coraグラフデータの可視化。関係性から未知のラベルを予測する。

Coraデータセットの読み込みと行列化

この2つのファイルを読み込んで、グラフ構造を抽出します。グラフは行列の形式で表現でき、各論文(ノードと呼びましょう)の関係性を表した隣接行列、ノードの特徴量を格納した特徴行列の2つの行列を生成します。それに加え、正解データとなるラベルのベクトルと、訓練・検証・テスト用データのインデックスも準備します。イメージとしては以下のような図になります。

f:id:Dajiro:20200509192212p:plain
学習データをグラフを表す行列(隣接行列、特徴行列)へと加工する。
このままでは隣接行列と特徴行列のイメージがつかないと思うので、これも図にしてみました。隣接行列の各行と各列はノード(論文)を表しており、引用関係があれば1、なければ0と表示されています。ほとんどの論文は関係がなく、ほとんどの行列要素が0のスパース行列となっています。一方で特徴行列の各行はノード(論文)を表しており、列は特徴量ベクトルに相当しています。これがGATモデルの入力として使われるわけですね。
(データセットから行列への加工には、スパース行列の計算を効率的に扱えるscipy.sparceというライブラリが使われているのですが、本稿では説明を省略します。長くなるので・・・汗)
f:id:Dajiro:20200509194148p:plain
隣接行列と特徴行列のイメージ図。これでグラフが持つ情報を数値化できる。

Graph Attention Networks(GAT)モデル

特徴行列の更新

つづいて、この論文の中核であるモデルの説明に移りましょう。意外とシンプルなモデルであることが後々分かると思いますが、以下の式が理解できればGATはほぼ理解できたといっていいと思います。これは、ノード特徴量 \vec{h}_{i}を周囲の情報を取り込んで更新することを意味しています。ここで Wは重み、 \alpha_{ij}は隣接するノードの重要度を表す係数で、この係数がかかっているが故にAttention(注意)という名前がついています。
 \vec{h}_{i}^{\prime} = \sigma (\sum_{j\in{N_i}}\alpha_{ij}W\vec{h}_{i})

ここでやっていることは要するに、

  1. 自分自身を含む、隣接するノード特徴量に重みをかける
  2. 重要度を表す係数をかける(重要度の計算は後ほど紹介します)
  3. それらの値を足す
  4. 適当な非線形活性化関数をかけてノード特徴量を更新する

という作業なので意外とシンプルではあります。さらに、論文中では結果を安定させるために別の重みや係数を用意して複数計算し、それらの結果を結合しようと主張しています。式で書くとこんな感じ。
 \vec{h}_{i}^{\prime} = \|_{k=1}^{K}\sigma (\sum_{j\in{N_i}}\alpha_{ij}^{k}W^{k}\vec{h}_{i})

ここで\|_{k=1}^{K}は、K個の計算を行って特徴量ベクトルの方向に結合しようという演算を意味しています。論文中ではマルチヘッドと呼ばれています。この場合、特徴量がN次元でK個結合すると、特徴量はN*Kになります。こうしてできた拡大したノード特徴量について、N*K次元の特徴量に対応し、かつ出力がラベルの次元数と等しくなるような適切な重みをもったアテンションレイヤーをもう一発かませて、最終的な出力を得ます。この場合は結合すると伸びてしまうので、平均を取ります。この演算に関しては論文中の図が分かりやすいので、張っておきましょう。

f:id:Dajiro:20200509210927p:plain
マルチヘッドアテンションのイメージ図。論文より引用。

重要度の計算

長くなりましたが、さっきから重要度といっている係数\alpha_{ij}の求め方を数式と実装ベースで見てみましょう。これは、ラベルを予測するうえで重要な論文間の関係性を評価したもので、予測精度向上に寄与するような関係性に対しては大きい値、ほとんど無関係な関係性には小さい値がかかります。式で書くとこのようになります。
\alpha_{ij} = \frac{{\rm exp}(e_{ij})}{\sum_{k\in{N_i}}{\rm exp}(e_{ik})}

softmax関数をかけてるので確率的な意味を持った数値になっていますね。このe_{ik}がアテンション係数と言われるもので、i,jノード間の関係性の重要度を表す指標になります。一般的にはアテンション係数は以下のように書くことが出来ます。
e_{ij} = a({\bf W}\vec {h}_{i}, {\bf W}\vec {h}_{j})

この具体的な計算方法には色々な流儀があります。たとえば2つのベクトルの内積を取るなど。そうすると似たベクトルは内積1、つまり大きな値を取りますし遠いベクトルだと小さい値になります。しかしここでは、アテンション係数の計算に単層のニューラルネットワークを用いています。式が抽象的ですがこのように書き下せます。
 e_{ij} = {\rm LeakyReLU} (\vec {a}^{T}({\bf W}\vec{h}_{i}\|{\bf W}\vec{h}_{j}))

数式ばかりで非常につかれますね。アテンションの計算から特徴行列の更新までの一式の処理を図で見た方が分かりやすいかもしれません。以下はアテンションレイヤーの処理内容の図解です。

f:id:Dajiro:20200509220258p:plain
アテンションレイヤーの処理内容のイメージ。

この図では重要な部分を抜粋していますが、どのような処理をしているのかのイメージが湧きやすいのではないでしょうか。より詳細な処理を見たい場合には以下のソースコードのforward部分の実装をご覧ください。__init__は行列の形状の定義などを行っていますがGATの根幹を理解する上ではそこまで重要ではありません。

class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
    """

    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        ###
        省略 
        ###
    def forward(self, input, adj):
        h = torch.mm(input, self.W)
        N = h.size()[0]

        a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))

        zero_vec = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, h)

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

上の図と対応させると、以下で特徴行列の圧縮を行っています。

h = torch.mm(input, self.W)

続いてアテンション係数の計算ですが、ここでやっているのは第i成分と第j成分の特徴量を結合して二倍の大きさ(つまり16)にする処理です。PyTorchの機能をうまくつかった、テクニカルな処理ですね。self.aは重みでa_inputとかけると(2708×2708×1)の行列ができますが第2軸が無駄なのでsqueeze(2)でこの軸を消去し(2708×2708)の行列にしているわけです。

a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))

そして以下の処理で \alpha_{ij}を計算しています。zero_vecには負の数がかかっていますが、これはsoftmax関数の内部にある指数関数にかけたときにゼロになるようにするためです。2行目のwhereは、ノード間に関係がある場合にはアテンション係数をそのまま入れて、関係がない場合にはゼロを設定するための処理です。

zero_vec = -9e15*torch.ones_like(e)
attention = torch.where(adj > 0, e, zero_vec)
attention = F.softmax(attention, dim=1)

このモジュールがGATの実装の根幹部分です。あとはdropoutをかけて過学習を抑制したり、このアテンションレイヤーを組み合わせて一つのネットワークを構築しているのです。

結果

試しに計算を走らせてみました。10エポックおきにデータを取りましたが、損失関数がは訓練・検証用データ双方に対して良く下がっていますね。ある程度汎化性能も獲得できていそうです。

f:id:Dajiro:20200509223747p:plain
学習曲線。10個おきにデータを取得。
続いて予測精度です。完璧にラベルを予測出来たら1になる指標ですね。この場合は検証用データの正解率が80%くらいで飽和しているのが分かります。なかなか高い精度ですね。論文でもこれくらいの精度が出ているので再現性もあり、文句なしの結果と言ってもいいのではないでしょうか。実装も理解してしまえばそこまで複雑でもないので、いろいろなところに応用が利きそうな技術ですね。業務でも是非試してみたいところ。
f:id:Dajiro:20200509224056p:plain
正解率のエポックに対する推移。80%を超えたあたりで飽和。