RNN , LSTM関数の使い方(pytorch)



Python
Pythonとは
基本的な使い方
IDLE
Jupyter Notebook
Markdown
コマンドラインで実行
ライブラリのインストール
pipの使い方
numpy , matplotlib等
graphviz
pytorch
Mecab
Pythonの関数:一覧
  共通関数
 ・append , extend
 ・class
 ・copy
 ・csv.reader
 ・csv.writer
 ・def , return
 ・dict , defaultdict
 ・enumerate
 ・exit
 ・for
 ・if
 ・import
 ・in
 ・input
 ・lambda
 ・len
 ・list
 ・min/max
 ・OrderedDict
 ・open/close
 ・os
 ・pickle
 ・print
 ・range
 ・re.split
 ・read/readline
 ・round/floor/ceil
 ・split
 ・sys.argv
 ・time
 ・while
 ・zip
 ・特殊メソッド
  ・__name__
  ・__iter__ , __next__
 ・正規表現、メタ文字
 ・データの型の種類
 ・四則演算 (+ , - , * , /)
 ・コメントアウト (# , ''')
  numpy
 ・append
 ・arange
 ・argmax/argmin
 ・array
 ・asfarray
 ・astype , dtype
 ・digitize
 ・dot
 ・hstack/vstack
 ・linspace
 ・mean
 ・meshgrid
 ・mgrid
 ・ndim
 ・ndmin
 ・pad
 ・prod
 ・random
 ・reshape
 ・savetxt/loadtxt
 ・shape
 ・std
 ・transpose
 ・where
 ・zeros/zeros_like
 scipy
 ・expit
 ・imread
 matplotlib
 ・imshow
 ・figure
 ・pcolormesh
 ・plot
 ・scatter
 scikit-learn
 ・GaussianNB
 ・KMeans
 ・KNeighborsClassifier
 ・SVC
 ・tree
 keras
 chainer
 chainerrl
 pytorch
 ・BCELoss , MSELoss
 ・Embedding
 ・device
 ・Sequential
 ・Dataset, Dataloader
 ・RNN, LSTM
 OpenAI gym
 ・Blackjack-v0
 ・CartPole-v0
 目的別
 ・ステップ関数
 ・1 of K 符号化法
 ・線形補間
 ・配列に番号をつける

公開日:2021/8/7         

In English


■説明

再帰型ニューラルネットワーク(RNN:Recurrent Neural Network)、LSTM(Long Short-Term Memory)の計算を行います。

■RNN関数の使い方

import torch
import torch.nn as nn

class CalRNN(nn.Module):
    def __init__(self , in_size , hidden , output , layer):
        super().__init__()
        self.rnn = nn.RNN(in_size, hidden, layer, batch_first=True, nonlinearity='relu') #デフォルトはtanh
        self.fc = nn.Linear(hidden, output)

    def forward(self, x):
        x_rnn, _ = self.rnn(x, None)
        y = self.fc(x_rnn[:, -1, :])
        return y

model_rnn = CalRNN(10, 64, 1, 2) #入力:10, 中間層:64, 出力:1, 縦方向の層2のNN構築

x = torch.randn(5,1,10) #バッチサイズ5、入力を10とする。
ans = model_rnn(x)
print(ans)

⇒tensor([[-0.0328],
               [ 0.0598],
               [-0.1709],
               [ 0.0603],
               [-0.1373]], grad_fn=<AddmmBackward>)


上記 forwardの部分は以下赤文字部分と等価です。下記記述で上記記述と結果が一致することが確認できます。

    def forward(self, x):
        x_rnn, _ = self.rnn(x, None) # 比較用の記述
        y = self.fc(x_rnn[:, -1, :])
# 比較用の記述

        h0 = torch.zeros(2, 5, 64) # layer, バッチサイズ, 中間層の順
        x2_rnn, h = self.rnn(x, h0)
        y2 = self.fc(x2_rnn[:, -1, :])


        return y , y2


■LSTM関数の使い方

上記に対して、RNNの部分をLSTMに変え、nonlinearityの部分を削除します。(tanhしか使えないため)

import torch
import torch.nn as nn

class CalRNN(nn.Module):
    def __init__(self , in_size , hidden , output , layer):
        super().__init__()
        self.rnn = nn.LSTM(in_size, hidden, layer, batch_first=True) #デフォルトはtanh
        self.fc = nn.Linear(hidden, output)

    def forward(self, x):
        x_rnn, _ = self.rnn(x, None)
        y = self.fc(x_rnn[:, -1, :])
        return y

model_rnn = CalRNN(10, 64, 1, 2) #入力:10, 中間層:64, 出力:1, 縦方向の層2のNN構築

x = torch.randn(5,1,10) #バッチサイズ5、入力を10とする。
ans = model_rnn(x)
print(ans)

⇒tensor([[0.0127],
               [0.0224],
               [0.0103],
               [0.0214],
               [0.0134]], grad_fn=<AddmmBackward>)










サブチャンネルあります。⇒ 何かのお役に立てればと

関連記事一覧



Python
Pythonとは
基本的な使い方
IDLE
Jupyter Notebook
Markdown
コマンドラインで実行
ライブラリのインストール
pipの使い方
numpy , matplotlib等
graphviz
pytorch
Mecab
Pythonの関数:一覧
  共通関数
 ・append , extend
 ・class
 ・copy
 ・csv.reader
 ・csv.writer
 ・def , return
 ・dict , defaultdict
 ・enumerate
 ・exit
 ・for
 ・if
 ・import
 ・in
 ・input
 ・lambda
 ・len
 ・list
 ・min/max
 ・OrderedDict
 ・open/close
 ・os
 ・pickle
 ・print
 ・range
 ・re.split
 ・read/readline
 ・round/floor/ceil
 ・split
 ・sys.argv
 ・time
 ・while
 ・zip
 ・特殊メソッド
  ・__name__
  ・__iter__ , __next__
 ・正規表現、メタ文字
 ・データの型の種類
 ・四則演算 (+ , - , * , /)
 ・コメントアウト (# , ''')
  numpy
 ・append
 ・arange
 ・argmax/argmin
 ・array
 ・asfarray
 ・astype , dtype
 ・digitize
 ・dot
 ・hstack/vstack
 ・linspace
 ・mean
 ・meshgrid
 ・mgrid
 ・ndim
 ・ndmin
 ・pad
 ・prod
 ・random
 ・reshape
 ・savetxt/loadtxt
 ・shape
 ・std
 ・transpose
 ・where
 ・zeros/zeros_like
 scipy
 ・expit
 ・imread
 matplotlib
 ・imshow
 ・figure
 ・pcolormesh
 ・plot
 ・scatter
 scikit-learn
 ・GaussianNB
 ・KMeans
 ・KNeighborsClassifier
 ・SVC
 ・tree
 keras
 chainer
 chainerrl
 pytorch
 ・BCELoss , MSELoss
 ・Embedding
 ・device
 ・Sequential
 ・Dataset, Dataloader
 ・RNN, LSTM
 OpenAI gym
 ・Blackjack-v0
 ・CartPole-v0
 目的別
 ・ステップ関数
 ・1 of K 符号化法
 ・線形補間
 ・配列に番号をつける