Graph Convolutional Network
🎙️ Alfredo CanzianiGraph Convolutional Network (GCN)の紹介
GCN(Graph Convolutional Network)は、データの構造を利用したアーキテクチャの一種です。 GCNとself-attentionは概念的に関係があるので、詳細に入る前に、self-attentionについて簡単に復習しておきましょう。
Self-attentionのおさらい
- Self-attentionでは、入力の集合$\lbrace\boldsymbol{x}{i}\rbrace^{t}{i=1}$があります。 系列データと違って、順番がありません。
- 隠れベクトル は、この集合の中のベクトルの線形結合で与えられます。
- これを行列ベクトルの掛け算を使って と表現します。ここで には入力ベクトル をスケーリングする係数が含まれています。
詳しい説明は、第12週のノート(/NYU-DLSP20/ja/week12/12-3/)を参照してください。
ノーテーション

図1: グラフ畳み込みニューラルネット
図1において、頂点 は、入力 とその隠れ表現 の2つのベクトルで構成されています。 また、複数の頂点 があり、これは と で構成されています。 このグラフでは、頂点は有向辺で結ばれています。
この有向辺を、隣接ベクトル で表現し、各要素 は、 から への有向辺があれば とします。 \(\alpha_{j} \stackrel{\tiny \downarrow}{=} 1 \Leftrightarrow v_{j} \rightarrow v \tag{Eq. 1}\) 次数(入ってくる辺の数)は、この隣接関係ベクトルのノルム、すなわち、の中の1の数であるとして定義されます。
隠れベクトル は次の式で与えられます。
ここで、は、ReLU , Sigmoid , hyperbolic tangent などの非線形関数です。
このという項は、入力に回転を適用することで、頂点自体を考慮しています。
Self-attentionでは、隠れベクトル は、 で計算され、 の列は、の要素でスケーリングされることを覚えておいてください。 GCNの文脈では、これは、隣接ベクトルの中に複数の辺が入ってくると(例えば隣接行列のなかの複数の要素)が大きくなることを意味します。 一方、入ってくる辺が1つしかない場合は、この値は小さくなります。 この値が流入する辺の数に比例するという問題を解決するために、流入する辺の数で割ってみましょう。 そして、に回転をかけます。
この隠れ表現を、入力の全集合 に対して、次のような行列記法で表現することができます。
ただしです。
Residual Gated GCNの理論とコード
Residual Gatedグラフ畳み込みニューラルネットは、図2に示すようなGCNの一種です。

