tft每日頭條

 > 生活

 > 線性回歸模型的求解原理與方法

線性回歸模型的求解原理與方法

生活 更新时间:2024-12-18 16:53:41

在機器學習領域内,經常會聽到回歸問題、線性回歸等專業術語,那什麼是回歸問題,線性回歸又是什麼意思以及怎麼訓練一個模型、怎麼調整評估模型的預測性能呢,本文将結合網上資料及TensorFlow來訓練一個最簡單的線性模型并評估其性能優劣。

01

線性回歸

什麼是回歸問題

回歸分析是确定多個相互依賴的變量之間的定量關系的方法。回歸分析通常用于預測,發現變量之間的因果關系。

線性回歸

線性回歸是回歸分析的一種,它研究的是多個變量之間的線性組合,線性回歸假設的是預測的目标和特征之間是線性相關,通常可以使用以下公式來表達:

線性回歸模型的求解原理與方法(如何快速理解線性回歸及模型訓練)1

其中w表示的是特征x的權重,b表示的是偏差。在幾何中,w又表示直線的斜率,w表示截距。

線性回歸的目的在于,通過大量的訓練數據來找到最合适的w,b的值。機器學習中常見的專業術語“模型”,就是上述類似的一組描述特征和目标之間關系的公式。

機器學習訓練出來的模型,最終需要應用的到驗證數據上進行預測,而模型的優劣好壞如何度量,如何才能找到合适的w,b的值呢?這就需要定量分析,需要引入“損失”、“損失函數”等概念了。

02

損失函數

我們先看下面的一個例子,紅色的直線表示當前的學習模型,藍色的點表示測試數據,通過下圖我們可以看到隻有一個點跟模型拟合,其他的都欠拟合。從直覺上我們知道,當前的模型可能沒那麼好,那麼如何定量的去衡量這個模型的好壞呢?

線性回歸模型的求解原理與方法(如何快速理解線性回歸及模型訓練)2

我們仍然以上圖為例,取x = 12這個樣本值,那麼相應藍色點的y值即為真實值,而相應紅線所在的prediction(x)即表示預測值。那麼y - prediction(x)即為損失。那麼衡量這種損失的函數,我們稱之為損失函數。

目前有這麼幾種損失函數:

  • 平方損失,又稱L2損失
  • 殘差平方和(RSS, Residual of Sum of Square)
  • 均方誤差(MSE, Mean Square Error)
  • 均方根誤差(RMSE, Root Mean Square Error)

其實都很好理解,我們一個一個來看。

平方損失

平方損失最好理解,單個樣本的平方損失,可以有如下公式給出:

線性回歸模型的求解原理與方法(如何快速理解線性回歸及模型訓練)3

其實即為真實值與預測值之間差值的平方。

殘差平方和

殘差平方和指的是将單個樣本的真實值與預測值之差的平方相加求和:

線性回歸模型的求解原理與方法(如何快速理解線性回歸及模型訓練)4

也即所有樣本的平方損失之和。

均方誤差

均方誤差指的是平均平方損失,也即需要計算所有樣本的平方損失之和,然後再除以樣本之和。可以由如下公式給出:

線性回歸模型的求解原理與方法(如何快速理解線性回歸及模型訓練)5

其中(x,y)表示樣本,x指的是樣本的特征值,y指的是樣本的标簽(或者說目标值)。

D表示數據集,prediction(x)指的是模型預測值,N表示樣本個數。

均方根誤差

均方根誤差與均方誤差,從字面意思上來講其實有很大的聯系,均方根誤差即為均方誤差的開平方根。

線性回歸模型的求解原理與方法(如何快速理解線性回歸及模型訓練)6

在實際編程應用中,經常會先求出MSE,然後對其開平方即可得到RMSE。

了解完所有的損失函數之後,我們不禁有個疑問,這些誤差到底在機器學習中起到了什麼作用?

