Graph Convolutional Network

$$\gdef \sam #1 {\mathrm{softargmax}(#1)}$$ $$\gdef \vect #1 {\boldsymbol{#1}} $$ $$\gdef \matr #1 {\boldsymbol{#1}} $$ $$\gdef \E {\mathbb{E}} $$ $$\gdef \V {\mathbb{V}} $$ $$\gdef \R {\mathbb{R}} $$ $$\gdef \N {\mathbb{N}} $$ $$\gdef \relu #1 {\texttt{ReLU}(#1)} $$ $$\gdef \D {\,\mathrm{d}} $$ $$\gdef \deriv #1 #2 {\frac{\D #1}{\D #2}}$$ $$\gdef \pd #1 #2 {\frac{\partial #1}{\partial #2}}$$ $$\gdef \set #1 {\left\lbrace #1 \right\rbrace} $$ % My colours $$\gdef \aqua #1 {\textcolor{8dd3c7}{#1}} $$ $$\gdef \yellow #1 {\textcolor{ffffb3}{#1}} $$ $$\gdef \lavender #1 {\textcolor{bebada}{#1}} $$ $$\gdef \red #1 {\textcolor{fb8072}{#1}} $$ $$\gdef \blue #1 {\textcolor{80b1d3}{#1}} $$ $$\gdef \orange #1 {\textcolor{fdb462}{#1}} $$ $$\gdef \green #1 {\textcolor{b3de69}{#1}} $$ $$\gdef \pink #1 {\textcolor{fccde5}{#1}} $$ $$\gdef \vgrey #1 {\textcolor{d9d9d9}{#1}} $$ $$\gdef \violet #1 {\textcolor{bc80bd}{#1}} $$ $$\gdef \unka #1 {\textcolor{ccebc5}{#1}} $$ $$\gdef \unkb #1 {\textcolor{ffed6f}{#1}} $$ % Vectors $$\gdef \vx {\pink{\vect{x }}} $$ $$\gdef \vy {\blue{\vect{y }}} $$ $$\gdef \vb {\vect{b}} $$ $$\gdef \vz {\orange{\vect{z }}} $$ $$\gdef \vtheta {\vect{\theta }} $$ $$\gdef \vh {\green{\vect{h }}} $$ $$\gdef \vq {\aqua{\vect{q }}} $$ $$\gdef \vk {\yellow{\vect{k }}} $$ $$\gdef \vv {\green{\vect{v }}} $$ $$\gdef \vytilde {\violet{\tilde{\vect{y}}}} $$ $$\gdef \vyhat {\red{\hat{\vect{y}}}} $$ $$\gdef \vycheck {\blue{\check{\vect{y}}}} $$ $$\gdef \vzcheck {\blue{\check{\vect{z}}}} $$ $$\gdef \vztilde {\green{\tilde{\vect{z}}}} $$ $$\gdef \vmu {\green{\vect{\mu}}} $$ $$\gdef \vu {\orange{\vect{u}}} $$ % Matrices $$\gdef \mW {\matr{W}} $$ $$\gdef \mA {\matr{A}} $$ $$\gdef \mX {\pink{\matr{X}}} $$ $$\gdef \mY {\blue{\matr{Y}}} $$ $$\gdef \mQ {\aqua{\matr{Q }}} $$ $$\gdef \mK {\yellow{\matr{K }}} $$ $$\gdef \mV {\lavender{\matr{V }}} $$ $$\gdef \mH {\green{\matr{H }}} $$ % Coloured math $$\gdef \cx {\pink{x}} $$ $$\gdef \ctheta {\orange{\theta}} $$ $$\gdef \cz {\orange{z}} $$ $$\gdef \Enc {\lavender{\text{Enc}}} $$ $$\gdef \Dec {\aqua{\text{Dec}}}$$
🎙️ Alfredo Canziani

Graph Convolutional Network (GCN)の紹介

GCN(Graph Convolutional Network)は、データの構造を利用したアーキテクチャの一種です。 GCNとself-attentionは概念的に関係があるので、詳細に入る前に、self-attentionについて簡単に復習しておきましょう。

Self-attentionのおさらい

  • Self-attentionでは、入力の集合$\lbrace\boldsymbol{x}{i}\rbrace^{t}{i=1}$があります。 系列データと違って、順番がありません。
  • 隠れベクトル $\boldsymbol{h}$ は、この集合の中のベクトルの線形結合で与えられます。
  • これを行列ベクトルの掛け算を使って $\boldsymbol{X}\boldsymbol{a}$ と表現します。ここで$\boldsymbol{a}$ には入力ベクトル $\boldsymbol{x}_{i}$ をスケーリングする係数が含まれています。

