본문 바로가기
DL/Code

MDGAN : Mask guided Generation Method for Industrial Defect Images with Non-uniform Structures 코드 구현 및 리뷰

by hits_gold 2024. 1. 24.
반응형

논문 : https://www.mdpi.com/2075-1702/10/12/1239

구현 코드 : https://github.com/hits-gold/MDGAN-pytorch/tree/main?tab=readme-ov-file

 

구현하게 된 배경

인턴으로 근무하던 당시 팀에서 반도체 불량 검출 과제를 하고 있었고, 이 후에도 같은 task를 다수 맡을 예정이었다.이에 '제조업계에서 사용되는 불량 검출 과제'에 선제적 대응을 할 요인을 찾아 대비하는 차원에서 1인 프로젝트를 시작하게 되었다.

 선행연구조사를 통해 결함 이미지에 대한 classification이나 segmentation과 같은 task를 진행할 때 일반적으로 결함 이미지가 적어 데이터 불균형이 있다는 것을 알게 되었다. (불량률을 줄이기 위해 검출을 필요로 하지만 정상적인 제조업 공장에서는 모델이 학습하기에 충분한 결함case가 적다.)

 이를 보완하기 위해 기존 데이터셋의 결함 이미지들을 활용해 새로운 결함 이미지를 만들어내는 생성 모델들에 대한 연구가 많이 진행되고 있었고, 처음에는 모델이 대부분 코드가 공개되어 있을 줄 알고 모델선정까지 진행하려했었다. 하지만 선정한 모델이 코드가 공개되어있지 않았고, 꼭 필요한 모델이란 생각이 들어 구현까지 하게 되었다!

 

모델 선정 배경

  • MDGAN은 기본적으로 binary mask를 통해 특정된 결함 영역에 대한 학습을 진행하고 이를 양품 이미지에 생성한다.
  • 따라서 task마다 다른 유형의 image data의 결함 형태에 맞추어 학습과 생성을 할 수 있는 모델이라 생각했다.
  • binary mask 생성을 위해서는 segmentation labeling이 필요했는데, CVAT가 당시 컴에 세팅되어 있었기 때문에 모델을 선정할 수 있었다.

Implementation 

 

1. Input

MDGAN의 기본적인 메커니즘은 다음과 같다.

  • 학습 : (결함 영역이 양품처럼 메꿔진 결함 이미지) -> (결함 영역에 원래의 결함 그대로 있는 원본 결함 이미지)
  • 생성 : (양품 이미지 + 결함을 생성할 영역을 표시할 binary mask) -> (양품 이미지의 binary mask로 지정된 영역에 결함 생성)

Pseudo-Normal Background(PNB)

 여기서 학습 시 Input은 다음과 같은 과정으로 Pseudo-Normal Background(PNB)를 만들어 사용한다.

  1. 양품 이미지에 Affine변환 적용
  2. (결함 이미지의 배경영역) + (아핀변환된 양품이미지의 결함 이미지에서의 결함영역)

2번 과정의 결과물이 PNB인데, 결국 결함 이미지의 결함 영역을 양품 이미지의 해당 영역으로 대체하여 마치 양품이미지 처럼 보이게 하는 것이다.

  위 사진을 보면 기존 결함 이미지의 결함 영역이 양품 이미지의 해당 영역으로 메꿔진 PNB를 확인할 수 있다. 따라서 input과 target의 paired-image를 얻기 위함이라고 볼 수 있는데, 논문의 저자는 이를 통해 CycleGAN의 의존성을 회피할 수 있다고 말한다.

## dataset/traindataset.py
import os
import glob
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import numpy as np
import cv2


