⊢ AI 모델 활용

FastAI: 사전 학습된 모델을 활용한 이미지 분류

최 수빈 2025. 3. 27. 03:44

 

FastAI

 

딥러닝을 빠르고 쉽게 할 수 있게 만들어진 고수준의 Python 라이브러리

기본적으로 PyTorch 위에 만들어졌고, 복잡한 코드 없이도 빠르게 모델을 만들고 학습, 평가, 예측할 수 있게 도와줌

 

  • 간결한 코드: 몇 줄만으로도 데이터 전처리, 모델 학습, 평가 가능
  • 전이학습 기본 내장: resnet, vgg 같은 사전 학습 모델 바로 사용
  • 다양한 모듈 지원: vision, text, tabular, collaborative filtering 등
  • 자동 최적화: 학습률 찾기, 데이터 증강, 조기 종료 등 자동 적용
  • PyTorch 기반: PyTorch의 유연성과 강력함을 그대로 활용 가능

 

주요 모듈

 

fastai.vision.all
이미지 분류, 객체 탐지 등 비전 관련

fastai.text.all
텍스트 분류, 언어 모델 등

fastai.tabular.all

범주형/연속형 데이터 처리

fastai.collab

추천 시스템

Learner 객체

학습을 총괄하는 핵심 객체

DataBlock

데이터셋 구성 정의를 편리하게 함

 

 

FastAI: 사전 학습된 모델을 활용한 이미지 분류 실습

 

FastAI 라이브러리를 활용하여 사전 학습된 CNN 모델로 이미지 분류 수행

PETS 데이터셋을 이용한 고양이 vs. 강아지 분류

전이 학습(Transfer Learning), 파인튜닝(Fine-tuning)

혼동 행렬, 예측 결과 시각화 등 성능 평가 진행

  1. 설치: fastai, torch 설치 및 라이브러리 임포트
  2. 데이터 준비: PETS 데이터셋 다운로드 및 라벨링
  3. 모델 학습: ResNet34 기반 모델 생성 및 학습률 찾기 후 파인튜닝
  4. 성능 평가: 예측 결과 시각화 + 혼동 행렬
  5. 예측: 새로운 이미지 예측
  6. 확장: 다양한 사전 학습 모델로 실험

 

FastAI 설치 및 기본 설정

 

설치

pip install fastai

 

라이브러리 불러오기

from fastai.vision.all import *

 

데이터 다운로드 및 구성

path = untar_data(URLs.PETS)          # PETS 데이터셋 다운로드 및 압축 해제
path_imgs = path/'images'

 

라벨링 함수 정의

def label_pet(x): 
    return "Cat" if x[0].isupper() else "Dog"  # 파일명이 대문자로 시작하면 고양이 아니면 개

 

데이터블록 생성

dls = ImageDataLoaders.from_name_func(
    path_imgs, get_image_files(path_imgs), 
    valid_pct=0.2, seed=42,
    label_func=label_pet, 
    item_tfms=Resize(224)
)

 

데이터 확인

dls.show_batch(max_n=9, figsize=(7, 6))
plt.show()

 

 

사전 학습된 모델 로드 및 학습

 

학습기 생성

# ResNet34 사전 학습된 모델을 사용해 학습기 생성, 학습 중 사용할 성능 평가 지표=오류율
learn = cnn_learner(dls, resnet34, metrics=error_rate)

 

최적 학습률 찾기

learn.lr_find() # 최적 학습률 찾기 (자동으로 찾아줌)

 

파인튜닝

learn.fine_tune(3) # 파인튜닝 (모델 학습) - 3epoch동안 파인튜닝 진행

 

 

모델 평가

 

예측 결과 시각화

learn.show_results()
plt.show()

Cat vs. Dog

 

혼동 행렬

interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()
plt.show()

혼동행렬

 

새로운 이미지 예측

img = PILImage.create('path_to_your_image.jpg')  # 예측할 이미지 경로
pred, _, probs = learn.predict(img)

print(f"Prediction: {pred}, Probability: {probs.max():.4f}")
img.show()

 

 

다른 사전 학습된 모델 사용

