昨日は中国で受講している「PyTorchを使ったディープラーニング」(深度学习之PyTorch实践篇)セミナーの最終回でした。テーマは敵対生成ネットワーク(GAN)。AIの中でも今一番注目されている技術で、画像生成や文章作成などで使われています。
敵対生成ネットワークで出来ること
敵対生成ネットワークで出来ることは、回帰、分類、クラスタリングなどの今までの機械学習で行ってきた内容のほか
- 画像→文章:(例)画像から内容を説明する文章を作成
- 画像→画像:(例)白黒の画像に色を着色する
- 文章→文章:(例)単語から、その単語が答えになる問題文を作成
など複雑な処理を実現することができます。
「過去10年の機械学習の世界で1番面白いアイディア」と言われています。下記動画はGANによる画像ですが、AIで生成された人間とは全く分かりません。
Progressive Growing of GANs for Improved Quality, Stability, and Variation
敵対生成ネットワークの仕組み
敵対生成ネットワークはgeneratorとdiscriminatorという2つのネットワークから構成されています。
基本思想は次のようなイメージです。
- 偽造者(generator)が警察をだまそうと本物に似せた偽札を作る
- 警察(discriminator)はお札が本物か偽物かを見分けようとする
上記を複数回繰り返すことにより、偽造者(generator)と警察(discriminator)ともに能力があがり、最終的には本物そっくりの偽札ができる。
PyTorchでの実装
プログラム実習ではGANで正規分布(ガウス分布)の乱数を作成します。
np.random.normalで作成するsize=500の正規分布データに近づくようgenerator、discriminatorの訓練を行います。
エポック数=1000
エポック数=8000
回数を重ねることで正規分布に近い値が出力されるようになりました。