図2: Residual Gatedグラフ畳み込みニューラルネット
標準的な GCN と同様に、頂点 は、入力 とその隠れ表現 の 2つのベクトルで構成されています。ただし、この場合、エッジにも特徴表現があり、入力辺の表現を 、隠れ辺の表現を とします。
頂点 の隠れ表現 は、次式で求められます。
ここで、 は入力表現です。 は、入力の回転を表し、 は、回転させた入力特徴量とゲートとの要素毎の積の和を表します。上で紹介した標準的なGCNでは入力表現を平均化するのに対し、ゲート項は、エッジ表現に基づいて入力表現を変化させることができるので、Residual GatedGCNの実装には重要です。
なお、和は、頂点へ入力されるエッジを持つ頂点のみを対象としていることに注意してください。Residual Gated GCNにおけるresidualという用語は、隠れ表現 を計算するために、入力表現 を加えることに由来します。ゲート項は,次のように計算されます。
ゲート値 は、入ってくる辺のシグモイドをすべての入ってくる辺のシグモイドの和で割った正規化シグモイドです。ゲート項を計算するためには、次の式を用いて計算できるエッジの表現 が必要です。
辺の隠れ表現 は、辺の表現の初期値 とに適用されたの和です。ここで、は、に適用される回転の和で与えられます。この回転は頂点 の入力表現 に対する回転と、頂点 の入力表現 に対する回転です。
*注:下流の隠れ表現(例えば 層の隠れ表現)を計算するには,入力特徴表現を上式の 層の特徴表現に置き換えれば大丈夫です。
グラフ分類とResidual Gated GCN Layer
ここでは、グラフ分類の問題点を紹介し、Residual Gated GCN層をコードに落としこみます。通常のimport文に加えて、以下を追加します。
os.environ['DGLBACKEND'] = 'pytorch'
import dgl
from dgl import DGLGraph
from dgl.data import MiniGCDataset
import networkx as nx
最初の行では、DGLにPyTorchをバックエンドとして使うように指示しています。Deep Graph Library (DGL) はグラフに関する様々な機能を提供しています。
このnotebookでは、与えられたグラフ構造を8つのグラフタイプのうちの1つに分類することを課題としています。dgl.data.MiniGCDataset
から得られるデータセットには、min_num_v
から max_num_v
の間にノードを持つグラフ (num_graphs
) がいくつか含まれています。したがって、得られたグラフはすべて同じ数のノード/頂点を持つわけではありません。
注: DGLGraphs
の基本を理解するためには、こちらのチュートリアルを参照することをお勧めします。
グラフを作成したら、次の作業はドメインにシグナルを追加することです。特徴表現は、名前(文字列)とテンソル(fields)の辞書として表現されます。ndata
と edata
は、すべてのノードとエッジの特徴データにアクセスするための糖衣構文です。
以下のコードスニペットは、特徴量がどのように生成されるかを示しています。各ノードには入射エッジの数に等しい値が割り当てられ、各エッジには値1が割り当てられます。
def create_artificial_features(dataset):
for (graph, _) in dataset:
graph.ndata['feat'] = graph.in_degrees().view(-1, 1).float()
graph.edata['feat'] = torch.ones(graph.number_of_edges(), 1)
return dataset
訓練データとテストデータが作成され、特徴量が次のように割り当てられます。
trainset = MiniGCDataset(350, 10, 20)
testset = MiniGCDataset(100, 10, 20)
trainset = create_artificial_features(trainset)
testset = create_artificial_features(testset)
訓練データ集合のサンプルグラフは以下のような表現です。ここで、このグラフは15個のノードと45個のエッジを持ち、ノードもエッジも期待通りの形状 (1,)
の特徴表現を持っていることがわかります。さらに、0
はこのグラフがクラス0に属していることを示しています。
(DGLGraph(num_nodes=15, num_edges=45,
ndata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)}
edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)}), 0)
DGLのMessage関数とReduce関数に関する注意
DGLでは、メッセージ関数はEdge UDF(ユーザー定義関数)として表現されます。Edge UDFは単一の引数
edges
を持ちます。エッジUDFは、ソースノードの特徴、到達ノードの特徴、エッジの特徴、のそれぞれアクセスするために、src
,dst
,data
の3つのメンバを持ちます。 Node UDFは、reduce関数です。ノードUDFは単一の引数nodes
を持ち、その引数には2つのメンバdata
とmailbox
を持ちます。data
にはノードの特徴が含まれ、mailbox
にはすべての受信メッセージの特徴が含まれ、2番目の次元に沿って積み上げられています(そのため、dim=1
引数が指定されています←次元は0から始まるので)。update_all(message_func, reduce_func)
は、すべてのエッジを経由してメッセージを送信し、すべてのノードを更新します。
Gated Residual GCN層の実装
Gated Residual GCN層は、以下のコードスニペットのように実装されています。
まず、__init__
関数内に、nn.Linear
層を定義して、入力表現 h
, e
を forward
関数内の線形層を介して順伝搬させることで、入力特徴 , , , , のすべての回転を計算します。
class GatedGCN_layer(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.A = nn.Linear(input_dim, output_dim)
self.B = nn.Linear(input_dim, output_dim)
self.C = nn.Linear(input_dim, output_dim)
self.D = nn.Linear(input_dim, output_dim)
self.E = nn.Linear(input_dim, output_dim)
self.bn_node_h = nn.BatchNorm1d(output_dim)
self.bn_node_e = nn.BatchNorm1d(output_dim)
次に、エッジの表現を計算します。これは message_func
関数の中で行われ、すべての辺を反復処理して辺の表現を計算します。具体的には、e_ij = edges.data['Ce'] + edges.src['Dh'] + edges.dst['Eh']
という行で (Eq. 7) を計算します。関数 message_func
は、Bh_j
(これは (Eq. 5)の です) と e_ij
(Eq. 7) を、エッジを経由して宛先ノードのメールボックスに送信します。
def message_func(self, edges):
Bh_j = edges.src['Bh']
# e_ij = Ce_ij + Dhi + Ehj
e_ij = edges.data['Ce'] + edges.src['Dh'] + edges.dst['Eh']
edges.data['e'] = e_ij
return {'Bh_j' : Bh_j, 'e_ij' : e_ij}
第三に、reduce_func
関数は、message_func
関数によって配送されたメッセージを収集します。メールボックスからノードデータ Ah
と配送されたデータ Bh_j
, e_ij
を収集した後、h = Ah_i + torch.sum(sigma_ij * Bh_j, dim=1) / torch.sum(sigma_ij, dim=1)
行により、(Eq. 5)で示されるように、各ノードの隠れ表現を計算します。ただし、これはとresidual connectionなしで 項だけを表しています。
def reduce_func(self, nodes):
Ah_i = nodes.data['Ah']
Bh_j = nodes.mailbox['Bh_j']
e = nodes.mailbox['e_ij']
# sigma_ij = sigmoid(e_ij)
sigma_ij = torch.sigmoid(e)
# hi = Ahi + sum_j eta_ij * Bhj
h = Ah_i + torch.sum(sigma_ij * Bh_j, dim=1) / torch.sum(sigma_ij, dim=1)
return {'h' : h}
関数 forward
の中で g.update_all
を呼び、グラフの畳み込みの結果 h
と e
を得ます。これは (Eq.5)より、を、(Eq.7)よりを表しています。そして、グラフのノードサイズとグラフのエッジサイズを基準にして h
と e
を正規化します。その後、ネットワークを効率的に学習できるように、バッチ正規化を行います。最後に、を適用し、residual connectionを加えて、ノードとエッジの隠れ表現を得ます。
def forward(self, g, h, e, snorm_n, snorm_e):
h_in = h # residual connection
e_in = e # residual connection
g.ndata['h'] = h
g.ndata['Ah'] = self.A(h)
g.ndata['Bh'] = self.B(h)
g.ndata['Dh'] = self.D(h)
g.ndata['Eh'] = self.E(h)
g.edata['e'] = e
g.edata['Ce'] = self.C(e)
g.update_all(self.message_func, self.reduce_func)
h = g.ndata['h'] # result of graph convolution
e = g.edata['e'] # result of graph convolution
h = h * snorm_n # normalize activation w.r.t. graph node size
e = e * snorm_e # normalize activation w.r.t. graph edge size
h = self.bn_node_h(h) # batch normalization
e = self.bn_node_e(e) # batch normalization
h = torch.relu(h) # non-linear activation
e = torch.relu(e) # non-linear activation
h = h_in + h # residual connection
e = e_in + e # residual connection
return h, e
次に、複数の全結合層(FCN)を含む MLP_Layer
モジュールを定義します。全結合層のリストを作成し、順伝播を定義します。
最後に、GatedGCN_layer
と MLP_layer
という先に定義したクラスからなる GatedGCN
モデルを定義します。GatedGCN
モデルの定義を以下に示します。
class GatedGCN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, L):
super().__init__()
self.embedding_h = nn.Linear(input_dim, hidden_dim)
self.embedding_e = nn.Linear(1, hidden_dim)
self.GatedGCN_layers = nn.ModuleList([
GatedGCN_layer(hidden_dim, hidden_dim) for _ in range(L)
])
self.MLP_layer = MLP_layer(hidden_dim, output_dim)
def forward(self, g, h, e, snorm_n, snorm_e):
# input embedding
h = self.embedding_h(h)
e = self.embedding_e(e)
# graph convnet layers
for GGCN_layer in self.GatedGCN_layers:
h, e = GGCN_layer(g, h, e, snorm_n, snorm_e)
# MLP classifier
g.ndata['h'] = h
y = dgl.mean_nodes(g,'h')
y = self.MLP_layer(y)
return y
私たちのコンストラクタでは、e
とh
の埋め込み(self.embedding_e
と self.embedding_h
)、self.GatedGCN_layers
を定義しています。これは以前定義したモデルGatedGCN_layer
からなるサイズのリストです。また、self.MLP_layer
も定義します。これも以前定義したものです。次に、これらの初期化を使って、順伝播して y
を出力します。
モデルをより理解しやすくするために、モデルのオブジェクトを初期化してしてprint
します。
# instantiate network
model = GatedGCN(input_dim=1, hidden_dim=100, output_dim=8, L=2)
print(model)
モデルの主な構造をいかに示します。
GatedGCN(
(embedding_h): Linear(in_features=1, out_features=100, bias=True)
(embedding_e): Linear(in_features=1, out_features=100, bias=True)
(GatedGCN_layers): ModuleList(
(0): GatedGCN_layer(...)
(1): GatedGCN_layer(... ))
(MLP_layer): MLP_layer(
(FC_layers): ModuleList(...))
驚くことではありませんが、GatedGCN_layer
の2つの層(L=2
なので)と、MLP_layer
の2つの層があり、最終的には8つの値が出力されます。
次に、train
と evaluate
関数を定義します。train
関数では、dataloader
からサンプルを取得する汎用コードを定義しています。 次に、batch_graphs
, batch_x
, batch_e
, batch_snorm_n
, batch_snorm_e
をモデルに入力し、batch_scores
(サイズ8) を返します。予測されたスコアは、損失関数 loss(batch_scores, batch_labels)
を用いて真の値と比較されます。次に、勾配をゼロにし (optimizer.zero_grad()
)、逆伝播を行い (J.backward()
)、重みを更新します (optimizer.step()
)。最後に、各エポックの損失と精度を計算します。さらに、同様のコードを evaluate
関数にも用います。
最後に、訓練の準備ができました!40エポックの学習を行った結果、我々のモデルはテスト精度%でグラフを分類することを学習したことがわかりました。
📝 Go Inoue, Muhammad Osama Khan, Muhammad Shujaat Mirza, Muhammad Muneeb Afzal
🇯🇵 Shiro Takagi
28 Apr 2020