GAN(Generative Adversarial Network)とは



機械学習

分類(classification)

ニューラルネットワーク(NN)

クラスタリング

強化学習

敵対的生成ネットワーク

公開日:2021/6/9         

前提知識
 ・ニューラルネットワークとは
 ・Pythonとは
 ・MNISTとは


■GAN(Generative Adversarial Network)とは
GANとは敵対的生成ネットワークといい、ニューラルネットワークを応用して、絵や文章などオリジナルの創作物を作ることができる機械学習のアルゴリズムです。 ディープフェイクと呼ばれているものはこのGANが利用されています。強化学習の枠組みに位置づけされる場合がありますが、強化学習とは異なり、学習器に対して人間が設計した報酬を与えるということはありません。

■考え方
以下図のとおりとなります。生成器は偽画像を生成し、識別器は偽画像と本物画像を比較して、偽画像がどれだけ本物画像に近いかを識別します。 そしてその結果を生成器にフィードバックすることで生成器は偽画像を生成する精度を向上させていきます。一方で識別器自身も識別結果をもとに、識別精度を向上させていきます。

この動作は、偽札づくりの犯人(生成器)と警察官(識別器)によく例えられます。犯人は警察官に見破られないように偽札づくりの精度を上げていき、警察官としても偽札を見破ることができる様に識別能力を高めていきます。



■具体例
MNISTの画像をもとにMNISTに似た画像を生成器に生成させます。生成器には32x68のノイズデータを入力し、そこから32枚の28x28の偽画像を生成します。 次に偽画像と本物の画像を入れ、それぞれの画像が本物である確率(識別率)を算出します。本物の画像は当然本物なのですから識別率は1と思うかもしれませんが、識別器自身も最初は精度は悪く、 本物の画像を学習していく必要があります。偽物画像に対する識別率は生成器にフィードバックし、生成器自身も学習を重ねていき、次第に精度の高い画像を生成することができます。



なお損失関数には、Binary Cross Entropy Loss(交差エントロピー誤差)を用います。

<pythonで実装>
実際にpythonで動かしてみます。ニューラルネットワークはpytorchで動かします。必要なファイル、環境は以下のとおり。

 ・python : 3.9.5 , pytorch:1.8.1
 ・画像データ:mnist_data.zip
 ・プログラムファイル:gan.zip

シミュレーション結果は以下のとおり。200エポック分の学習の推移を表していますが、うまく学習していっているのが分かります。











サブチャンネルあります。⇒ 何かのお役に立てればと

関連記事一覧



機械学習

分類(classification)

ニューラルネットワーク(NN)

クラスタリング

強化学習

敵対的生成ネットワーク