關(guān)鍵字(keywords):SVM 支持向量機 SMO算法 實現(xiàn) 機器學(xué)習(xí)
如果對SVM原理不是很懂的,可以先看一下入門的視頻,對幫助理解很有用的,然后再深入一點可以看看這幾篇入門文章,作者寫得挺詳細,看完以后SVM的基礎(chǔ)就了解得差不多了,再然后買本《支持向量機導(dǎo)論》作者是Nello Cristianini 和 John Shawe-Taylor,電子工業(yè)出版社的。然后把書本后面的那個SMO算法實現(xiàn)就基本上弄懂了SVM是怎么一回事,最后再編寫一個SVM庫出來,比如說像libsvm等工具使用,呵呵,差不多就這樣。這些是我學(xué)習(xí)SVM的整個過程,也算是經(jīng)驗吧。
下面是SVM的簡化版SMO算法,我將結(jié)合Java代碼來解釋一下整個SVM的學(xué)習(xí)訓(xùn)練過程,即所謂的train訓(xùn)練過程。那么什么是SMO算法呢?
SMO算法的目的無非是找出一個函數(shù)f(x),這個函數(shù)能讓我們把輸入的數(shù)據(jù)x進行分類。既然是分類肯定需要一個評判的標(biāo)準,比如分出來有兩種情況A和B,那么怎么樣才能說x是屬于A類的,或不是B類的呢?就是需要有個邊界,就好像兩個國家一樣有邊界,如果邊界越明顯,則就越容易區(qū)分,因此,我們的目標(biāo)是最大化邊界的寬度,使得非常容易的區(qū)分是A類還是B類。
在SVM中,要最大化邊界則需要最小化這個數(shù)值:

w:是參量,值越大邊界越明顯
C代表懲罰系數(shù),即如果某個x是屬于某一類,但是它偏離了該類,跑到邊界上后者其他類的地方去了,C越大表明越不想放棄這個點,邊界就會縮小
代表:松散變量
但問題似乎還不好解,又因為SVM是一個凸二次規(guī)劃問題,凸二次規(guī)劃問題有最優(yōu)解,于是問題轉(zhuǎn)換成下列形式(KKT條件):
…………(1)
這里的ai是拉格朗日乘子(問題通過拉格朗日乘法數(shù)來求解)
對于(a)的情況,表明ai是正常分類,在邊界內(nèi)部(我們知道正確分類的點yi*f(xi)>=0)
對于(b)的情況,表明了ai是支持向量,在邊界上
對于(c)的情況,表明了ai是在兩條邊界之間
而最優(yōu)解需要滿足KKT條件,即滿足(a)(b)(c)條件都滿足
以下幾種情況出現(xiàn)將會出現(xiàn)不滿足:
yiui<=1但是ai<C則是不滿足的,而原本ai=C
yiui>=1但是ai>0則是不滿足的而原本ai=0
yiui=1但是ai=0或者ai=C則表明不滿足的,而原本應(yīng)該是0<ai<C
所以要找出不滿足KKT的這些ai,并更新這些ai,但這些ai又受到另外一個約束,即

因此,我們通過另一個方法,即同時更新ai和aj,滿足以下等式

就能保證和為0的約束。
利用yiai+yjaj=常數(shù),消去ai,可得到一個關(guān)于單變量aj的一個凸二次規(guī)劃問題,不考慮其約束0<=aj<=C,可以得其解為:
………………………………………(2)
這里
………………(3)
表示舊值,然后考慮約束0<=aj<=C可得到a的解析解為:
…………(4)

對于
那么如何求得ai和aj呢?
對于ai,即第一個乘子,可以通過剛剛說的那幾種不滿足KKT的條件來找,第二個乘子aj可以找滿足條件
…………………………………………………………………………(5)
b的更新:
在滿足條件:
下更新b?!?)
最后更新所有ai,y和b,這樣模型就出來了,然后通過函數(shù):
……………………………………………………(7)
輸入是x,是一個數(shù)組,組中每一個值表示一個特征。
輸出是A類還是B類。(正類還是負類)
以下是主要的代碼段:
-
-
-
-
-
-
- double C = 1;
- double tol = 0.01;
- int maxPasses = 5;
-
-
-
-
-
- double a[] = new double[x.length];
- this.a = a;
-
-
- for (int i = 0; i < x.length; i++) {
- a[i] = 0;
- }
- int passes = 0;
-
-
- while (passes < maxPasses) {
-
- int num_changed_alphas = 0;
- for (int i = 0; i < x.length; i++) {
-
-
- double Ei = getE(i);
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- if ((y[i] * Ei < -tol && a[i] < C) ||
- (y[i] * Ei > tol && a[i] > 0))
- {
-
-
-
-
- int j;
-
-
-
-
- if (this.boundAlpha.size() > 0) {
-
- j = findMax(Ei, this.boundAlpha);
- } else
-
- j = RandomSelect(i);
-
- double Ej = getE(j);
-
-
- double oldAi = a[i];
- double oldAj = a[j];
-
-
-
-
-
- double L, H;
- if (y[i] != y[j]) {
- L = Math.max(0, a[j] - a[i]);
- H = Math.min(C, C - a[i] + a[j]);
- } else {
- L = Math.max(0, a[i] + a[j] - C);
- H = Math.min(0, a[i] + a[j]);
- }
-
-
-
-
-
- double eta = 2 * k(i, j) - k(i, i) - k(j, j);
-
- if (eta >= 0)
- continue;
-
- a[j] = a[j] - y[j] * (Ei - Ej)/ eta;
- if (0 < a[j] && a[j] < C)
- this.boundAlpha.add(j);
-
- if (a[j] < L)
- a[j] = L;
- else if (a[j] > H)
- a[j] = H;
-
- if (Math.abs(a[j] - oldAj) < 1e-5)
- continue;
- a[i] = a[i] + y[i] * y[j] * (oldAj - a[j]);
- if (0 < a[i] && a[i] < C)
- this.boundAlpha.add(i);
-
-
-
-
-
- double b1 = b - Ei - y[i] * (a[i] - oldAi) * k(i, i) - y[j] * (a[j] - oldAj) * k(i, j);
- double b2 = b - Ej - y[i] * (a[i] - oldAi) * k(i, j) - y[j] * (a[j] - oldAj) * k(j, j);
-
- if (0 < a[i] && a[i] < C)
- b = b1;
- else if (0 < a[j] && a[j] < C)
- b = b2;
- else
- b = (b1 + b2) / 2;
-
- num_changed_alphas = num_changed_alphas + 1;
- }
- }
- if (num_changed_alphas == 0) {
- passes++;
- } else
- passes = 0;
- }
-
- return new SVMModel(a, y, b);
運行后的結(jié)果還算可以吧,測試數(shù)據(jù)主要是用了libsvm的heart_scale的數(shù)據(jù)。
預(yù)測的正確率達到73%以上。
如果我把核函數(shù)從線性的改為基于RBF將會更好點。
最后,說到SVM算法實現(xiàn)包,應(yīng)該有很多,包括svm light,libsvm,有matlab本身自帶的svm工具包等。