카테고리 없음

위성 사진 건물 탐지 과제 EfficientNet기반 UNet 구현 -2

yummpu 2023. 8. 3. 18:25

Model 구현에 대해 알아보자

 

 

 

 

 

 

 

 

 

 

 

 


https://github.com/qubvel/segmentation_models.pytorchhttps://solaris.readthedocs.io/en/latest/pretrained_models.html

 

Pretrained models available in solaris — solaris 0.4.0 documentation

Pretrained models available in solaris solaris provides access to a number of pre-trained models from the SpaceNet challenges. See the table below for a summary. Note that the model name in the first column should be used as the "model_name" argument in th

solaris.readthedocs.io

 

 

원래 밑의 spacenet 위성 데이터를 활용한 대회 우승자 모델을 사용하려했다.

근데 나는 colab환경을 쓰는데 이 solaris를 다운받으려고 하니까.. 안됨..ㅋㅋ 

 

그래서 버리고 있다가 나중에 

https://github.com/CosmiQ/solaris

 

GitHub - CosmiQ/solaris: CosmiQ Works Geospatial Machine Learning Analysis Toolkit

CosmiQ Works Geospatial Machine Learning Analysis Toolkit - GitHub - CosmiQ/solaris: CosmiQ Works Geospatial Machine Learning Analysis Toolkit

github.com

여기 들어가서 직접 코드 긁어서 사용했다.

solaris/solaris/nets/zoo/xdxd_sn4

solaris/solaris/nets/zoo/selim_sef_sn4

안을 긁어서 코드로 쓰고

weight는 따로 다운받아서 load하면 된다.

model.load_state_dict(torch.load('xdxd_spacenet4_solaris_weights.pth'))

 

 

 

근데 이거 어차피 vgg_unet, resnet_unet 들이라 굳이 다운받아서 안쓰고 직접 구현 해도 되고 성능도 그닥 그렇게 좋지 않다.

가중치를 얼려서 돌려도 봤는데 그것도 그닥.. 그닥이다

 

참고로 xdxd 수상자 말고 다른건 이상하게 입력이 3개 채널이면 오류가 나고 4개채널이면 오류가 안나더라 

시간이 없어서 그냥 밑으로 구현하고 solaris는 그만 썼다.


 

https://github.com/qubvel/segmentation_models.pytorch

 

GitHub - qubvel/segmentation_models.pytorch: Segmentation models with pretrained backbones. PyTorch.

Segmentation models with pretrained backbones. PyTorch. - GitHub - qubvel/segmentation_models.pytorch: Segmentation models with pretrained backbones. PyTorch.

github.com

그래서 결론은 이걸 사용했다

 

 

%pip install segmentation-models-pytorch
import segmentation_models_pytorch as smp

다운 받아주고

 

모델을 정의해준다.

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

ENCODER = 'timm-efficientnet-b7'
ENCODER_WEIGHTS = 'advprop'
ACTIVATION = 'sigmoid'
DEVICE = 'cuda'

model = smp.Unet(
    encoder_name = ENCODER,
    encoder_weights = ENCODER_WEIGHTS,
    in_channels = 3,
    classes = 1,
    activation = ACTIVATION,
)
model.to(device)

우리팀은 backbone으로 efficientnetb7의 adbporp 사전훈련모델을 사용하였다. 

사이트에 들어가보면 지원해주는 backbone과 decoder목록이 있다. 

확인하고 적용해주면 모델 구현이 끝난다.

 

 

 

 

 

 

모델 확인해보고 싶으면 summary를 사용한다.

from torchsummary import summary
summary(model, (3,224,224))

summary에 정의한 모델, (입력class수,이미지size) 입력해주면 어떤채널인지, 출력이 어떻게 나오는지 확인할 수 있다.

만약 train에서 모델 구조 관련 오류가 뜨면 summary에서 어디서 오류인지 확인해 볼 것

 

 

 

 

 

 

이제 train을 할 것이다.

import segmentation_models_pytorch.utils

# loss function과 optimizer 정의
criterion = smp.utils.losses.DiceLoss()

optimizer = torch.optim.Adam([dict(params=model.parameters(), lr=0.0001),])

# training loop
for epoch in range(20):  # 20 에폭 동안 학습합니다.
    model.train()
    model.to('cuda')
    epoch_loss = 0
    for images, masks in tqdm(dataloader):
        images = images.float().to(device)
        masks = masks.float().to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks.unsqueeze(1))
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {epoch_loss/len(dataloader)}')

dice loss를 segmentation의 util에서 가져왔기 때문에 import해주고 사용했다.

 

 

 

 

 

근데 colab은 자주 끊기니까 고혈압 방지를 위해 print위에 모델 저장을 주기적으로 하자

    %cd /content/drive/MyDrive
    torch.save(model.state_dict(), f'single_model{epoch+1}.pth')
    %cd /content