class TrainDataset(Dataset):
    def __init__(self, args):
        self.args = args
        img_paths = os.path.join(self.args.root, f"images/{self.args.defect_type}")
        mask_paths = os.path.join(self.args.root, f"ground_truth/{self.args.defect_type}")

        self.normal = cv2.cvtColor(cv2.imread(self.args.normal_path), cv2.COLOR_BGR2RGB)  # normal image
        self.img_paths = sorted(glob.glob(img_paths + "/*.*"))  # defect images path
        self.mask_paths = sorted(glob.glob(mask_paths + "/*.*"))  # masks path corresponding defect images

    # ---------- Pseudo-Normal background ----------
    # image transforms (defect image transformation)
    def trans(self):
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize(256, transforms.InterpolationMode("nearest")),
                # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]
        )
        return transform

    # image transforms (Normal image affine transformation)
    def normal_trans(self, center: tuple[int, int]):
        if self.args.affine_arg == 0:
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Resize(256, transforms.InterpolationMode("nearest")),
                ]
            )
        else:
            degrees = self.args.affine_arg[0]
            translate = self.args.affine_arg[1]
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.RandomAffine(
                        degrees=degrees, translate=translate, center=center
                    ),
                    transforms.Resize(256),
                ]
            )

        return transform

    # defect area center
    def center(self, mask: np.ndarray)-> tuple[int, int]:
        _, thresh = cv2.threshold(mask, 100, 255, 0)
        contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        cnt = contours[0]
        box = cv2.boxPoints(cv2.minAreaRect(cnt))
        center = tuple(np.int0(np.mean(box, axis=0)))
        return center

    # Pseudo-Normal background(PNB) construction
    def pnb(self, img: torch.tensor, mask: torch.tensor, normal: torch.tensor)-> torch.tensor:
        reverse_mask = torch.sub(1, mask)
        pnb = torch.add(torch.mul(img, reverse_mask), torch.mul(normal, mask))
        return pnb

    # --------------------------------------------------

    def __getitem__(self, idx):
        img = cv2.cvtColor(cv2.imread(self.img_paths[idx]), cv2.COLOR_BGR2RGB)
        mask = cv2.cvtColor(cv2.imread(self.mask_paths[idx]), cv2.COLOR_BGR2GRAY)
        file_name = self.img_paths[idx].split("/")[-1]

        affine_center = self.center(mask)
        normal_trans = self.normal_trans(center=affine_center)
        trans = self.trans()

        img = trans(img)
        mask = trans(mask)
        normal = normal_trans(self.normal)
        pnb_img = self.pnb(img, mask, normal)

        sample = {
            "image": img,
            "pnb_image": pnb_img,
            "mask": mask,
            "file_name": file_name,
        }

        return sample

    def __len__(self):
        return len(self.img_paths)

 

PNB를 만들어 입력하는 과정은 학습이 필요한 과정이 아니어서 CustomDataset에 Train 시 자동으로 수행될 수 있도록 만들었다.

  1. center 함수를 통해 binary mask(결함 영역) 기준 center 좌표를 찾는다.
  2. 찾은 center 좌표를 기준으로 normal_trans 함수로 양품 이미지를 affine변환시킨다.
  3. pnb 함수를 통해 배경 영역은 결함 이미지 그대로, 결함 영역은 affine변환된 양품 이미지가 위치하도록 한다.

2. Generator

 BRM

 

 BRM은 위 이미지와 같은 구조로, 모델이 input으로 받은 mask와 PNB를 활용해 feature map의 결함 영역의 Local feature를 중점적으로 가져가면서 결함과 상관없는 배경 영역을 합성시키는 역할을 한다.

 

 Generator

 

Generator의 핵심 구조는 U-net과 같은 형태로, 각 conv block의 뒤에 BRM이 위치하는 구조이다. 또한, Input으로 입력되는 binary mask와 PNB가 결합하기 전에 binary mask에 가우시안 분포에서 추출한 랜덤한 값이 학습 가능한 fc layer를 거쳐 일종의 scaling과정을 거친다.

## model/utils.py
import torch
import torch.nn as nn


# Generator Downsampling block
class Gdown(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, normalize: bool = True):
        super(Gdown, self).__init__()
        if normalize:
            layers = [
                nn.Conv2d(in_channels, out_channels, 3, 1, 1),
                nn.AvgPool2d(2, 2),
                nn.InstanceNorm2d(in_channels),
                nn.LeakyReLU(inplace=True),
            ]
        else:
            layers = [
                nn.Conv2d(in_channels, out_channels, 3, 1, 1),
                nn.AvgPool2d(2, 2),
                nn.LeakyReLU(inplace=True),
            ]

        self.block = nn.Sequential(*layers)

    def forward(self, x: torch.tensor)-> torch.tensor:

        return self.block(x)