詳しい説明は、第12週のノート(/NYU-DLSP20/ja/week12/12-3/)を参照してください。

ノーテーション


図1: グラフ畳み込みニューラルネット

図1において、頂点 $v$ は、入力 $\boldsymbol{x}$ とその隠れ表現 $\boldsymbol{h}$ の2つのベクトルで構成されています。 また、複数の頂点 $v_{j}$ があり、これは $\boldsymbol{x}_j$ と $\boldsymbol{h}_j$ で構成されています。 このグラフでは、頂点は有向辺で結ばれています。

この有向辺を、隣接ベクトル $\boldsymbol{a}$ で表現し、各要素 $\alpha_{j}$ は、$v_{j}$ から $v$ への有向辺があれば $1$ とします。 \(\alpha_{j} \stackrel{\tiny \downarrow}{=} 1 \Leftrightarrow v_{j} \rightarrow v \tag{Eq. 1}\) 次数(入ってくる辺の数)$d$は、この隣接関係ベクトルのノルム、すなわち、$\boldsymbol{a}$の中の1の数である$\Vert\boldsymbol{a}\Vert_{1}$として定義されます。

\[d = \Vert\boldsymbol{a}\Vert_{1} \tag{Eq. 2}\]

隠れベクトル $\boldsymbol{h}$ は次の式で与えられます。

\[\boldsymbol{h}=f(\boldsymbol{U}\boldsymbol{x} + \boldsymbol{V}\boldsymbol{X}\boldsymbol{a}d^{-1}) \tag{Eq. 3}\]

ここで、$f(\cdot)$は、ReLU $(\cdot)^{+}$, Sigmoid $\sigma(\cdot)$, hyperbolic tangent $\tanh(\cdot)$などの非線形関数です。

この$\boldsymbol{U}\boldsymbol{x}$という項は、入力$v$に回転$\boldsymbol{U}$を適用することで、頂点$v$自体を考慮しています。

Self-attentionでは、隠れベクトル $\boldsymbol{h}$ は、$\boldsymbol{X}\boldsymbol{a}$ で計算され、$\boldsymbol{X}$ の列は、$\boldsymbol{a}$の要素でスケーリングされることを覚えておいてください。 GCNの文脈では、これは、隣接ベクトルの中に複数の辺が入ってくると(例えば隣接行列$\boldsymbol{a}$のなかの複数の要素)$\boldsymbol{X}\boldsymbol{a}$が大きくなることを意味します。 一方、入ってくる辺が1つしかない場合は、この値は小さくなります。 この値が流入する辺の数に比例するという問題を解決するために、流入する辺の数$d$で割ってみましょう。 そして、$\boldsymbol{X}\boldsymbol{a}d^{-1}$に回転$\boldsymbol{V}$をかけます。

この隠れ表現$\boldsymbol{h}$を、入力の全集合 $\boldsymbol{x}$に対して、次のような行列記法で表現することができます。

\[\{\boldsymbol{x}_{i}\}^{t}_{i=1}\rightsquigarrow \boldsymbol{H}=f(\boldsymbol{UX}+ \boldsymbol{VXAD}^{-1}) \tag{Eq. 4}\]

ただし$\vect{D} = \text{diag}(d_{i})$です。

Residual Gated GCNの理論とコード

Residual Gatedグラフ畳み込みニューラルネットは、図2に示すようなGCNの一種です。


図2: Residual Gatedグラフ畳み込みニューラルネット

標準的な GCN と同様に、頂点 $v$ は、入力 $\boldsymbol{x}$ とその隠れ表現 $\boldsymbol{h}$ の 2つのベクトルで構成されています。ただし、この場合、エッジにも特徴表現があり、入力辺の表現を $\boldsymbol{e_{j}^{x}}$、隠れ辺の表現を $\boldsymbol{e_{j}^{h}}$ とします。

頂点 $v$ の隠れ表現 $\boldsymbol{h}$ は、次式で求められます。

\[\boldsymbol{h}=\boldsymbol{x} + \bigg(\boldsymbol{Ax} + \sum_{v_j→v}{\eta(\boldsymbol{e_{j}})\odot \boldsymbol{Bx_{j}}}\bigg)^{+} \tag{Eq. 5}\]