文章的開頭我們指出,我們需要利用損失函數來評估模型的優劣,或者說我們通過損失函數的取值來調整模型的參數,以期獲得最合适的模型。既然損失函數描述的是預測值與真實值的偏離情況,那麼在實際訓練過程中,我們應該找到使損失函數取最小值時取對應模型的參數。

我們假設目标和特征之間存在線性相關性,于是就有公式 y' = wx b (假設隻有一個特征)。以MSE(均方誤差)及線性回歸為例,我們将公式帶入MSE中替換預測值,即可得如下公式:

線性回歸模型的求解原理與方法(如何快速理解線性回歸及模型訓練)7

該公式是以w,b為參數的函數,我們的目标是求取w,b的值使得該函數取最小值。由于N是一個常數,我們可以簡化成下面的公式:

線性回歸模型的求解原理與方法(如何快速理解線性回歸及模型訓練)8

機器學習的模型訓練問題到此就被我們簡化成了求解上面公式(2)的最小值,求解函數最小值的方法有很多,比較常見的有:

  • 最小二乘法
  • 梯度下降法

最小二乘法我們這裡省略不講解,本文隻講解梯度下降法。

03

叠代法與梯度下降

講解梯度下降法之前,我們先了解一下機器學習訓練過程中的叠代法。

在機器學習過程中,我們的會選取一個初始化參數(w,b),然後通過損失函數來不斷地調整參數值(w,b)直到訓練處一個符合預期的模型。下圖是一個簡化版本的機器學習叠代過程:

線性回歸模型的求解原理與方法(如何快速理解線性回歸及模型訓練)9

圖例來自Google機器學習文檔

上圖中,“計算參數更新”就是梯度下降法的本質。在回歸問題中,權重值w的取值和損失之間的關系如下圖,始終是一個凸函數(凸函數,用幾何的方式來理解就是圖形内任意兩點構成的線段仍然被圖形包圍)。也就是說找到了一個極小值點,那麼這個極小值點就是全局最小值點。

線性回歸模型的求解原理與方法(如何快速理解線性回歸及模型訓練)10

圖例來自Google機器學習文檔

上圖中,任意選取一個起點。由于我們能夠看到圖形全貌,我們很快就能找到損失下降的方向,并最終定位到極小值點。但是在實際應用中,我們任意選取的起點後,是不能立刻确定哪個方向是損失下降的方向。舉個例子,常識告訴我們地球是圓的,但站着地球上,我們始終感受不到弧度,也就無從找到某段圓弧的最低點。同樣,我們就需要通過某種方式判斷哪個方向是損失下降的方向。

在熟知的高數中,我們知道導數可以理解成斜率,或者曲線某個點的導數可以理解成改點切線的斜率。導數為正,則曲線遞增;反之,曲線遞減。梯度就可以理解成導數(偏導數)。

梯度下降就是按照負梯度的方向,逐漸的去找到最小值點。

04

模型訓練

了解了上面的線性回歸的基本原理以及模型訓練的目的之後,我們總結一下線性回歸模型訓練的套路:

  • 分析訓練數據,找到數據的合适的分布
  • 确定問題的類型,回歸還是分類問題
  • 加載數據集,輸入訓練數據
  • 開始訓練模型,叠代過程
  • 使用測試數據來分析當前模型的優劣
  • 其他(後面再講)

可以參照google機器學習文檔來詳細了解如何使用tensorflow訓練模型。接下來會寫一篇TensorFlow模型訓練初體驗。

結語

線性回歸是機器學習中常見的一種類型,通過損失函數來反複叠代機器學習模型。在叠代過程中采用了梯度下降法,尋找最優解,整個流程比較容易理解。

,

更多精彩资讯请关注tft每日頭條,我们将持续为您更新最新资讯!

查看全部

相关生活资讯推荐

热门生活资讯推荐

网友关注

Copyright 2023-2024 - www.tftnews.com All Rights Reserved