TensorDataset , Dataloaderの使い方(pytorch)



Python
Pythonとは
基本的な使い方
IDLE
Jupyter Notebook
Markdown
コマンドプロンプトで実行
仮想環境の構築
仮想環境でIDLEを実行
ライブラリのインストール
pipの使い方
numpy , matplotlib等
graphviz
pytorch
Mecab
Pythonの関数:一覧
共通関数
append , extend
class
copy
csv.reader
csv.writer
def , return
dict , defaultdict
enumerate
exit
for
if
import
in
input
lambda
len
list
min/max
OrderedDict
open/close
os
pickle
print
range
re.split
read/readline
round/floor/ceil
split
sys.argv
time
while
write
zip
・特殊メソッド
 ・__name__
 ・__iter__ , __next__
正規表現、メタ文字
データの型の種類
四則演算 (+ , - , * , /)
コメントアウト (# , ''')
numpy
append
arange
argmax/argmin
array
asfarray
astype , dtype
digitize
dot
hstack/vstack
linalg.solve
linspace
mean
meshgrid
mgrid
ndim
ndmin
pad
poly1d
polyfit
prod
random
reshape
savetxt/loadtxt
shape
std
transpose
where
zeros/zeros_like
scipy
expit
imread
interpolate
matplotlib
imshow
figure
pcolormesh
plot
scatter
scikit-learn
GaussianNB
KMeans
KNeighborsClassifier
SVC
tree
keras
chainer
chainerrl
pytorch
BCELoss , MSELoss
Embedding
device
Sequential
Dataset, Dataloader
RNN, LSTM
OpenAI gym
Blackjack-v0
CartPole-v0
tkinter
frame, grid
画像表示
画像を切り取り表示
画像を保存
目的別
ステップ関数
1 of K 符号化法
線形補間
配列に番号をつける

公開日:2021/8/4         

In English


■説明

TensorDatasetと、Dataloaderは、主に機械学習用に入力データと正解(教師)データを扱う場合に使用する関数です。

■TensorDatasetの説明

入力データと正解データをセットにします。

import numpy as np
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

input = np.random.rand(4, 2) # 入力データ
correct = np.random.rand(4, 1) # 正解データ

input = torch.FloatTensor(input) # pytorchで扱える配列に変更
correct = torch.FloatTensor(correct) # pytorchで扱える配列に変更

print(input)

⇒ tensor([[0.7752, 0.9332],
              [0.5186, 0.1956],
              [0.1267, 0.1171],
              [0.3495, 0.5235]])


print(correct)

⇒ tensor([[0.2506],
              [0.9407],
              [0.9416],
              [0.8879]])


dataset = TensorDataset(input, correct) # データをセットにする
print(vars(dataset)) # varsはオブジェクトの中身を出力する

⇒{'tensors': (tensor([[0.7752, 0.9332], # 入力データ
                            [0.5186, 0.1956],
                            [0.1267, 0.1171],
                            [0.3495, 0.5235]]),
                    tensor([[0.2506], # 正解データ
                            [0.9407],
                            [0.9416],
                            [0.8879]]))}


■Dataloader

上記でセットにした入力データと正解データを読み出します。バッチサイズを入力することで、1回で読み出す数を指定することができます。

train_load = DataLoader(dataset, batch_size=2, shuffle=False) # shuffle=Trueでデータシャッフル

for x, t in train_load:
    print(x)
    print(t)

⇒ tensor([[0.7752, 0.9332], # xの1回目の読み出し
              [0.5186, 0.1956]])
 tensor([[0.2506], # tの1回目の読み出し
              [0.9407]])

   tensor([[0.1267, 0.1171], # 2回目の読み出し
              [0.3495, 0.5235]])
 tensor([[0.9416],
              [0.8879]])










サブチャンネルあります。⇒ 何かのお役に立てればと

関連記事一覧



Python
Pythonとは
基本的な使い方
IDLE
Jupyter Notebook
Markdown
コマンドプロンプトで実行
仮想環境の構築
仮想環境でIDLEを実行
ライブラリのインストール
pipの使い方
numpy , matplotlib等
graphviz
pytorch
Mecab
Pythonの関数:一覧
共通関数
append , extend
class
copy
csv.reader
csv.writer
def , return
dict , defaultdict
enumerate
exit
for
if
import
in
input
lambda
len
list
min/max
OrderedDict
open/close
os
pickle
print
range
re.split
read/readline
round/floor/ceil
split
sys.argv
time
while
write
zip
・特殊メソッド
 ・__name__
 ・__iter__ , __next__
正規表現、メタ文字
データの型の種類
四則演算 (+ , - , * , /)
コメントアウト (# , ''')
numpy
append
arange
argmax/argmin
array
asfarray
astype , dtype
digitize
dot
hstack/vstack
linalg.solve
linspace
mean
meshgrid
mgrid
ndim
ndmin
pad
poly1d
polyfit
prod
random
reshape
savetxt/loadtxt
shape
std
transpose
where
zeros/zeros_like
scipy
expit
imread
interpolate
matplotlib
imshow
figure
pcolormesh
plot
scatter
scikit-learn
GaussianNB
KMeans
KNeighborsClassifier
SVC
tree
keras
chainer
chainerrl
pytorch
BCELoss , MSELoss
Embedding
device
Sequential
Dataset, Dataloader
RNN, LSTM
OpenAI gym
Blackjack-v0
CartPole-v0
tkinter
frame, grid
画像表示
画像を切り取り表示
画像を保存
目的別
ステップ関数
1 of K 符号化法
線形補間
配列に番号をつける