AI.doll

このブログは僕のためのメモです。

PyTorchでGraph Neural Network

PyTorch Geometricの紹介

概要

M.DefferrardさんやT.Kipfさんによるグラフ信号処理を基にした、GCNなどのグラフを入力できるニューラルネットワークが数年前から注目されている。Chainerを使っている人であれば、Chainer Chemistryというフレームワークでこれらのモデルを構築できる。僕は基本的にPyTorchを使っているため、PyTorch Geometric という PyTorchの拡張ライブラリを使いはじめたので、これを用いてGNNモデルや食わせるデータの作り方を簡単にメモしておく。(基本的に公式のドキュメントの自分が使った部分だけを抜き出した感じ)

インストール方法

インストール方法は公式のInstallationを見れば分かるが、僕がインストールしたときのコマンドを一応ここにも書いておく。
PyTorchのバージョンは1.0.0以上が必要。
あとは以下のコマンドでインストールすれば使える。

pip install --verbose --no-cache-dir torch-scatter
pip install --verbose --no-cache-dir torch-sparse
pip install --verbose --no-cache-dir torch-cluster
pip install torch-geometric

使い方

Data

Dataクラスのオブジェクトを作るときの引数のうちよく使おうであろうものについて説明する。
x: グラフ上の信号のtensor (各ノードに対して定義された特徴量ベクトル)
edge_index : グラフ構造を示すtensorで[[始点のリスト], [終点のリスト]]になっている。( [[始点, 終点], [始点, 終点], ...]ではないので注意 )
edge_attr : エッジ上の特徴量 (辺の重みなど)、[[エッジ数], [特徴量ベクトルの次元]]の形
y : 教師信号などがある場合yとして持っておく

Dataset

まず, 抽象クラスとしてDatasetクラスと、InMemoryDatasetクラスがある。後者はデータが全部メモリに乗り切る場合に使われるクラスで、Datasetを継承している。
ここではDatasetクラスを用いたデータセットの作り方のみ説明する。
大まかなクラスの雛形は以下の通り。

class DatasetClass(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(DatasetClass, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        """ 元ファイル名のリストを作る """

    @property
    def processed_file_names(self):
        """加工済みファイル名のリストを作る"""

    def __len__(self):
        """データの数"""
        return len(self.processed_file_names)

    def _download(self):
        pass

    def process(self):
        """データの加工を行い保存する."""

    def get(self, idx):
        """idx番目のデータを取り出して返す"""

親クラスのinitを呼び出すことで定義される変数と、propertyとして呼び出せるものは以下の通り。 (データを保存しているディレクトリの最上位を'data/'とする)

変数名 内容
self.root initに渡したrootでデータディレクトリの最上位 ('data')
self.raw_dir rootの下のrawディレクトリへ ('data/raw')
self.processed_dir rootの下のprocessedディレクトリ('data/processed')
self.transform initに渡したtransform、データをgetするときに通す関数
self.pre_transform initに渡したpre_transform、データを保存する前に通す関数
self.raw_paths raw_file_namesの各要素の前にself.raw_dirをくっつけたもの(生データへのパス)
self.processed_paths processed_file_namesの各要素の前にself.processed_dirをくっつけたもの(加工して保存したデータへのパス)

_downloadの代わりにdownloadを定義して、self.raw_dirに保存する手順を書くこともできる. データが手許にある場合は上記のように_downloadをoverrideすれば良い。 ここで、保存するデータは前節で説明したDataオブジェクトを想定している。

定義済みのデータセット

ここに書いてあるものはimportすればすぐに使える。

実装されている手法

ChebConv(in_channels, out_channels, K, bias=True)
チェビシェフ多項式でグラフラプラシアンの累乗を効率よく計算するモデル。 in_channel: 入力するグラフのノード上の特徴量ベクトルの次元
out_channel: 出力するグラフのノード上の特徴量ベクトルの次元
K: グラフラプラシアンの何乗まで使うか。各ノードに対してKホップ以内のノードについて畳み込まれる。

GCNConv(in_channels, out_channels, improved=False, cached=False, bias=True)
ChebConvのK=1にしたほか、複数の近似等を含むもの。詳しくはKipfさんの論文を読むと分かる。

気が向いたらsurvey読みながら詳細に書くけどドキュメント読んだほうが正確で早いよ。