ここで、$\boldsymbol{x}$ は入力表現です。$\boldsymbol{Ax}$ は、入力$\boldsymbol{x}$の回転を表し、$\sum_{v_j→v}{\eta(\boldsymbol{e_{j}})\odot \boldsymbol{Bx_{j}}}$ は、回転させた入力特徴量$\boldsymbol{Bx_{j}}$とゲート$\eta(\boldsymbol{e_{j}})$との要素毎の積の和を表します。上で紹介した標準的なGCNでは入力表現を平均化するのに対し、ゲート項は、エッジ表現に基づいて入力表現を変化させることができるので、Residual GatedGCNの実装には重要です。

なお、和は、頂点${v}$へ入力されるエッジを持つ頂点${v_j}$のみを対象としていることに注意してください。Residual Gated GCNにおけるresidualという用語は、隠れ表現 $\boldsymbol{h}$ を計算するために、入力表現 $\boldsymbol{x}$ を加えることに由来します。ゲート項$\eta(\boldsymbol{e_{j}})$は,次のように計算されます。

\[\eta(\boldsymbol{e_{j}})=\sigma(\boldsymbol{e_{j}})\bigg(\sum_{v_k→v}\sigma(\boldsymbol{e_{k}})\bigg)^{-1} \tag{Eq. 6}\]

ゲート値 $\eta(\boldsymbol{e_{j}})$ は、入ってくる辺のシグモイドをすべての入ってくる辺のシグモイドの和で割った正規化シグモイドです。ゲート項を計算するためには、次の式を用いて計算できるエッジの表現 $\boldsymbol{e_{j}}$ が必要です。

\[\boldsymbol{e_{j}} = \boldsymbol{Ce_{j}^{x}} + \boldsymbol{Dx_{j}} + \boldsymbol{Ex} \tag{Eq. 7}\] \[\boldsymbol{e_{j}^{h}}=\boldsymbol{e_{j}^{x}}+(\boldsymbol{e_{j}})^{+} \tag{Eq. 8}\]

辺の隠れ表現 $\boldsymbol{e_{j}^{h}}$ は、辺の表現の初期値 $\boldsymbol{e_{j}^{x}}$ と$\boldsymbol{e_{j}}$に適用された$\texttt{ReLU}(\cdot)$の和です。ここで、$\boldsymbol{e_{j}}$は、$\boldsymbol{e_{j}^{x}}$に適用される回転の和で与えられます。この回転は頂点 $v_{j}$ の入力表現 $\boldsymbol{x_{j}}$に対する回転と、頂点 $v$ の入力表現 $\boldsymbol{x}$ に対する回転です。

*注:下流の隠れ表現(例えば $2^\text{nd}$ 層の隠れ表現)を計算するには,入力特徴表現を上式の $1^\text{st}$ 層の特徴表現に置き換えれば大丈夫です。

グラフ分類とResidual Gated GCN Layer

ここでは、グラフ分類の問題点を紹介し、Residual Gated GCN層をコードに落としこみます。通常のimport文に加えて、以下を追加します。

os.environ['DGLBACKEND'] = 'pytorch'
import dgl
from dgl import DGLGraph
from dgl.data import MiniGCDataset
import networkx as nx

最初の行では、DGLにPyTorchをバックエンドとして使うように指示しています。Deep Graph Library (DGL) はグラフに関する様々な機能を提供しています。

このnotebookでは、与えられたグラフ構造を8つのグラフタイプのうちの1つに分類することを課題としています。dgl.data.MiniGCDataset から得られるデータセットには、min_num_v から max_num_v の間にノードを持つグラフ (num_graphs) がいくつか含まれています。したがって、得られたグラフはすべて同じ数のノード/頂点を持つわけではありません。

注: DGLGraphs の基本を理解するためには、こちらのチュートリアルを参照することをお勧めします。

グラフを作成したら、次の作業はドメインにシグナルを追加することです。特徴表現は、名前(文字列)とテンソル(fields)の辞書として表現されます。ndataedata は、すべてのノードとエッジの特徴データにアクセスするための糖衣構文です。

以下のコードスニペットは、特徴量がどのように生成されるかを示しています。各ノードには入射エッジの数に等しい値が割り当てられ、各エッジには値1が割り当てられます。

def create_artificial_features(dataset):
    for (graph, _) in dataset:
        graph.ndata['feat'] = graph.in_degrees().view(-1, 1).float()
        graph.edata['feat'] = torch.ones(graph.number_of_edges(), 1)
    return dataset

訓練データとテストデータが作成され、特徴量が次のように割り当てられます。

trainset = MiniGCDataset(350, 10, 20)
testset = MiniGCDataset(100, 10, 20)

trainset = create_artificial_features(trainset)
testset = create_artificial_features(testset)