# Generator Upsampling block
class Gup(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, normalize: bool = True):
        super(Gup, self).__init__()
        if normalize:
            layers = [
                nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1),
                nn.Conv2d(out_channels, out_channels, 3, 1, 1),
                nn.InstanceNorm2d(out_channels),
                nn.LeakyReLU(inplace=True),
            ]
        else:
            layers = [
                nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1),
                nn.Conv2d(out_channels, out_channels, 3, 1, 1),
                nn.Tanh(),
            ]

        self.block = nn.Sequential(*layers)

    def forward(self, x: torch.tensor, f: torch.tensor = 0, skip: bool = True)-> torch.tensor:
        if skip:
            x = x + f
        x = self.block(x)

        return x

# Background Replacement module
class BRM(nn.Module):
    def __init__(self, d: int, out_channels: int):
        super(BRM, self).__init__()
        self.k = 2**d
        self.mask_layer = nn.Sequential(
            nn.AvgPool2d(self.k, stride=self.k),
            nn.Conv2d(1, 1, 3, 1, 1),
            nn.Sigmoid(),
        )
        self.background_layer = nn.Sequential(
            nn.AvgPool2d(self.k, stride=self.k),
            nn.Conv2d(3, out_channels, 1, stride=1),
        )

    def forward(self, x: torch.tensor, mask: torch.tensor, pnb_img: torch.tensor)-> torch.tensor:
        f = self.mask_layer(mask)
        background = self.background_layer(pnb_img)

        f_ = (f * -1) + 1
        x = (background * f_) + (x * f)

        return x

# weight initailizing
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

 

 utils.py 파일을 따로 만들어 각 block을 만들고 해당 block들을 불러와 모델을 완성시켰다.

 

## model/generator.py
import torch
import torch.nn as nn
from .utils import Gup, Gdown, BRM


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # Fully-connected layer-> z_randomvector(size=(batchsize, 8)) to image size
        self.noise_layer = nn.Linear(8, 3 * 256**2)

        # G_down(in_channels, out_channels, normalize=True)
        self.down1 = Gdown(3, 64, normalize=False)
        self.down2 = Gdown(64, 128)
        self.down3 = Gdown(128, 256)
        self.down4 = Gdown(256, 256)
        self.down5 = Gdown(256, 512)
        self.down6 = Gdown(512, 512, normalize=False)

        # G_up(in_channels, out_channels, normalize=True)
        self.up1 = Gup(512, 512)
        self.up2 = Gup(512, 256)
        self.up3 = Gup(256, 256)
        self.up4 = Gup(256, 128)
        self.up5 = Gup(128, 64)
        self.up6 = Gup(64, 3, normalize=False)

        # BRM(layer_number, out_channels)
        self.brm_down1 = BRM(1, 64)
        self.brm_down2 = BRM(2, 128)
        self.brm_down3 = BRM(3, 256)
        self.brm_down4 = BRM(4, 256)
        self.brm_down5 = BRM(5, 512)
        self.brm_up1 = BRM(4, 256)
        self.brm_up2 = BRM(3, 256)
        self.brm_up3 = BRM(2, 128)
        self.brm_up4 = BRM(1, 64)
        self.brm_up5 = BRM(0, 3)

    def forward(self, mask: torch.tensor, z: torch.tensor, pnb_img: torch.tensor)-> torch.tensor:
        # add noise on defect area of input image
        x_z = self.noise_layer(z)
        x_z = x_z.view(pnb_img.size(0), 3, 256, 256)
        noise_mask = mask.data
        noise_mask[noise_mask == -1] = 0
        x = (x_z * noise_mask) + pnb_img

        # downsampling
        d1 = self.brm_down1(self.down1(x), mask, pnb_img)
        d2 = self.brm_down2(self.down2(d1), mask, pnb_img)
        d3 = self.brm_down3(self.down3(d2), mask, pnb_img)
        d4 = self.brm_down4(self.down4(d3), mask, pnb_img)
        d5 = self.brm_down5(self.down5(d4), mask, pnb_img)
        d6 = self.down6(d5)

        # upsampling
        u1 = self.up1(d6, skip=False)
        u2 = self.brm_up1(self.up2(u1, d5), mask, pnb_img)
        u3 = self.brm_up2(self.up3(u2, d4), mask, pnb_img)
        u4 = self.brm_up3(self.up4(u3, d3), mask, pnb_img)
        u5 = self.brm_up4(self.up5(u4, d2), mask, pnb_img)
        u6 = self.up6(u5, d1)  # last input of BRM
        gen = self.brm_up5(u6, mask, pnb_img)

        return gen, u6

 

  downsampling에 필요한 block과 upsampling에 필요한 block을 하나씩 만들어 scale 별 block을 만들 때 반복문으로 간단하게 처리하려했으나, 사이에 BRM 모듈과 함께 input binary mask와 PNB가 같이 들어가는 구조라 block을 하나하나 다 지정해줬다.(쫌 아쉽;;)

  noise_layer는 input binary mask에 가우시안 분포에서 추출한 랜덤 값 z가 binary mask에 곱해지기 전에 통과하는 fc layer이다.

  여기서 generator의 최종 output인 gen 이외에 마지막 BRM의 input인 u6는 loss function에 들어가기 위해 같이 return된다.

 