content에 그대로 저장하면 연결이 끊기면 날라가니까 저장하기전에 경로를 drive로 바꿔주자

이렇게 저장하면 epoch별로 따로 저장되고 epoch마다 갱신하고 싶으면 

 

 torch.save(model.state_dict(), 'single_model.pth')

이런식으로 저장하면 된다.

 

 

 

 

 

 

 

적합한 loss를 찾기위해 dice+focal을 사용하고자 코드를 긁어와서 사용하기도 하였다.

import re

class BaseObject(nn.Module):
    def __init__(self, name=None):
        super().__init__()
        self._name = name

    @property
    def __name__(self):
        if self._name is None:
            name = self.__class__.__name__
            s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
            return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
        else:
            return self._name


class Metric(BaseObject):
    pass


class Loss(BaseObject):
    def __add__(self, other):
        if isinstance(other, Loss):
            return SumOfLosses(self, other)
        else:
            raise ValueError("Loss should be inherited from `Loss` class")

    def __radd__(self, other):
        return self.__add__(other)

    def __mul__(self, value):
        if isinstance(value, (int, float)):
            return MultipliedLoss(self, value)
        else:
            raise ValueError("Loss should be inherited from `BaseLoss` class")

    def __rmul__(self, other):
        return self.__mul__(other)


class SumOfLosses(Loss):
    def __init__(self, l1, l2):
        name = "{} + {}".format(l1.__name__, l2.__name__)
        super().__init__(name=name)
        self.l1 = l1
        self.l2 = l2

    def __call__(self, *inputs):
        return self.l1.forward(*inputs) + self.l2.forward(*inputs)


class MultipliedLoss(Loss):
    def __init__(self, loss, multiplier):

        # resolve name
        if len(loss.__name__.split("+")) > 1:
            name = "{} * ({})".format(multiplier, loss.__name__)
        else:
            name = "{} * {}".format(multiplier, loss.__name__)
        super().__init__(name=name)
        self.loss = loss
        self.multiplier = multiplier

    def __call__(self, *inputs):
        return self.multiplier * self.loss.forward(*inputs)
def _take_channels(*xs, ignore_channels=None):
    if ignore_channels is None:
        return xs
    else:
        channels = [channel for channel in range(xs[0].shape[1]) if channel not in ignore_channels]
        xs = [torch.index_select(x, dim=1, index=torch.tensor(channels).to(x.device)) for x in xs]
        return xs


def _threshold(x, threshold=None):
    if threshold is not None:
        return (x > threshold).type(x.dtype)
    else:
        return x


def iou(pr, gt, eps=1e-7, threshold=None, ignore_channels=None):
    """Calculate Intersection over Union between ground truth and prediction
    Args:
        pr (torch.Tensor): predicted tensor
        gt (torch.Tensor):  ground truth tensor
        eps (float): epsilon to avoid zero division
        threshold: threshold for outputs binarization
    Returns:
        float: IoU (Jaccard) score
    """

    pr = _threshold(pr, threshold=threshold)
    pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)

    intersection = torch.sum(gt * pr)
    union = torch.sum(gt) + torch.sum(pr) - intersection + eps
    return (intersection + eps) / union


jaccard = iou


def f_score(pr, gt, beta=1, eps=1e-7, threshold=None, ignore_channels=None):
    """Calculate F-score between ground truth and prediction
    Args:
        pr (torch.Tensor): predicted tensor
        gt (torch.Tensor):  ground truth tensor
        beta (float): positive constant
        eps (float): epsilon to avoid zero division
        threshold: threshold for outputs binarization
    Returns:
        float: F score
    """

    pr = _threshold(pr, threshold=threshold)
    pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)

    tp = torch.sum(gt * pr)
    fp = torch.sum(pr) - tp
    fn = torch.sum(gt) - tp

    score = ((1 + beta**2) * tp + eps) / ((1 + beta**2) * tp + beta**2 * fn + fp + eps)

    return score


def accuracy(pr, gt, threshold=0.5, ignore_channels=None):
    """Calculate accuracy score between ground truth and prediction
    Args:
        pr (torch.Tensor): predicted tensor
        gt (torch.Tensor):  ground truth tensor
        eps (float): epsilon to avoid zero division
        threshold: threshold for outputs binarization
    Returns:
        float: precision score
    """
    pr = _threshold(pr, threshold=threshold)
    pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)

    tp = torch.sum(gt == pr, dtype=pr.dtype)
    score = tp / gt.view(-1).shape[0]
    return score


