ML/모두의 딥러닝

4-1장) 오차 수정하기 : 경사 하강법

busy맨 2023. 7. 6. 16:34

기울기 a를 너무 크게 잡거나 작게 잡으면 오차가 커진다.

기울기 a와 오차의 관계

이 그래프 상에서 오차가 가장 작을 때는 그래프의 가장 아래쪽 볼록한 부분인 m의 위치에 있을 때이다.

평균 제곱 오차를 구하는 방법을 통해 기울기를 적절히 바꾸어 m의 위치에 이르게 하면 최적의 기울기를 찾는 것.

이 과정을 컴퓨터가 판단해야하는데 이런 판단을 하게 하는 방법이 미분 기울기를 이용하는 경사 하강법

 

경사 하강법은 오차의 변화에 따라 이차 함수 그래프를 만들고 적절한 학습률을 설정해 미분 값이 0인 지점을 구하는 것

 

1. 경사 하강법의 개요

  • 경사 하강법(Gradient descent)
    • 1차 근삿값 발견용 최적화 알고리즘
    • 함수의 기울기를 구하고 경사의 반대 방향으로 계속 이동시켜 극값에 이를 때까지 반복
      • 이때의 극값은 m이고, m에서의 기울기는 0이다.

과정

  • 과정
    1. a1에서 미분을 구함
    2. 구해진 기울기의 반대방향으로 얼마간 이동시킨 a2에서 미분을 구함
      • 기울기가 +면 음의 방향, -면 양의 방향
    3. 구한 미분 값이 0이 아니면 1,2 과정 반복
      기울기 a를 변화시켜 m의 값을 찾는다.

 

2. 학습률

  • 학습률(learning rate)
    • 이동 거리를 정해주는 것
    • 딥러닝에서 학습률의 값을 적절히 바꾸면서 최적의 학습률을 찾는 것은 중요한 최적화 과정 중 하나
    • 학습률을 너무 크게 잡으면 한 점으로 수렴하지 않고 발산

학습률을 크게 잡았을 때

3. 과정

  • 평균 제곱 오차의 식에 yi=a*xi+b대입
  • a와 b를 각각 편미분

 

 

code)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

#공부시간 X와 성적 Y의 리스트를 만듭니다.
data = [[2, 81], [4, 93], [6, 91], [8, 97]]
x = [i[0] for i in data]
y = [i[1] for i in data]

#그래프로 나타내 봅니다.
plt.figure(figsize=(8,5))
plt.scatter(x, y)
plt.show()
#리스트로 되어 있는 x와 y값을 넘파이 배열로 바꾸어 줍니다.(인덱스를 주어 하나씩 불러와 계산이 가능해 지도록 하기 위함입니다.)
x_data = np.array(x)
y_data = np.array(y)

# 기울기 a와 절편 b의 값을 초기화 합니다.
a = 0
b = 0

#학습률을 정합니다.
lr = 0.03 

#몇 번 반복될지를 설정합니다.
epochs = 2001 

#경사 하강법을 시작합니다.
for i in range(epochs): # epoch 수 만큼 반복
    y_hat = a * x_data + b  #y를 구하는 식을 세웁니다
    error = y_data - y_hat  #오차를 구하는 식입니다.
    a_diff = -(2/len(x_data)) * sum(x_data * (error)) # 오차함수를 a로 미분한 값입니다. 
    b_diff = -(2/len(x_data)) * sum(error)  # 오차함수를 b로 미분한 값입니다. 
    a = a - lr * a_diff  # 학습률을 곱해 기존의 a값을 업데이트합니다.
    b = b - lr * b_diff  # 학습률을 곱해 기존의 b값을 업데이트합니다.
    if i % 100 == 0:    # 100번 반복될 때마다 현재의 a값, b값을 출력합니다.
        print("epoch=%.f, 기울기=%.04f, 절편=%.04f" % (i, a, b))
        # 앞서 구한 기울기와 절편을 이용해 그래프를 그려 봅니다.
y_pred = a * x_data + b
plt.scatter(x, y)
plt.plot([min(x_data), max(x_data)], [min(y_pred), max(y_pred)])
plt.show()

 

'ML > 모두의 딥러닝' 카테고리의 다른 글

5장) 참 거짓 판단 장치 : 로지스틱 회귀  (0) 2023.07.10
4-2장) 다중 선형 회귀  (0) 2023.07.07
3장) 선형 회귀  (0) 2023.07.06
2장) 딥러닝을 위한 기초 수학  (0) 2023.07.05
1장) 나의 첫 딥러닝  (0) 2023.07.05