3. Discriminator

 Discriminator의 기본 구조는 PatchGAN을 따르며 각 conv block 사이에 DDM이라는 모듈이 적용되어 있다.

 

 DDM

 

 DDM(double discrimination module)은 discriminator의 feature map을 background와 결함 영역으로 나눈다. 이후 결함 영역에 대한 feature를 추출한 이후 background와 다시 결합시킨다. 이 과정을 DDM없이는 두 개의 discriminator를 사용해야한다.

 이를 통해 local feature와 global feature를 동시에 판별할 수 있다.

 

## model/utils.py
import torch
import torch.nn as nn

# Discriminator Downsampling block
class Ddown(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, normalize: bool = True):
        super(Ddown, self).__init__()
        if normalize:
            layers = [
                nn.Conv2d(in_channels, out_channels, 4, 2, 1),
                nn.InstanceNorm2d(out_channels),
                nn.ReLU(),
            ]
        else:
            layers = [
                nn.Conv2d(in_channels, out_channels, 4, 2, 1),
                nn.ReLU(),
            ]

        self.block = nn.Sequential(*layers)

    def forward(self, x: torch.tensor)-> torch.tensor:

        return self.block(x)

# Double Discrimination module
class DDM(nn.Module):
    def __init__(self, d: int):
        super(DDM, self).__init__()
        self.k = 2**d
        self.mask_layer = nn.Sequential(
            nn.AvgPool2d(self.k, self.k),
            nn.Conv2d(1, 1, 3, 1, 1),
            nn.LeakyReLU(inplace=True),
        )

    def forward(self, x: torch.tensor, mask: torch.tensor)-> torch.tensor:
        mask = self.mask_layer(mask)
        local = x * mask
        x = torch.cat((x, local), dim=1)

        return x


# weight initailizing
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)
## model/discriminator.py
import torch
import torch.nn as nn
from .utils import Ddown, DDM


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        # D_down(in_channels, out_channels, normalize=True)
        self.down1 = Ddown(4, 32, normalize=False)
        self.down2 = Ddown(64, 128)
        self.down3 = Ddown(256, 256)
        self.down4 = Ddown(512, 256)
        self.down5 = Ddown(512, 512)

        # DDM(layer_number)
        self.ddm1 = DDM(1)
        self.ddm2 = DDM(2)
        self.ddm3 = DDM(3)
        self.ddm4 = DDM(4)

        # 1-channel feature map (PatchGAN output)
        self.patch = nn.Conv2d(512, 1, 3, 1, 1, bias=False)

    def forward(self, x: torch.tensor, mask: torch.tensor)-> torch.tensor:
        # PatchGAN discriminator input
        x = torch.cat((x, mask), dim=1)

        d1 = self.ddm1(self.down1(x), mask)
        d2 = self.ddm2(self.down2(d1), mask)
        d3 = self.ddm3(self.down3(d2), mask)
        d4 = self.ddm4(self.down4(d3), mask)
        d5 = self.down5(d4)
        out = self.patch(d5)

        return out

 

 4. Train

Loss function

M = 마스크, D = target, D' = 생성이미지1, D^r = 생성이미지2 d = 마지막 BRM의 input

1. reconstruction loss / Normal background loss

 

 

 기본적으로 MDGAN의 Loss는 Generator로 두 개의 이미지를 생성해 계산한다. reconstruction loss와 adversarial loss는 기본적인 GAN 모델에서 사용되는 loss이다(adversarial loss 설명은 생략한다). 여기서 reconstruction loss는 결함영역과 배경영역(Normal background loss,r2와 r3에 해당 배경영역은 사실 경계면의 자연스러운 합성을 제외하고는 그대로 나와도 될 것 같은데, 모델 구조상 어느 정도의 생성이 들어가 Loss가 들어감)의 Loss를 따로 계산한다.

 Reconstruction Loss의 Normal background loss를 계산할 때, 마지막 BRM의 의존성을 피하고 결함 영역과 final output의 normal background와의 일치성을 위해 마지막 BRM의 input인 dl을 사용한 loss를 추가한다.

 

