CIFAR-10 데이터셋을 ResNet 모델로 학습 및 평가해보자

김보현·2024년 8월 22일
0

DeepLearning

목록 보기
4/4
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision import models

# 데이터셋에 대한 전처리 과정을 정의하다
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # 데이터 증강: 가로로 뒤집다
    transforms.RandomCrop(32, padding=4),  # 데이터 증강: 무작위로 자르다
    transforms.ToTensor(),  # 데이터를 텐서로 변환하다
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))  # 데이터를 정규화하다
])

# CIFAR-10 데이터셋을 불러오다
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

# ResNet 모델을 불러오다 (사전 훈련된 모델 사용)
net = models.resnet18(pretrained=True)

# CIFAR-10에 맞게 마지막 출력층을 수정하다
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs, 10)  # CIFAR-10은 10개의 클래스가 있으므로 출력층을 수정하다

# 모델을 GPU로 이동시키다 (가능한 경우)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = net.to(device)

# 손실 함수와 옵티마이저를 정의하다
criterion = nn.CrossEntropyLoss()  # 다중 클래스 분류에 적합한 손실 함수
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)  # SGD 옵티마이저 사용

# 모델을 학습시키다
for epoch in range(10):  # 10 에포크 동안 학습시키다
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()  # 옵티마이저의 기울기를 초기화하다

        outputs = net(inputs)  # 모델의 예측을 얻다
        loss = criterion(outputs, labels)  # 손실을 계산하다
        loss.backward()  # 역전파를 통해 기울기를 계산하다
        optimizer.step()  # 옵티마이저를 통해 모델의 매개변수를 업데이트하다

        running_loss += loss.item()
        if i % 100 == 99:  # 100 미니 배치마다 평균 손실을 출력하다
            print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}")
            running_loss = 0.0

print('Finished Training')  # 학습이 완료되다

# 모델을 테스트하다
correct = 0
total = 0
with torch.no_grad():  # 평가 시에는 기울기 계산을 하지 않다
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)  # 예측된 라벨을 얻다
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy of the network on the 10000 test images: {100 * correct / total:.2f}%")  # 정확도를 출력하다
  1. 데이터 전처리: CIFAR-10 데이터셋에 대해 다양한 데이터 증강 및 정규화를 적용하다. 이를 통해 모델의 일반화 성능을 높이다.
  2. 데이터셋 로드: torchvision을 사용해 CIFAR-10 데이터셋을 다운로드하고, 이를 PyTorch DataLoader에 로드하다.
  3. ResNet 모델 불러오기: torchvision.models에서 사전 훈련된 ResNet-18 모델을 불러오다. CIFAR-10에 맞게 출력층을 10개의 클래스로 수정하다.
  4. 모델 학습: 10 에포크 동안 SGD 옵티마이저와 교차 엔트로피 손실 함수를 사용해 모델을 학습시키다. 학습 중에는 매 100 미니 배치마다 손실을 출력하다.
  5. 모델 평가: 학습된 모델을 테스트 세트에 대해 평가하고, 전체 정확도를 출력하다.
profile
Fall in love with Computer Vision

0개의 댓글