def precision(pr, gt, eps=1e-7, threshold=None, ignore_channels=None):
    """Calculate precision score between ground truth and prediction
    Args:
        pr (torch.Tensor): predicted tensor
        gt (torch.Tensor):  ground truth tensor
        eps (float): epsilon to avoid zero division
        threshold: threshold for outputs binarization
    Returns:
        float: precision score
    """

    pr = _threshold(pr, threshold=threshold)
    pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)

    tp = torch.sum(gt * pr)
    fp = torch.sum(pr) - tp

    score = (tp + eps) / (tp + fp + eps)

    return score


def recall(pr, gt, eps=1e-7, threshold=None, ignore_channels=None):
    """Calculate Recall between ground truth and prediction
    Args:
        pr (torch.Tensor): A list of predicted elements
        gt (torch.Tensor):  A list of elements that are to be predicted
        eps (float): epsilon to avoid zero division
        threshold: threshold for outputs binarization
    Returns:
        float: recall score
    """

    pr = _threshold(pr, threshold=threshold)
    pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)

    tp = torch.sum(gt * pr)
    fn = torch.sum(gt) - tp

    score = (tp + eps) / (tp + fn + eps)

    return score
class DiceLoss(Loss):
    def __init__(self, eps=1.0, beta=1.0, ignore_channels=None, **kwargs):
        super().__init__(**kwargs)
        self.eps = eps
        self.beta = beta
        self.ignore_channels = ignore_channels

    def forward(self, y_pr, y_gt):
        return 1 - f_score(
            y_pr,
            y_gt,
            beta=self.beta,
            eps=self.eps,
            threshold=None,
            ignore_channels=self.ignore_channels,
        )
class FocalLoss(Loss):
    def __init__(self, alpha=1, gamma=2, class_weights=None, logits=False, reduction='mean'):
        super().__init__()
        assert reduction in ['mean', None]
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduction = reduction
        self.class_weights = class_weights if class_weights is not None else 1.

    def forward(self, y_pr, y_gt):
        bce_loss = nn.functional.binary_cross_entropy(y_pr, y_gt)

        pt = torch.exp(- bce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
        focal_loss = focal_loss * torch.tensor(self.class_weights).to(focal_loss.device)

        if self.reduction == 'mean':
            focal_loss = focal_loss.mean()

        return focal_loss
dice = DiceLoss().to(device)
focal = FocalLoss().to(device)
criterion = dice + focal

이렇게 사용하였다.

긁은거라 할말이 없다.

 

 

 

 

 

test를 돌리자

학습시킬 필요가 없으니 grad업데이트를 안해도된다.

with torch.no_grad():
    model.eval()
    result = []
    for images in tqdm(test_dataloader):
        images = images.float().to(device)

        outputs = model(images)
        masks = outputs.cpu().numpy()
        masks = np.squeeze(masks, axis=1)
        masks = (masks > 0.35).astype(np.uint8) # Threshold = 0.35

        for i in range(len(images)):
            mask_rle = rle_encode(masks[i])
            if mask_rle == '': # 예측된 건물 픽셀이 아예 없는 경우 -1
                result.append(-1)
            else:
                result.append(mask_rle)

tqdm으로 원하는 dataloader를 가져온다. 

학습은 image와 라벨을 둘다 가져와서 loss를 계산하며 학습하는데 test는 이미지만 가져와서 예측 결과를 내는거니 images만 가져온다.

예측한 결과를 output에 저장하고 확률값에 따라 threshold를 설정하여 0과 1를 구분해준다.

그 후 다시 인코딩 하여 result에 저장한다.

 

 

 

submit = pd.read_csv('./sample_submission.csv')
submit['mask_rle'] = result
submit.to_csv('./submit.csv', index=False)

주어진 sample_sumission.csv파일을 가져와서 mask_rle자리에 결과물을 갱신한다.

그리고 그 결과를 다른 csv파일로 저장한다.

그럼 이 csv파일은 image path와 그에 맞는 예측 라벨로 구성되어 결과를 확인할 수 있다.

 

 

 

 

test_data = pd.read_csv('test.csv')
submit = pd.read_csv('submit.csv')

import matplotlib.pyplot as plt
from google.colab.patches import cv2_imshow

for i in range(5):
  test_image_path = test_data['img_path'][i]
  test_image = cv2.imread(test_image_path)
  test_mask = rle_decode(submit['mask_rle'][i], (224,224))


  plt.figure(figsize=(10,10))
  plt.subplot(131)
  plt.imshow(test_image)
  plt.axis("off")
  plt.subplot(132)
  plt.imshow(test_mask)
  plt.axis("off")

 

이런식으로 !

 

 

 

 

참고로 efficientnet은 깊을 수록 좋다고 하는데 8은 무거워서 batch를 8까지 줄여야해서 안하고 7로 돌렸다.

다른 조합도 많이 써봤는데 이 조합이 젤 괜찮았다.

 

이모델로 고정하고 transform을 바꾸거나 기법을 추가하였다.

다른기법은 다음 포스팅에..