2. Gradient Loss

 gradient loss는 replacement 이 후 결함 영역과 background의 경계선의 자연스러움을 위해 계산한다. 여기서의 gradient는 기울기가 아니라, 이미지의 픽셀 값들이 인근에 위치한 픽셀 값들과의 차이를 의미한다. torchmetrics.functional import의 image_gradients 모듈을 사용해 구현했는데, 해당 모듈에서는 오른쪽 픽셀과 아래쪽 픽셀과의 차이를 반환한다. 따라서 해당 모듈의 input과 같은 사이즈의 dx와 dy tensor를 반환한다.

 위에서 gradient loss의 수식을 살펴보면, binary mask의 기울기를 계산하고 0인 값을 1로 바꾸어주는데, 0과 1로 이루어진 binary mask의 gradient를 계산하면 배경과 결함 영역의 경계가 1로 계산되고, 나머지 모든 영역이 0이 된다(이후 Non-zero to one 연산이 있긴하다). 이렇게 변형된 mask를 통해 target image와 generated image에서 경계에 해당하는 부분의 gradient를 계산해낼 수 있다.

 

3. WGAN-gp

 WGAN-gp는 학습 프로세스를 안정시키기 위해 적용하는데, Discriminator 내에서만 back-propagation을 진행한다.

 

4. weight parameter

 

위 수식은 전체 Loss function인데, 각 세부 function들의 중요도를 조절하기 위한 gamma parameter가 적용되어 있다.

import torch
import torch.nn as nn
from torchvision.utils import save_image
from torch import optim
import torch
from torch.utils.tensorboard import SummaryWriter
from torchmetrics.functional import image_gradients
from torch.autograd import Variable
import torch.autograd as autograd

import os
import time
import numpy as np
from datetime import datetime

from model import Generator, Discriminator, weights_init_normal
from dataset import get_loader


def compute_gradient_penalty(D, real_img, fake_img, mask):
    """Calculates the gradient penalty loss for WGAN GP"""
    cuda = True if torch.cuda.is_available() else False
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    # Random weight term for interpolation between real and fake samples
    alpha = torch.rand(real_img.shape[0], 1, 1, 1)
    alpha = alpha.expand_as(real_img).cuda()
    # Get random interpolation between real and fake samples
    interpolates = (
        (alpha * real_img + ((1 - alpha) * fake_img)).requires_grad_(True).cuda()
    )
    d_interpolates = D(interpolates, mask)
    fake = Variable(Tensor(real_img.shape[0], 1, 8, 8).fill_(1.0), requires_grad=False)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty


