|
訓練集高分,測試集預測提交后發(fā)現(xiàn)分數(shù)很低,為什么?有可能是訓練集和測試集分布不一致,導致模型過擬合訓練集,個人很不喜歡碰到這種線下不錯但線上抖動過大的比賽,有種讓你感覺好像在“碰運氣”,看誰“碰”對了測試集的分布。但實際是有方法可循的,而不是說純碰運氣。本文我將從“訓練/測試集分布不一致問題”的發(fā)生原因講起,然后羅列判斷該問題的方法和可能的解決手段。 ![]() 一、發(fā)生原因 訓練集和測試集分布不一致也被稱作數(shù)據(jù)集偏移(Dataset Shift)。西班牙格拉納達大學Francisco Herrera教授在他PPT[1]里提到數(shù)據(jù)集偏移有三種類型:
![]() 圖1:協(xié)變量偏移 最常見的有兩種原因[1]:
在分類任務上,有時候官方隨機劃分數(shù)據(jù)集,沒有考慮類別平衡問題,例如: 訓練集類別A數(shù)據(jù)量遠多于類別B,而測試集相反,這類樣本選擇偏差問題會導致訓練好的模型在測試集上魯棒性很差,因為訓練集沒有很好覆蓋整個樣本空間。此外,除了目標變量,輸入特征也可能出現(xiàn)樣本選擇偏差問題,比如要預測泰坦尼克號乘客存活率,而訓練集輸入特征里“性別”下更多是男性,而測試集里“性別”更多是女性,這樣也會導致模型在測試集上表現(xiàn)差。 樣本選擇偏差也有些特殊的例子,之前我參加阿里天池2021“AI Earth”人工智能創(chuàng)新挑戰(zhàn)賽[2],官方提供兩類數(shù)據(jù)集作為訓練集,分別是CMIP模擬數(shù)據(jù)和SODA真實數(shù)據(jù),然后測試集又是SODA真實數(shù)據(jù),CMIP模擬數(shù)據(jù)是通過系列氣象模型仿真模擬得到的,即有偏方法,但選手都會選擇將模擬數(shù)據(jù)加入訓練,因為訓練集真實數(shù)據(jù)太少了,可模擬數(shù)據(jù)的加入也無可避免的引入了樣本選擇偏差。 聊完樣本選擇偏移,我們聊下環(huán)境不平穩(wěn)帶來的數(shù)據(jù)偏移,我想最常見是在時序比賽里了吧,用歷史時序數(shù)據(jù)預測未來時序,未來突發(fā)事件很可能帶來時序的不穩(wěn)定表現(xiàn),這便帶來了分布差異。環(huán)境因素不僅限于時間和空間,還有數(shù)據(jù)采集設備、標注人員等。 二、判斷方法 ![]() 1. KDE (核密度估計)分布圖 當我們一想到要對比訓練集和測試集的分布,便是畫概率密度函數(shù)直方圖,但直方圖看分布有兩點缺陷: 受bin寬度影響大和不平滑,因此多數(shù)人會偏向于使用核密度估計圖(Kernel Density Estimation, KDE),KDE是非參數(shù)檢驗,用于估計分布未知的密度函數(shù),相比于直方圖,它受bin影響更小,繪圖呈現(xiàn)更平滑,易于對比數(shù)據(jù)分布。我研究生的有一門課的小作業(yè)有要去對比直方圖和KDE圖,相信這個能幫助大家更直觀了解到它們的差異: ![]() 圖2:心臟疾病患者最大心率的概率密度函數(shù)分布圖,數(shù)據(jù)源自UCI ML開放數(shù)據(jù)集 這里在略微細講下KDE,我們先看KDE函數(shù): ![]() 圖3:生成KDE的過程呈現(xiàn)[3] 言歸正傳,對比訓練集和測試集特征分布時,我們可以用seaborn.kdeplot()[4]進行繪圖可視化,樣例圖和代碼如下: ![]() 圖4:不同數(shù)據(jù)集下的KDE對比 import numpy as npimport seaborn as snsimport matplotlib.pyplot as plt# 創(chuàng)建樣例特征train_mean, train_cov = [0, 2], [(1, .5), (.5, 1)]test_mean, test_cov = [0, .5], [(1, 1), (.6, 1)]train_feat, _ = np.random.multivariate_normal(train_mean, train_cov, size=50).Ttest_feat, _ = np.random.multivariate_normal(test_mean, test_cov, size=50).T# 繪KDE對比分布sns.kdeplot(train_feat, shade = True, color='r', label = 'train')sns.kdeplot(test_feat, shade = True, color='b', label = 'test')plt.xlabel('Feature')plt.legend()plt.show()2.KS檢驗 KDE是PDF來對比,而KS檢驗是基于CDF(累計分布函數(shù)Cumulative Distribution Function)來檢驗兩個數(shù)據(jù)分布是否一致,它也是非參數(shù)檢驗方法(即不知道數(shù)據(jù)分布情況)。兩條不同數(shù)據(jù)集下的CDF曲線,它們最大垂直差值可用作描述分布差異(見下圖5中的D)。 ![]() 圖5:不同數(shù)據(jù)集下的CDF對比[5] 調(diào)用scipy.stats.ks_2samp()[6]可輕松得到KS的統(tǒng)計值(最大垂直差)和假設檢驗下的p值:
若KS統(tǒng)計值小且p值大,則我們可以接受KS檢驗的原假設H0,即兩個數(shù)據(jù)分布一致。上面樣例數(shù)據(jù)的統(tǒng)計值較低,p值大于10%但不是很高,因此反映分布略微不一致。注意: p值<0.01,強烈建議拒絕原假設H0,p值越大,越傾向于原假設H0成立。 3. 對抗驗證 對抗驗證是個很有趣的方法,它的思路是:我們構建一個分類器去分類訓練集和測試集,如果模型能清楚分類,說明訓練集和測試集存在明顯區(qū)別(即分布不一致),否則反之。具體步驟如下:
相關代碼可參考Qiuyan918在Kaggle的Microsoft Malware Prediction比賽中使用實例代碼[7]。 ![]() 圖6:對抗驗證示意圖 三、解決方法 ![]() 1. 構造合適的驗證集 當出現(xiàn)訓練集和測試集分布不一致的,我們可以試圖去構建跟測試集分布近似相同的驗證集,保證線下驗證跟線上測試分數(shù)不會抖動,這樣我們就能得到穩(wěn)定的benchmark。Qiuyan918在基于對抗驗證的基礎上,提出了三種構造合適的驗證集的辦法:
接下來,我將依次細講上述方法。 (1) 人工劃分驗證集 以時間序列舉例,因為一般測試集也會是未來數(shù)據(jù),所以我們也要保證訓練集是歷史數(shù)據(jù),而劃分出的驗證集是未來數(shù)據(jù),不然會發(fā)生“時間穿越”的數(shù)據(jù)泄露問題,導致模型過擬合(例如用未來預測歷史數(shù)據(jù)),這個時候就有兩種驗證劃分方式可參考使用:
![]() 圖7:劃分時序數(shù)據(jù)的兩種方法 除了時間序列數(shù)據(jù),其它數(shù)據(jù)集的驗證集劃分都要遵循一個原則,即盡可能符合測試集的數(shù)據(jù)模式。像前面提到的2021“AI Earth”人工智能創(chuàng)新挑戰(zhàn)賽中氣象數(shù)據(jù),由于測試集是真實氣象數(shù)據(jù),那么我們劃分驗證集時,更傾向于使用真實氣象數(shù)據(jù)去評估線下模型的表現(xiàn),而不是使用模擬氣象數(shù)據(jù)作為驗證集。 (2) 選擇和測試集最相似的樣本作為驗證集 前面在講對抗驗證時,我們有訓練出一個分類器去分類訓練集和測試集,那么自然我們也能預測出訓練集屬于測試集的概率(即訓練集在'Is_Test’標簽下預測概率),我們對訓練集的預測概率進行降序排列,選擇概率最大的前20%樣本劃分作為驗證集,這樣我們就能從原始數(shù)據(jù)集中,得到分布跟測試集接近的一個驗證集了,具體樣例代碼詳見[7]。之后,我們還可以評估劃分好的驗證集跟測試集的分布狀況,評估方法:將驗證集和測試集做對抗驗證,若AUC越小,說明劃分出的驗證集和測試集分布越接近(即分類器越分不清驗證集和測試集)。 ![]() 圖8:選擇和測試集最相似的樣本作為驗證集 (3) 有權重的交叉驗證 如果我們對訓練集里分布更偏向于測試集分布的樣本更大的樣本權重,給與測試集分布不太一致的訓練集樣本更小權重,也能一定程度上,幫助我們線下得到不易抖動的評估分數(shù)。在lightgbm庫的Dataset初始化參數(shù)中,便提供了樣本加權的參數(shù)weight,詳見文檔[8]。圖7中,對抗驗證的分類器預測訓練集的Is_Test概率作為權重即可。 2. 刪除分布不一致特征 如果我們遇到分布不一致且不太重要的特征,我們可以選擇直接刪去這種特征。該方法在各大比賽中十分常見。例如: 在2018年螞蟻金服風險大腦-支付風險識別比賽中,亞軍團隊根據(jù)特征在訓練集和測試集上的表現(xiàn),去除分布差異較大的特征,如圖9[9]。 圖9:螞蟻金服支付風險識別比賽中刪除分布不一致特征[9] 雖然個人建議的是刪除分布不一致但不太重要的特征,但有時避免不了碰到分布不一致但又很重要的特征,這時候其實就需要自行trade off特征分布和特征重要性的關系了,比如在第四屆工業(yè)大數(shù)據(jù)創(chuàng)新競賽-注塑成型工藝的虛擬量測中,第5名團隊保留了sensor1_mean特征而刪除了pack_press_2特征,盡管他們發(fā)現(xiàn)pack_press_2從實際生產(chǎn)角度和相關性角度都非常重要,可為了提升模型在測試集的泛化能力和分數(shù),他們沒用pack_press_2特征,如圖10[10]。 ![]() 圖10:注塑成型工藝的虛擬量測比賽中刪除分布不一致特征[10] 3. 修正分布不一致的特征輸入 當我們對比觀察訓練集和測試集的KDE時,若發(fā)現(xiàn)對數(shù)據(jù)做數(shù)學運算(例如加減乘除)或?qū)?strong>增刪樣本就能修正分布,使得分布接近一致,那么我們可以試試。比如,螞蟻金服比賽里,亞軍團隊發(fā)現(xiàn)'用戶交易請求'特征在訓練集中包含0、1和-1,而測試集只有1和0樣本,因此他們對訓練集刪去了特征值為-1的樣本,減少該特征在訓練集和測試集的差異[9]。 4. 修正分布不一致的預測輸出 除了對輸入特征進行分布檢查,我們也可以檢查目標特征的分布,看是否存在可修正的空間。這種案例很少見,因為正常情況下,你看不到測試集的目標特征值。在“AI Earth”人工智能創(chuàng)新挑戰(zhàn)賽里,我們有提到官方提供兩類數(shù)據(jù)集作為訓練集,分別是CMIP模擬數(shù)據(jù)和SODA真實數(shù)據(jù),然后測試集又是SODA真實數(shù)據(jù),其中前排參賽者YueTan就將CMIP和SODA的目標特征分布畫在一起,然后發(fā)現(xiàn)SODA的值更集中,且整體分布偏右一些,所以對用CMIP訓練得到的預測值加了一個小的常數(shù),修正CMIP下模型的預測輸出,使得它分布更偏向于SODA分布[11]。 ![]() 圖11:氣象數(shù)據(jù)SODA真實值和CMIP模擬值分布對比[11] 5. 偽標簽 偽標簽是半監(jiān)督方法,利用未標注數(shù)據(jù)加入訓練,我們先看看偽標簽的思路,再討論為什么它可能在一定程度上對分布不一致的數(shù)據(jù)集有幫助。偽標簽最常見的方法是:
TripleLift知乎主提供的入門版?zhèn)螛撕炈悸穲D如下所示,建議有興趣的朋友閱讀他原文[12],他還提供了進階版和創(chuàng)新版的偽標簽技術,值得借鑒學習。 圖12:入門版?zhèn)螛撕炈悸穲D 由上圖我們可以看到,模型的訓練引入了部分測試集的樣本,這樣相當于引入了部分測試集的分布。但需要注意: (1) 相比于前面的方法,偽標簽通常沒有表現(xiàn)的很好,因為它引入的是置信度高的測試集樣本,這些樣本很可能跟訓練集分布接近一致,所以才會預測概率高。因此引入的測試集分布也沒有很不同,所以使用時常發(fā)生過擬合的情況。 (2) 注意引入的是高置信度樣本,如果引入低置信度樣本,會帶來很大的噪聲。另外,高置信度樣本也不建議選取過多加入訓練集,這也是為了避免模型過擬合。 (3) 偽標簽適用于圖像領域更多些,表格型比賽建議最后沒辦法再考慮該方法,因為本人使用過該方法,漲分的可能性都不是很高(也可能是我沒用好)。 6. 其它 在寫文章的時候,我查知乎發(fā)現(xiàn)有個問答《訓練集和測試集的分布差距太大有好的處理方法嗎?》下,知乎主納米醬提到:'特征數(shù)值差距不大,特征相關性差距也不大,但是目標數(shù)值差距過大,這個好辦,改變?nèi)蝿赵O置共同的中間目標,比如你說的目標值是否可以采取相對值,增長率,夏普等,而非絕對值'[13]。這種更改預測目標的方法,可能是發(fā)現(xiàn)更改預測目標后,新的預測目標值分布會變得相對一致,所以才考慮該方法的。但實際中,我沒碰過這種情境,但還是提出來讓大家參考學習下。 四、總結 通過這次整理,我對“訓練集和測試集分布不一致”問題有了一個大致的知識框架,也學到了不少,特別是對抗驗證這塊,希望大家也有所獲,歡迎交流討論。 參考資料 [1] Dataset Shift in Classification: Approaches and Problems - Francisco Herrera, PPT: http://iwann./2011/pdf/InvitedTalk-FHerrera-IWANN11.pdf [2] 2021“AI Earth”人工智能創(chuàng)新挑戰(zhàn)賽 - 阿里天池, 比賽: https://tianchi.aliyun.com/competition/entrance/531871/introduction [3] Kernel Distribution - MathWorks, 文檔: https://www./help/stats/kernel-distribution.html [4] seaborn.kdeplot(), 文檔: http://seaborn./generated/seaborn.kdeplot.html [5] KS-檢驗(Kolmogorov-Smirnov test)-- 檢驗數(shù)據(jù)是否符合某種分布 - Arkenstone, 博客: https://www.cnblogs.com/arkenstone/p/5496761.html [6] scipy.stats.ks_2samp(), 文檔: https://docs./doc/scipy/reference/generated/scipy.stats.ks_2samp.html [7] Adversarial_Validation - Qiuyuan918, 代碼: https://github.com/Qiuyan918/Adversarial_Validation_Case_Study/blob/master/Adversarial_Validation.ipynb [8] lightgbm.Dataset(), 文檔: https://lightgbm./en/latest/pythonapi/lightgbm.Dataset.html#lightgbm.Dataset [9] 螞蟻金服ATEC風險大腦-支付風險識別--TOP2方案 - 吊車尾學院-E哥, 文章: https://zhuanlan.zhihu.com/p/57347243?from_voters_page=true [10] 工業(yè)大數(shù)據(jù)之注塑成型虛擬量測Top5分享 - 公眾號: Coggle數(shù)據(jù)科學 [11] 數(shù)據(jù)敏感度:以AI earth為栗子 - 公眾號: YueTan [12] 偽標簽(Pseudo-Labelling)——鋒利的匕首 - TripleLift, 文章: https://zhuanlan.zhihu.com/p/157325083 [13] 訓練集和測試集的分布差距太大有好的處理方法嗎?- 知乎, 文章: https://www.zhihu.com/question/265829982/answer/1770310534 干貨學習,點贊三連↓ |
|
|