반응형
DATA Augmentation이란?
복잡한 모델을 만들기 위해서는 다량의 데이터가 필요하지만 우리가 갖고 있는 데이터는 한정적입니다. DATA Augmentation은 이를 보완하기 위해 데이터를 임의로 변형해 데이터의 수를 늘려 다양한 feature를 뽑는 방법입니다.
위 사진과 같이 여러 기법을 사용하는데 대표적인 기법은 다음과 같습니다.
- Flip : 이미지를 랜덤하게 좌우 또는 상하 반전시킵니다.
- Scaling : 이미지를 확대 또는 축소시킵니다.
- Rotation : 이미지를 회전시킵니다.
- Crop : 이미지의 일정 부분을 잘라 사용합니다.
- Cutout : 이미지의 일부를 사각형 모양으로 검은색을 칠합니다.
- Cutmix : 두 이미지를 합쳐 놓고 이미지의 Label을 학습시킬 때 각각의 이미지가 차지하는 비율만큼 학습시키는 방 법입니다.
DATA Augmentation in pytorch
이미지 데이터에 DATA Augmentation과 CNN모델을 적용시키기 위해 CIFAR10 데이터를 사용했습니다.
## 패키지 설치 및 gpu 확인
!pip3 install torch
!pip3 install torchvision
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision
from torchvision import transforms, datasets
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))
print(torch.cuda.is_available())
device = torch.device("cuda") # gpu device 정의
batch_size = 32
epochs = 10
trans = transforms.Compose([ # 전처리 및 Augmentation 적용 메소드
transforms.RandomHorizontalFlip(), # 50%확률로 랜덤하게 좌우 반전
transforms.ToTensor(), # 0과 1사이의 값으로 정규화하고 Tensor형태로 변환
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]) # red, green, blue 순으로 평균과 표준편차가 0.5가 되게 정규화
train_dataset = datasets.CIFAR10(root = "../data/CIFAR_10",
train = True,
download = True,
transform = trans)
test_dataset = datasets.CIFAR10(root = "../data/CIFAR_10",
train = False,
download = True,
transform = trans)
train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
batch_size = batch_size,
shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
batch_size = batch_size,
shuffle = True)
간단한 CNN모델 구성
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(
in_channels = 3, # input의 채널 수 설정 -> r, g, b
out_channels = 8, # output의 채널 수, 즉 필터의 개수
kernel_size = 3, # 필터의 사이즈 -> 3*3
padding = 1 # padding
)
self.conv2 = nn.Conv2d(
in_channels = 8, # conv1에서 output의 채널 수가 8 -> conv2 input 채널 수 8
out_channels = 16,
kernel_size = 3,
padding = 1
)
self.pool = nn.MaxPool2d( # 풀링층 -> maxpooling
kernel_size = 2,
stride = 2
)
self.fc1 = nn.Linear(8*8*16, 64) # fully connected layer
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = F.relu(x)
x = self.pool(x)
x = x.view(-1, 8*8*16)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
x = F.log_softmax(x)
return x
모델 학습과 평가 함수
def train(model, train_loader, optimizer, log_interval):
model.train() # 모델을 학습 모드로
for batch_idx, (image, label) in enumerate(train_loader):
image = image.to(device)
label = label.to(device)
optimizer.zero_grad()
output = model(image)
loss = criterion(output, label)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print("Train Epoch : {} [{}/{}{:.0f}%]\tTrain Loss: {:.6f}".format(epoch, batch_idx * len(image), len(train_loader.dataset), 100.*batch_idx/len(train_loader), loss.item()))
def evaluate(model, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for image, label in test_loader:
image = image.to(device)
label = label.to(device)
output = model(image)
test_loss += criterion(output, label).item()
prediction = output.max(1, keepdim = True)[1]
correct += prediction.eq(label.view_as(prediction)).sum().item()
test_loss /= len(test_loader.dataset)
test_accuracy = 100.*correct/len(test_loader.dataset)
return test_loss, test_accuracy
모델 학습 및 평가
model = CNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
criterion = nn.CrossEntropyLoss()
print(model)
for epoch in range(1, epochs+1):
train(model, train_loader, optimizer, log_interval = 200)
test_loss, test_accuracy = evaluate(model, test_loader)
print("\n [EPOCH : {}], \tTest Loss: {:.4f}, \tTest Accuracy : {:.2f} % \n"
.format(epoch, test_loss, test_accuracy))
결과는 Data Augmentation을 적용하기 전보다 Test Accuracy가 66.4% 소폭 상승하였습니다.
반응형
'AI_basic > Pytorch' 카테고리의 다른 글
[Pytorch] 작물 잎 분류 Pre-trained model(resnet50) (1) | 2022.02.07 |
---|---|
[Pytorch] 작물 잎 분류 non Pre_trained model (0) | 2022.02.06 |
[Pytorch Part.4] AutoEncoder (0) | 2022.01.14 |
[Pytorch Part.2] AI Background (0) | 2022.01.06 |
[Pytorch Part.1] Basic Skill (0) | 2022.01.05 |