訓練データ集合のサンプルグラフは以下のような表現です。ここで、このグラフは15個のノードと45個のエッジを持ち、ノードもエッジも期待通りの形状 (1,) の特徴表現を持っていることがわかります。さらに、0はこのグラフがクラス0に属していることを示しています。

(DGLGraph(num_nodes=15, num_edges=45,
         ndata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)}
         edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)}), 0)

DGLのMessage関数とReduce関数に関する注意

DGLでは、メッセージ関数Edge UDF(ユーザー定義関数)として表現されます。Edge UDFは単一の引数 edges を持ちます。エッジUDFは、ソースノードの特徴、到達ノードの特徴、エッジの特徴、のそれぞれアクセスするために、src, dst, data の3つのメンバを持ちます。 Node UDFは、reduce関数です。ノードUDFは単一の引数 nodes を持ち、その引数には2つのメンバ datamailbox を持ちます。dataにはノードの特徴が含まれ、mailboxにはすべての受信メッセージの特徴が含まれ、2番目の次元に沿って積み上げられています(そのため、dim=1引数が指定されています←次元は0から始まるので)。 update_all(message_func, reduce_func) は、すべてのエッジを経由してメッセージを送信し、すべてのノードを更新します。

Gated Residual GCN層の実装

Gated Residual GCN層は、以下のコードスニペットのように実装されています。

まず、__init__関数内に、nn.Linear層を定義して、入力表現 h, eforward 関数内の線形層を介して順伝搬させることで、入力特徴 $\boldsymbol{Ax}$, $\boldsymbol{Bx_{j}}$, $\boldsymbol{Ce_{j}^{x}}$, $\boldsymbol{Dx_{j}}$, $\boldsymbol{Ex}$ のすべての回転を計算します。

class GatedGCN_layer(nn.Module):

    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.A = nn.Linear(input_dim, output_dim)
        self.B = nn.Linear(input_dim, output_dim)
        self.C = nn.Linear(input_dim, output_dim)
        self.D = nn.Linear(input_dim, output_dim)
        self.E = nn.Linear(input_dim, output_dim)
        self.bn_node_h = nn.BatchNorm1d(output_dim)
        self.bn_node_e = nn.BatchNorm1d(output_dim)

次に、エッジの表現を計算します。これは message_func 関数の中で行われ、すべての辺を反復処理して辺の表現を計算します。具体的には、e_ij = edges.data['Ce'] + edges.src['Dh'] + edges.dst['Eh'] という行で (Eq. 7) を計算します。関数 message_func は、Bh_j (これは (Eq. 5)の$\boldsymbol{Bx_{j}}$ です) と e_ij (Eq. 7) を、エッジを経由して宛先ノードのメールボックスに送信します。

def message_func(self, edges):
    Bh_j = edges.src['Bh']
    # e_ij = Ce_ij + Dhi + Ehj
    e_ij = edges.data['Ce'] + edges.src['Dh'] + edges.dst['Eh']
    edges.data['e'] = e_ij
    return {'Bh_j' : Bh_j, 'e_ij' : e_ij}

第三に、reduce_func関数は、message_func関数によって配送されたメッセージを収集します。メールボックスからノードデータ Ah と配送されたデータ Bh_j, e_ij を収集した後、h = Ah_i + torch.sum(sigma_ij * Bh_j, dim=1) / torch.sum(sigma_ij, dim=1) 行により、(Eq. 5)で示されるように、各ノードの隠れ表現を計算します。ただし、これは$\texttt{ReLU}(\cdot)$とresidual connectionなしで $(\boldsymbol{Ax} + \sum_{v_j→v}{\eta(\boldsymbol{e_{j}})\odot \boldsymbol{Bx_{j}}})$項だけを表しています。

def reduce_func(self, nodes):
    Ah_i = nodes.data['Ah']
    Bh_j = nodes.mailbox['Bh_j']
    e = nodes.mailbox['e_ij']
    # sigma_ij = sigmoid(e_ij)
    sigma_ij = torch.sigmoid(e)
    # hi = Ahi + sum_j eta_ij * Bhj
    h = Ah_i + torch.sum(sigma_ij * Bh_j, dim=1) / torch.sum(sigma_ij, dim=1)
    return {'h' : h}

関数 forward の中で g.update_all を呼び、グラフの畳み込みの結果 he を得ます。これは (Eq.5)より、$(\boldsymbol{Ax} + \sum_{v_j→v}{\eta(\boldsymbol{e_{j}})\odot \boldsymbol{Bx_{j}}})$を、(Eq.7)より$\boldsymbol{e_{j}}$を表しています。そして、グラフのノードサイズとグラフのエッジサイズを基準にして he を正規化します。その後、ネットワークを効率的に学習できるように、バッチ正規化を行います。最後に、$\texttt{ReLU}(\cdot)$を適用し、residual connectionを加えて、ノードとエッジの隠れ表現を得ます。