def trainer(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    save_path = f"./result/{args.exp}"
    os.makedirs(save_path, exist_ok=True)
    os.makedirs(save_path + "/img", exist_ok=True)
    os.makedirs(save_path + "/model", exist_ok=True)

    writer = SummaryWriter(f"./logs/{args.exp}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    netG = Generator().to(device)
    netD = Discriminator().to(device)

    # Loss
    L1_loss = nn.L1Loss()
    mse_loss = nn.MSELoss()
    criterion_D = nn.MSELoss()

    # Optimizer
    optimizer_G = optim.Adam(netG.parameters(), lr=args.lr, betas=args.betas)
    optimizer_D = optim.Adam(netD.parameters(), lr=args.lr, betas=args.betas)

    # loss weight
    gamma_r = args.gammas[0]
    gamma_d = args.gammas[1]
    gamma_g = args.gammas[2]
    gamma_gp = args.gammas[3]

    # Initialize weights
    netG.apply(weights_init_normal)
    netD.apply(weights_init_normal)

    netG.train()
    netD.train()

    loader = get_loader(args)

    for epoch in range(args.epochs):
        epoch_start = time.time()
        for idx, sample in enumerate(loader):
            iter_start = time.time()

            pnb_img = sample["pnb_image"].to(device)
            img = sample["image"].to(device)
            mask = sample["mask"].to(device)  # mask for loss [0, 1]
            input_mask = mask.data
            input_mask[input_mask == 0] = -1
            input_mask.requires_grad_(True).to(device)  # mask for input [-1, 1]
            reverse_mask = 1 - mask

            z1 = torch.normal(0, torch.var(pnb_img), (img.shape[0], 8)).to(
                device
            )  # random vector(latent dimension : 8)
            z2 = torch.normal(0, torch.var(pnb_img), (img.shape[0], 8)).to(device)

            # -------------------------------- Train --------------------------------
            # PatchGAN Discriminator Label
            real_label = torch.ones(pnb_img.shape[0], 1, 8, 8, requires_grad=False).to(
                device
            )

            # -----------------
            #  Forward pass
            # -----------------
            # Generate
            gen1, dl1 = netG(input_mask, z1, pnb_img)  # gen -> generated image, dl -> input of last BRM
            gen2, dl2 = netG(input_mask, z2, pnb_img)  # gen2 -> for diversity

            # Discriminate
            out_dis1 = netD(gen1, input_mask)
            out_dis2 = netD(gen2, input_mask)
            out_real = netD(img, input_mask)
            # -----------------
            #  Train Generator
            # -----------------
            optimizer_G.zero_grad()

            # Loss
            # reconstruction loss
            recon_loss = 5 * L1_loss(gen1 * mask, img * mask)  # reconstruction loss
            nb_loss1 = L1_loss(gen1 * reverse_mask, img * reverse_mask) + \
                       L1_loss(gen2 * reverse_mask, img * reverse_mask)  # normal background loss
            nb_loss2 = L1_loss(dl1 * reverse_mask, img * reverse_mask) + \
                       L1_loss(dl2 * reverse_mask, img * reverse_mask)  # normal background loss (using input of last BRM)
            loss_r = recon_loss + nb_loss1 + nb_loss2

            # diversity loss
            loss_div = -L1_loss(gen1 * mask, gen2 * mask)

            # gradient loss
            img_dy, img_dx = image_gradients(img)
            gen1_dy, gen1_dx = image_gradients(gen1)
            gen2_dy, gen2_dx = image_gradients(gen2)
            mask_dy, mask_dx = image_gradients(mask)
            mask_dy[mask_dy != 0] = 1  # non-zero elements to 1
            mask_dx[mask_dx != 0] = 1

            loss_grad1 = (mse_loss(img_dy * mask_dy, gen1_dy * mask_dy)
                          + mse_loss(img_dx * mask_dx, gen1_dx * mask_dx)) * 0.5
            loss_grad2 = (mse_loss(img_dy * mask_dy, gen2_dy * mask_dy)
                          + mse_loss(img_dx * mask_dx, gen2_dx * mask_dx)) * 0.5
            loss_grad = loss_grad1 + loss_grad2

            # adversarial loss
            loss_adv1 = criterion_D(out_dis1, real_label)
            loss_adv2 = criterion_D(out_dis2, real_label)
            loss_adv = loss_adv1 + loss_adv2

            # Total loss
            loss_G = loss_r * gamma_r + loss_div * gamma_d + loss_adv + loss_grad * gamma_g

            loss_G.backward()
            optimizer_G.step()

            # -----------------
            #  Train Discriminator
            # -----------------
            optimizer_D.zero_grad()

            # Loss
            # gradient-penalty loss
            loss_gp1 = compute_gradient_penalty(netD, img.data, gen1.data, mask.data)
            loss_gp2 = compute_gradient_penalty(netD, img.data, gen2.data, mask.data)
            loss_gp = loss_gp1 + loss_gp2

            out_real = netD(img, input_mask)
            loss_adv3 = 2 * criterion_D(out_real, real_label)

            # Total loss
            loss_D = loss_gp * gamma_gp + loss_adv3
            loss_adv = loss_adv + loss_adv3
            loss_D.backward()
            optimizer_D.step()

            # -------------------------------- save --------------------------------
            # -----------------
            #  Log print & save
            # -----------------
            # print log
            iter_end = time.time() - iter_start
            if idx + 1 == len(loader):
                epoch_end = time.time() - epoch_start
                loss_log = (
                    "[Epoch %d/%d] [D loss: %f] [G loss: %f, reconstruction: %f, adv: %f, diversity: %f, gradient: %f] Time: %s"
                    % (
                        epoch + 1,
                        args.epochs,
                        loss_G.item(),
                        loss_D.item(),
                        loss_r.item(),
                        loss_adv.item(),
                        loss_div.item(),
                        loss_grad.item(),
                        epoch_end,
                    )
                )
                # tensorboard
                writer.add_scalar("loss_G/train", loss_G.item(), epoch + 1)
                writer.add_scalar("loss_D/train", loss_D.item(), epoch + 1)
                writer.add_scalar("reconstruction_loss", loss_r.item(), epoch + 1)
                writer.add_scalar("adversarial_loss", loss_adv.item(), epoch + 1)
                writer.add_scalar("diversity_loss", loss_div.item(), epoch + 1)
                writer.add_scalar("gradient_loss", loss_grad.item(), epoch + 1)

            else:
                loss_log = (
                    "[Iter %d/%d] [D loss: %f] [G loss: %f, reconstruction: %f, adv: %f, diversity: %f, gradient: %f] Time: %s"
                    % (
                        idx + 1,
                        len(loader),
                        loss_G.item(),
                        loss_D.item(),
                        loss_r.item(),
                        loss_adv.item(),
                        loss_div.item(),
                        loss_grad.item(),
                        iter_end,
                    )
                )
            print(loss_log)

            # save log
            f = open(save_path + "/log.txt", "a")
            if idx == epoch == 0:
                f.write(f"{args.exp}  |  " + str(datetime.now()) + "\n")
                f.write(
                    f"image source : {args.root}, defect type : {args.defect_type}, num_epochs : {args.epochs}, batch_size : {args.batch_size}"
                )
                f.write("\n")
            f.write(loss_log + "\n")
            f.close()

            # -----------------
            #  Model & Image save
            # -----------------

            if (epoch + 1) % args.save_epoch == 0:
                torch.save(
                    {
                        "optimizer_G_state_dict": optimizer_G.state_dict(),
                        "model_G_state_dict": netG.state_dict(),
                        "optimizer_D_state_dict": optimizer_D.state_dict(),
                        "model_D_state_dict": netD.state_dict(),
                    },
                    save_path + f"/model/{args.exp}_{epoch+1}.pth",
                )

                save_image(
                    [img.data[0], pnb_img.data[0], gen1.data[0]],
                    save_path + f"/img/epoch_{epoch+1}.png",
                    nrows=3,
                )

 

 trainer 함수 내에 학습 과정이 전부 들어가 있는데, input에 사용되는 z인자도 이 과정에서 추출하는 것으로 정했다. 

 

 

Conclusion

 이 글에 가져온 코드 외에 학습 및 생성을 위한 모든 코드는 상단에 링크를 첨부한 github에서 확인할 수 있다. 당시 다른 프로젝트들과 병행하느라 짬짬이 시간을 내서 했던 기억이 있는데, 모델선정+코드구현이 총 일주일 정도 걸렸던 것 같다.

 

 실제 이미지를 학습하고 생성한 결과 샘플은 github README에서 확인할 수 있다. 이 외에 구현을 하면서 고려한 요소나 구현 환경 등은 다음과 같다.

  • Method에서 언급한 pseudo-normal background를 형성하는 과정에서 OpenCV의 getRotationMatrix2D()와 warpAffine() 사용
  • training image는 rotation, flipping, random cropping을 통해 augment
  • binary mask는 [-1, 1]로 정규화 후 모델에 input으로 적용, loss 계산 시 [0, 1]로 변환
  • 12개의 convolution layers의 output은 64, 128, 256, 256, 512, 512, 512, 256, 256, 128, 64로 구성.
  • discriminator의 경우 32, 128, 256, 256, 512, 1로 구성.
  • latent space dimension은 8
  • 전체 loss에 쓰인 가중치 파라미터는 r=10, d=15, g=10, gp=10
  • Adam optimizer는 B1=0.5, B2=0.999이고 batch size=20, lr=0.0004, iterations=500
  • GPU = RTX 3090, CPU = Intel(R) Xeon(R) Gold 622306R

 

반응형

'DL > Code' 카테고리의 다른 글

YOLOv3 Pytorch 코드 리뷰  (0) 2024.01.26
YOLOv1 Pytorch 코드 리뷰  (1) 2024.01.25
Transformer Pytorch 코드 리뷰  (1) 2024.01.24