본문 바로가기
AI_basic/Pytorch

[Pytorch Part.5] Augmentation과 CNN

by hits_gold 2022. 1. 14.
반응형

DATA Augmentation이란?

 복잡한 모델을 만들기 위해서는 다량의 데이터가 필요하지만 우리가 갖고 있는 데이터는 한정적입니다. DATA Augmentation은 이를 보완하기 위해 데이터를 임의로 변형해 데이터의 수를 늘려 다양한 feature를 뽑는 방법입니다.

Data Augmentation 예시

 위 사진과 같이 여러 기법을 사용하는데 대표적인 기법은 다음과 같습니다.

  • Flip : 이미지를 랜덤하게 좌우 또는 상하 반전시킵니다.
  • Scaling : 이미지를 확대 또는 축소시킵니다.
  • Rotation : 이미지를 회전시킵니다.
  • Crop : 이미지의 일정 부분을 잘라 사용합니다.
  • Cutout : 이미지의 일부를 사각형 모양으로 검은색을 칠합니다.
  • Cutmix : 두 이미지를 합쳐 놓고 이미지의 Label을 학습시킬 때 각각의 이미지가 차지하는 비율만큼 학습시키는 방              법입니다.

 

DATA Augmentation in pytorch

이미지 데이터에 DATA Augmentation과 CNN모델을 적용시키기 위해 CIFAR10 데이터를 사용했습니다.

## 패키지 설치 및 gpu 확인
!pip3 install torch
!pip3 install torchvision

import numpy as np
import matplotlib.pyplot as plt

import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision
from torchvision import transforms, datasets

print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))
print(torch.cuda.is_available())
device = torch.device("cuda")  # gpu device 정의

batch_size = 32
epochs = 10

trans = transforms.Compose([            # 전처리 및 Augmentation 적용 메소드
    transforms.RandomHorizontalFlip(),  # 50%확률로 랜덤하게 좌우 반전
    transforms.ToTensor(),              # 0과 1사이의 값으로 정규화하고 Tensor형태로 변환
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 
])             # red, green, blue 순으로 평균과 표준편차가 0.5가 되게 정규화

train_dataset = datasets.CIFAR10(root = "../data/CIFAR_10",
                              train = True,
                              download = True,
                              transform = trans)
test_dataset = datasets.CIFAR10(root = "../data/CIFAR_10",
                              train = False,
                              download = True,
                              transform = trans)

train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                           batch_size = batch_size,
                                           shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                           batch_size = batch_size,
                                           shuffle = True)

 

간단한 CNN모델 구성

 

class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Conv2d(
        in_channels = 3,      # input의 채널 수 설정 -> r, g, b
        out_channels = 8,     # output의 채널 수, 즉 필터의 개수
        kernel_size = 3,      # 필터의 사이즈 -> 3*3
        padding = 1           # padding
    )
    self.conv2 = nn.Conv2d(
        in_channels = 8,      # conv1에서 output의 채널 수가 8 -> conv2 input 채널 수 8
        out_channels = 16,    
        kernel_size = 3,
        padding = 1
    )
    self.pool = nn.MaxPool2d(  # 풀링층 -> maxpooling
        kernel_size = 2,       
        stride = 2
    )
    self.fc1 = nn.Linear(8*8*16, 64)    # fully connected layer
    self.fc2 = nn.Linear(64, 32)
    self.fc3 = nn.Linear(32, 10)

  def forward(self, x):
    x = self.conv1(x)
    x = F.relu(x)
    x = self.pool(x)
    x = self.conv2(x)
    x = F.relu(x)
    x = self.pool(x)

    x = x.view(-1, 8*8*16)
    x = self.fc1(x)
    x = F.relu(x)
    x = self.fc2(x)
    x = F.relu(x)
    x = self.fc3(x)
    x = F.log_softmax(x)
    return x

 

모델 학습과 평가 함수

def train(model, train_loader, optimizer, log_interval):
  model.train()  # 모델을 학습 모드로
  for batch_idx, (image, label) in enumerate(train_loader):
    image = image.to(device)
    label = label.to(device)
    optimizer.zero_grad() 
    output = model(image)
    loss = criterion(output, label)
    loss.backward()
    optimizer.step()

    if batch_idx % log_interval == 0:
      print("Train Epoch : {} [{}/{}{:.0f}%]\tTrain Loss: {:.6f}".format(epoch, batch_idx * len(image), len(train_loader.dataset), 100.*batch_idx/len(train_loader), loss.item()))


def evaluate(model, test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for image, label in test_loader:
      image = image.to(device)
      label = label.to(device)
      output = model(image)
      test_loss += criterion(output, label).item()
      prediction = output.max(1, keepdim = True)[1]
      correct += prediction.eq(label.view_as(prediction)).sum().item()
  test_loss /= len(test_loader.dataset)
  test_accuracy = 100.*correct/len(test_loader.dataset)
  return test_loss, test_accuracy

 

모델 학습 및 평가

model = CNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
criterion = nn.CrossEntropyLoss()

print(model)
for epoch in range(1, epochs+1):
  train(model, train_loader, optimizer, log_interval = 200)
  test_loss, test_accuracy = evaluate(model, test_loader)
  print("\n [EPOCH : {}], \tTest Loss: {:.4f}, \tTest Accuracy : {:.2f} % \n"
  .format(epoch, test_loss, test_accuracy))

 

결과는 Data Augmentation을 적용하기 전보다 Test Accuracy가 66.4% 소폭 상승하였습니다.

반응형