def forward(self, g, h, e, snorm_n, snorm_e):

    h_in = h # residual connection
    e_in = e # residual connection

    g.ndata['h']  = h
    g.ndata['Ah'] = self.A(h)
    g.ndata['Bh'] = self.B(h)
    g.ndata['Dh'] = self.D(h)
    g.ndata['Eh'] = self.E(h)
    g.edata['e']  = e
    g.edata['Ce'] = self.C(e)

    g.update_all(self.message_func, self.reduce_func)

    h = g.ndata['h'] # result of graph convolution
    e = g.edata['e'] # result of graph convolution

    h = h * snorm_n # normalize activation w.r.t. graph node size
    e = e * snorm_e # normalize activation w.r.t. graph edge size

    h = self.bn_node_h(h) # batch normalization
    e = self.bn_node_e(e) # batch normalization

    h = torch.relu(h) # non-linear activation
    e = torch.relu(e) # non-linear activation

    h = h_in + h # residual connection
    e = e_in + e # residual connection

    return h, e

次に、複数の全結合層(FCN)を含む MLP_Layer モジュールを定義します。全結合層のリストを作成し、順伝播を定義します。

最後に、GatedGCN_layerMLP_layer という先に定義したクラスからなる GatedGCN モデルを定義します。GatedGCNモデルの定義を以下に示します。

 class GatedGCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, L):
        super().__init__()
        self.embedding_h = nn.Linear(input_dim, hidden_dim)
        self.embedding_e = nn.Linear(1, hidden_dim)
        self.GatedGCN_layers = nn.ModuleList([
            GatedGCN_layer(hidden_dim, hidden_dim) for _ in range(L)
        ])
        self.MLP_layer = MLP_layer(hidden_dim, output_dim)
    def forward(self, g, h, e, snorm_n, snorm_e):
        # input embedding
        h = self.embedding_h(h)
        e = self.embedding_e(e)
        # graph convnet layers
        for GGCN_layer in self.GatedGCN_layers:
            h, e = GGCN_layer(g, h, e, snorm_n, snorm_e)
        # MLP classifier
        g.ndata['h'] = h
        y = dgl.mean_nodes(g,'h')
        y = self.MLP_layer(y)
        return y

私たちのコンストラクタでは、ehの埋め込み(self.embedding_e self.embedding_h)、self.GatedGCN_layersを定義しています。これは以前定義したモデルGatedGCN_layerからなるサイズ$L$のリストです。また、self.MLP_layerも定義します。これも以前定義したものです。次に、これらの初期化を使って、順伝播して y を出力します。

モデルをより理解しやすくするために、モデルのオブジェクトを初期化してしてprintします。

# instantiate network
model = GatedGCN(input_dim=1, hidden_dim=100, output_dim=8, L=2)
print(model)

モデルの主な構造をいかに示します。

GatedGCN(
  (embedding_h): Linear(in_features=1, out_features=100, bias=True)
  (embedding_e): Linear(in_features=1, out_features=100, bias=True)
  (GatedGCN_layers): ModuleList(
    (0): GatedGCN_layer(...)
    (1): GatedGCN_layer(... ))
  (MLP_layer): MLP_layer(
    (FC_layers): ModuleList(...))

驚くことではありませんが、GatedGCN_layerの2つの層(L=2なので)と、MLP_layerの2つの層があり、最終的には8つの値が出力されます。

次に、trainevaluate 関数を定義します。train 関数では、dataloader からサンプルを取得する汎用コードを定義しています。 次に、batch_graphs, batch_x, batch_e, batch_snorm_n, batch_snorm_e をモデルに入力し、batch_scores (サイズ8) を返します。予測されたスコアは、損失関数 loss(batch_scores, batch_labels) を用いて真の値と比較されます。次に、勾配をゼロにし (optimizer.zero_grad())、逆伝播を行い (J.backward())、重みを更新します (optimizer.step())。最後に、各エポックの損失と精度を計算します。さらに、同様のコードを evaluate 関数にも用います。

最後に、訓練の準備ができました!40エポックの学習を行った結果、我々のモデルはテスト精度$87$%でグラフを分類することを学習したことがわかりました。


📝 Go Inoue, Muhammad Osama Khan, Muhammad Shujaat Mirza, Muhammad Muneeb Afzal
Shiro Takagi
28 Apr 2020