[Pytorch 기초 - 3] MNIST data를 활용하여 Pytorch로 CNN모델 구현 기본
2020. 9. 1. 20:42ㆍDL in Python/Pytorch 기초
PyTorch Layer 이해하기¶
예제 불러오기¶
In [6]:
import torch
from torchvision import datasets, transforms
In [7]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
In [8]:
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('dataset', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor()
])),
batch_size=1)
In [9]:
image, label = next(iter(train_loader))
In [11]:
image.shape, label.shape
Out[11]:
In [12]:
plt.imshow(image[0,0, :, :], 'gray')
plt.show()
각 Layer별 설명¶
- Network 쌓기 위한 준비
In [32]:
import torch
import torch.nn as nn # weight값이 있는 것들
import torch.nn.functional as F # weight값이 없는 것들, pooling, activation func
Convolution¶
- in_channels: 받게 될 channel의 갯수
- out_channels: 보내고 싶은 channel의 갯수
- kernel_size: 만들고 싶은 kernel(weights)의 사이즈
In [14]:
nn.Conv2d(in_channels = 1, out_channels = 20, kernel_size = 5, stride = 1)
Out[14]:
In [15]:
# 위와 동일 Convolution layer
layer = nn.Conv2d(1, 20, 5, 1).to(torch.device('cpu'))
layer
Out[15]:
- weight 확인 & 시각화
In [16]:
weight = layer.weight
weight.shape
Out[16]:
- 여기서 weight는 학습 가능한 상태이기 때문에 바로 numpy로 뽑아낼 수 없음
- detach() method는 그래프에서 잠깐 빼서 gradient에 영향을 받지 않게 함
In [18]:
"weight = weight.numpy() - detach 이전 error"
weight = weight.detach().numpy()
In [19]:
weight.shape
Out[19]:
In [20]:
plt.imshow(weight[0,0, :,:], 'jet')
plt.colorbar()
plt.show()
- output 시각화 준비를 위해 numpy화
In [21]:
output_data = layer(image)
In [22]:
output_data = output_data
In [26]:
output = output_data.cpu().detach().numpy()
In [27]:
output.shape
Out[27]:
- Input으로 들어간 이미지 numpy화
In [28]:
image_arr = image.numpy()
image_arr.shape
Out[28]:
In [31]:
plt.figure(figsize = (15, 30))
plt.subplot(131)
plt.title('input')
plt.imshow(np.squeeze(image_arr), 'gray')
plt.subplot(132)
plt.title('Weight')
plt.imshow(weight[0,0, :, :], 'jet')
plt.subplot(133)
plt.title('Output')
plt.imshow(output[0,0, :, :], 'gray')
plt.show()
# convolution layer의 가중치를 input이 조정받은 것이 output
Pooling¶
input을 먼저 앞에 넣고, 뒤에 kernel 사이즈와 stride를 순서대로 넣음
In [33]:
image.shape
Out[33]:
In [34]:
pool = F.max_pool2d(image, 2, 2)
pool.shape
Out[34]:
- MaxPool Layer는 weight가 없기 때문에 바로 numpy()가 가능
In [36]:
pool_arr = pool.numpy()
pool_arr.shape
Out[36]:
In [37]:
image_arr.shape
Out[37]:
In [38]:
plt.figure(figsize = (10,15))
plt.subplot(121)
plt.title('Input')
plt.imshow(np.squeeze(image_arr), 'gray')
plt.subplot(122)
plt.title('Output')
plt.imshow(np.squeeze(pool_arr), 'gray')
plt.show()
Linear¶
nn.Linear는 2d가 아닌 1d만 들어가기 때문에 .view() 1D로 펼쳐줘야함
In [39]:
flatten = image.view(1, 28 * 28) # (batch_size, flatten_size)
flatten.shape
Out[39]:
In [40]:
lin = nn.Linear(784, 10)(flatten)
lin.shape
Out[40]:
In [41]:
lin
Out[41]:
In [42]:
plt.imshow(lin.detach().numpy(), 'jet')
plt.show()
Softmax¶
결과를 numpy로 꺼내기 위해선 weight가 담긴 Linear에 weight를 꺼줘야함
In [43]:
with torch.no_grad():
flatten = image.view(1, 28 * 28)
lin = nn.Linear(784,10)(flatten)
softmax = F.softmax(lin, dim=1)
In [44]:
softmax
Out[44]:
In [45]:
np.sum(softmax.numpy())
Out[45]:
Layer 쌓기¶
nn 과 nn.functional의 차이점
- nn은 학습 파라미터가 담긴 것
- nn.functional은 학습 파라미터가 없는 것이라 생각하면 간단
In [57]:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1,20,5,1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
# Feature Extraction
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
print(x.shape) ### flatten 할 때, shape을 알아보기 위함
# Fully Connected (Classification)
x = x.view(-1, 4*4*50) # (batch_size, flatten_size)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
- Image를 Model에 넣어서 결과 확인
In [58]:
model = Net()
In [59]:
result = model.forward(image)
In [60]:
result
Out[60]:
In [65]:
image.shape
Out[65]:
In [63]:
model.conv1(image).shape
Out[63]:
'DL in Python > Pytorch 기초' 카테고리의 다른 글
[Pytorch 기초 - 4] MNIST data를 활용하여 CNN모델의 학습과 Optimizer, Evaluation (0) | 2020.09.01 |
---|---|
[Pytorch 기초 - 2] MNIST data를 load와 시각화 작업 구현하기 (0) | 2020.09.01 |
[Pytorch 기초 - 1] Pytorch의 가장 기본적인 함수들 (0) | 2020.09.01 |