⊢ DeepLearning

PyTorch 문법 정리

최 수빈 2025. 3. 22. 23:13

 

모델 구축 및 학습 (Model Building & Training)

 

신경망 기본 구조

torch.nn.Module: 모든 신경망 모델의 기본 클래스

import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(10, 20)

    def forward(self, x):
        return self.layer(x)

 

손실 함수 (Loss Function)

분류 CrossEntropyLoss loss = nn.CrossEntropyLoss()
회귀 MSELoss loss = nn.MSELoss()

 

 

최적화 알고리즘 (Optimizers)

optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # 확률적 경사 하강법 최적화 알고리즘
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Adam 최적화 알고리즘

 

 

 

데이터 로딩 및 전처리

 

커스텀 데이터셋

from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

 

데이터로더

from torch.utils.data import DataLoader

dataset = MyDataset(X, y)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

 

이미지 전처리 (torchvision)

from torchvision import transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

 

 

MPS 사용

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
inputs, labels = inputs.to(device), labels.to(device)

 

 

 

모델 아키텍처별 레이어

 

합성곱 신경망 (CNN)

nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1)

 

순환 신경망 (RNN 계열)

nn.RNN(input_size=10, hidden_size=20, num_layers=2, batch_first=True)
nn.LSTM(...)
nn.GRU(...)

 

트랜스포머 (Transformer)

nn.Transformer(nhead=8, num_encoder_layers=6)
nn.TransformerEncoderLayer(d_model=512, nhead=8)

 

 

 

타 유틸리티

 

저장 & 로드

torch.save(model.state_dict(), "model.pth")
model.load_state_dict(torch.load("model.pth"))
model.eval()

 

학습 / 평가 모드 전환

model.train()  # 학습 시
model.eval()   # 평가 시

 

 

'⊢ DeepLearning' 카테고리의 다른 글

모델평가와 검증  (0) 2025.03.22
하이퍼파라미터 튜닝  (0) 2025.03.22
과적합(Overfitting) 방지 기법  (0) 2025.03.22
전이학습(Transfer Learning)  (0) 2025.03.22
생성형 모델(Generative Models)  (0) 2025.03.22