FastAI는 다양한 사전 학습된 모델을 지원함 (resnet18, resnet50, densenet121, efficientnet_b0 등)

모델만 교체하면 손쉽게 실험 가능

learn = cnn_learner(dls, vgg16_bn, metrics=error_rate)
learn.fine_tune(3)

 

 

사용자 입력 기반 이미지 예측,분류

 

Utils.py

def label_pet(x):
    return "Cat" if x[0].isupper() else "Dog"

 

main.py

from fastai.vision.all import *
import random
from utils import label_pet

# 저장된 모델 불러오기
learn = load_learner("catdog_model.pkl")

# 검증용 데이터셋 (label 정보 확인을 위해 필요)
path = untar_data(URLs.PETS)
path_imgs = path / "images"


# 데이터블록 재생성 (검증용 이미지 접근용)
dls = ImageDataLoaders.from_name_func(
    path_imgs,
    get_image_files(path_imgs),
    valid_pct=0.2,
    seed=42,
    label_func=label_pet,
    item_tfms=Resize(224),
)


def predict_from_input(learn, dls):
    """
    사용자 입력을 기반으로 랜덤 이미지 예측 또는 경로 기반 이미지 예측 수행 함수
    """
    user_input = input(
        "예측할 이미지 경로 입력 (random 입력 시 랜덤 이미지 사용): "
    ).strip()

    if user_input.lower() == "random":
        # 랜덤 예측
        idx = random.randint(0, len(dls.valid_ds) - 1)
        img, label = dls.valid_ds[idx]
        pred, _, probs = learn.predict(img)

        print("랜덤 이미지 예측 결과")
        print(f"실제 라벨: {label}")
        print(f"예측 결과: {pred}")
        print(f"확률: {probs.max():.4f}")

        img.show()

    else:
        # 이미지 경로 기반 예측
        try:
            img = PILImage.create(user_input)
            pred, _, probs = learn.predict(img)

            print("\n [사용자 이미지 예측]")
            print(f"예측 결과: {pred}")
            print(f"확률: {probs.max():.4f}")
            img.show()

        except Exception as e:
            print(f"이미지 로딩 실패: {e}")


if __name__ == "__main__":
    predict_from_input(learn, dls)

 

predict.py

from fastai.vision.all import *
import random
from utils import label_pet
import matplotlib.pyplot as plt

# 저장된 모델 불러오기
learn = load_learner("catdog_model.pkl")

# 검증용 데이터셋 (label 정보 확인을 위해 필요)
path = untar_data(URLs.PETS)
path_imgs = path / "images"


# 데이터블록 재생성 (검증용 이미지 접근용)
dls = ImageDataLoaders.from_name_func(
    path_imgs,
    get_image_files(path_imgs),
    valid_pct=0.2,
    seed=42,
    label_func=label_pet,
    item_tfms=Resize(224),
)


def predict_from_input(learn, dls):
    """
    사용자 입력을 기반으로 랜덤 이미지 예측 또는 경로 기반 이미지 예측 수행 함수
    """
    user_input = input(
        "예측할 이미지 경로 입력 (random 입력 시 랜덤 이미지 사용): "
    ).strip()

    if user_input.lower() == "random":
        # 랜덤 예측
        idx = random.randint(0, len(dls.valid_ds) - 1)
        img, label = dls.valid_ds[idx]
        pred, _, probs = learn.predict(img)

        print("랜덤 이미지 예측 결과")
        print(f"실제 라벨: {label}")
        print(f"예측 결과: {pred}")
        print(f"확률: {probs.max():.4f}")

        img.show()
        plt.show()

    else:
        # 이미지 경로 기반 예측
        try:
            img = PILImage.create(user_input)
            pred, _, probs = learn.predict(img)

            print("\n [사용자 이미지 예측]")
            print(f"예측 결과: {pred}")
            print(f"확률: {probs.max():.4f}")
            img.show()
            plt.show()

        except Exception as e:
            print(f"이미지 로딩 실패: {e}")


if __name__ == "__main__":
    predict_from_input(learn, dls)

random
Dog