[딥러닝 기초] 선형회귀 모델의 개선
다음의 책을 공부하며 정리한 내용입니다
nn.Module 을 사용한 선형회귀 모델의 개선
: 파이토치에서 일부 모델(ex. 선형회귀모델)들은 이미 nn.Module의 형태로 편리하게 쓸 수 있도록 구현되어있다.
즉, 우리가 기존에 구현했던 선형회귀 수식
- y_hat = (w * x_train) + b
- y_hat = x_train.matmul(w) + b
이는 아래와 같이 변경할 수 있다.
- model = nn.Linear(1, 1)
- model = nn.Linear(3, 1)
: 이때 nn.Linear()은 선형회귀 모델을 의미하며 왼쪽부터 순서대로 input dimension
, output dimension
이다.
input_dim : 가중치 w의 개수
output_dim : 가중치 w의 길이
nn.Linear 모델 사용해보기
- nn.Linear( )에는 가중치w와 편향b가 저장되어있다.
- 이는 model.parameters( )로 불러올 수 있다.
1 | # 모델을 선언 및 초기화. 단순 선형 회귀이므로 input_dim=1, output_dim=1. |
1 | [결과] |
기존 선형회귀 코드의 개선
: 기존 선형회귀 코드를 다음과 같이 개선할 수 있다.
개선 1
1 | # 최종코드 |
torch.nn 까지 import
1 | # 최종코드 |
개선 2
1 | for epoch in range(1000): |
nn.Linear( )로 선언한 model로 y_hat 계산
F.mse_loss(prediction, y_train) 파이토치에서 제공하는 평균제곱함수로 cost 계산
1 | model = nn.Linear(1, 1) |
1 | [결과] |
- 위의 코드로 학습한 모델 model은 x_train, y_train에 대해 학습된 값 w,b 를 저장하고 있다.
- 학습된 모델 model을 활용해 새로운 값 x_new 에 대한 예측값 y_pred를 얻을 수 있다.