[Pytorch 기초 - 2] MNIST data를 load와 시각화 작업 구현하기
2020. 9. 1. 18:49ㆍDL in Python/Pytorch 기초
PyTorch Data Preprocess¶
In [3]:
import torch
from torchvision import datasets, transforms
Data Loader 부르기¶
파이토치는 DataLoader를 불러 model에 넣음
In [4]:
batch_size = 32
test_batch_size = 32
In [10]:
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('dataset/',train = True,download = True,
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean = (0.5,), std = (0.5,))
])),
batch_size = batch_size,
shuffle = True)
In [12]:
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('dataset',train=False,
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0,5))
])),
batch_size = test_batch_size,
shuffle = True)
첫번재 iteration에서 나오는 데이터 확인¶
In [13]:
images, labels = next(iter(train_loader))
In [17]:
images.shape
# tensorflow에서는 (batch_size, height, width, channel) - (32, 28, 28, 1)
# pytorch에서는 (batch_size, channel, height, width) - (32, 1, 28, 28)
# rgb였으면 1이 아니라 3
# pytorch와 tensorflow의 차이점 !!
Out[17]:
In [18]:
labels.shape
Out[18]:
PyTorch는 TensorFlow와 다르게 [Batch Size, Channel, Height, Width] 임을 명시해야함
데이터 시각화¶
In [19]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
In [20]:
torch_image = torch.squeeze(images[0])
torch_image.shape
Out[20]:
In [22]:
image = torch_image.numpy()
image.shape
Out[22]:
In [23]:
label = labels[0].numpy()
In [24]:
label.shape
Out[24]:
In [25]:
label
Out[25]:
In [27]:
plt.title(label)
plt.imshow(image,'gray')
plt.show()
'DL in Python > Pytorch 기초' 카테고리의 다른 글
[Pytorch 기초 - 4] MNIST data를 활용하여 CNN모델의 학습과 Optimizer, Evaluation (0) | 2020.09.01 |
---|---|
[Pytorch 기초 - 3] MNIST data를 활용하여 Pytorch로 CNN모델 구현 기본 (0) | 2020.09.01 |
[Pytorch 기초 - 1] Pytorch의 가장 기본적인 함수들 (0) | 2020.09.01 |