|
關(guān)于決策樹的介紹很多,我用sas宏寫了一個決策樹算法(ID3),主要是印證一下想法。決策樹作為一種非參方法,比較符合人的推理,因此很容易被理解,借助信息熵這個評判標(biāo)準(zhǔn)就可以進行迭代。
決策樹可以用來判斷變量的重要性,完成后可以導(dǎo)出規(guī)則。
從機器學(xué)習(xí)的角度來說,歸納的粒度是重要的,既要有很好的識別度(分類),又要有一定的泛化能力,因此適合的擬合水平是生長或剪枝要考慮的主要問題。
下個版本計劃做的改進:
1、對連續(xù)變量進行切分點尋找
2、決策樹的預(yù)剪枝規(guī)則改進(信息增益率:c4.5、記錄數(shù):過小的記錄可以考慮不生長、純度:不必非得是純的)
3、算法運算時間的檢查(本來覺得EM跑的太慢,但自己寫了也發(fā)現(xiàn)的確比較耗時)
例子可以見youtube的這個視頻(要翻墻) https://www./watch?v=eKD5gxPPeY0
先給兩張圖:我覺得圖1是決策樹比較清晰的一個結(jié)構(gòu),而圖2對于葉子節(jié)點分類的表示也很醒目
圖1、決策樹的結(jié)構(gòu)示例
圖2、決策樹葉子節(jié)點的分類表示(血槽)
以下是根據(jù)Youtube視頻中提的例子進行代碼實現(xiàn):
圖3、數(shù)據(jù)集
圖4、決策樹的生長
輸入數(shù)據(jù)集后,使用sas宏進行實現(xiàn):
主要使用兩個邏輯庫
1、tree用于保存迭代表(iter)以及stack表來控制迭代過程
圖5.tree庫
圖6.iter表
圖7.stack表
2、tree_tmp保存每個數(shù)據(jù)集的 _vlist表(變量列表)和_vlev表(變量水平列表)。
圖8.tree_tmp庫
**********************以下是主程序*****************************
第一部分:數(shù)據(jù)預(yù)處理(set_des這個宏在前面的文章里)
/*1數(shù)據(jù)集描述*/
%set_des(work,dt,10)
/*2描述修改*/
data set_des;
set set_des;
if name eq 'play' then cls_indi=.;
run;
/*3將分類變量和目標(biāo)變量分別存入宏*/
%let tar=play;
proc sql noprint;
select name into :cv_list separated by ' '
from set_des
where cls_indi=1;
select distinct(nobs) into :vobs
from set_des;
quit;
/*4將數(shù)據(jù)集重新命名,以便迭代*/
data dt_1;
set dt;
keep &cv_list &tar;
run;
第二部分:初始化部分數(shù)據(jù)
/*tree庫放迭代的結(jié)果*/
/*tree_tmp庫放迭代的中間結(jié)果*/
/*1初始化iter表*/
%var_ent(dt_1,&tar);
data tree.iter;
length set_from set_name 100; iter=1; set_from="dt_1"; set_name="dt_1"; set_recs=symget('vobs')+0; set_ent=symget("&tar._ent")+0; run; /*2初始化stack表*/ data tree.stack; length set_from set_name var_name 100;
if 0 then do;
set_from='a';
set_name='a';
var_name='a';
iter=0;
set_recs=0;
p=0;
set_ent=0;
weight_ent=0;
end;
if set_from eq '' then delete;
run;
/*3創(chuàng)建競爭變量列表*/
proc sql noprint;
create table tree_tmp.dt_1_vlist as
select lower(name) as var_cand from sashelp.vcolumn
where libname='WORK' and memname='DT_1'
and lower(name) ne "&tar";
quit;
data tree_tmp.dt_1_vlist;
length var_cand 100; set tree_tmp.dt_1_vlist; run; 第三部分:迭代 /*options mprint mlogic symbolgen;*/ /*限定迭代次數(shù),進行迭代*/ %macro split(times); /*根據(jù)iter表進行迭代*/ %do iter=1 %to × proc sql noprint; select count(*) into :_sets_tem from tree.iter where iter=&iter; quit; %let sets=%sysfunc(left(&_sets_tem)); %if &sets > 0 %then %do; proc sql noprint; select set_name into :set1-:set&sets from tree.iter where iter=&iter; quit; /*對本輪迭代中的數(shù)據(jù)集進行遍歷*/ %do xi=1 %to &sets; /*獲取現(xiàn)有變量的水平- 拆分后有些變量水平可能消失了*/ /*查詢當(dāng)前數(shù)據(jù)集的變量水平,使用當(dāng)前循環(huán)的數(shù)據(jù)集作為參數(shù) 并且查詢了以數(shù)據(jù)集名字命名的競爭列表set_vlist */ %cur_lev(&&set&xi); /*cur_lev生成了以數(shù)據(jù)集命名的水平列表 set_lev*/ /*對數(shù)據(jù)集的所有變量及水平遍歷,放入stack表*/ %vvar(&&set&xi,&iter); %cp(&&set&xi,&iter); %end; %end; %end; %mend; %split(5) **********************相關(guān)的宏及解釋************************* 宏1:vvar 對所有變量遍歷 /*數(shù)據(jù)集變量_水平 vvar*/ /*1遍歷所有變量*/ /*參數(shù):迭代次數(shù) 數(shù)據(jù)集*/ %macro vvar(set,iter,vlib=tree_tmp); /*先查詢該數(shù)據(jù)集的競爭變量列表*/ proc sql noprint; select count(*) into :_vars_tem from &vlib..&set._vlist; quit; %let vars=%sysfunc(left(&_vars_tem)); /*如果競爭變量表不為空,則執(zhí)行變量的遍歷*/ %if &vars > 0 %then %do; proc sql noprint; select var_cand into :var1-:var&vars from &vlib..&set._vlist; quit; %do j=1 %to &vars; /*對變量的每個水平進行遍歷*/ /*參數(shù)傳遞:迭代次數(shù) 數(shù)據(jù)集名 變量名*/ %vlev(&set,&&var&j,&iter); %end; %end; %mend; 宏2:vlev 對變量的所有水平遍歷 /*遍歷變量的水平*/ %macro vlev(set,var,iter,vlib=tree_tmp); /*遍歷當(dāng)前變量的所有水平*/ proc sql noprint; select count(*) into :_lev_tem from &vlib..&set._lev where var_cand="&var" ; quit; %let lev=%sysfunc(left(&_lev_tem)); proc sql noprint; select level into :lev1-:lev&lev from &vlib..&set._lev where var_cand="&var" ; quit; /*生成niter=iter+1*/ %let niter_tem=%eval(&iter+1); %let niter=%sysfunc(left(&niter_tem)); /*根據(jù)變量水平生成一些數(shù)據(jù)集*/ %do k=1 %to &lev; /*命名分裂數(shù)據(jù)集*/ data &var.&&lev&k.._&niter.; set &set nobs=mobs; /*獲取分裂前數(shù)據(jù)集的記錄數(shù)*/ if _n_=1 then call symput('mobs',mobs); &var._s='_'||left(&var); if &var._s="&&lev&k"; run; /*同步生成數(shù)據(jù)集的變量列表:從分裂前數(shù)據(jù)集繼承*/ data &vlib..&var.&&lev&k.._&niter._vlist; set &vlib..&set._vlist; if var_cand="&var" then delete; run; /*獲取分裂數(shù)據(jù)集的記錄數(shù)*/ proc sql noprint; select count(*) into :_setobs from &var.&&lev&k.._&niter.; quit; /*獲取分裂數(shù)據(jù)集的熵*/ /*tar是全局宏變量,且不會和局部沖突*/ %var_ent(&var.&&lev&k.._&niter.,&tar); data stack_tem; length set_from set_name var_name 100;
iter=&iter;
set_from=symget('set');
set_name="&var.&&lev&k.._&niter.";
set_recs=symget('_setobs')+0;
p=&_setobs/&mobs;
set_ent=symget("&tar._ent")+0;
weight_ent=p*set_ent;
var_name=symget("var");
run;
/*將分裂數(shù)據(jù)集的信息寫入stack表中*/
proc append base=tree.stack data=stack_tem;
run;
%end;
%mend;
宏3:cur_lev 生成當(dāng)前數(shù)據(jù)集的水平變量(隨著裂變有些變量水平可能會消失)
/*cur_lev*/
%macro cur_lev(set,vlib=tree_tmp);
/*先查詢該數(shù)據(jù)集的競爭變量列表*/
proc sql noprint;
select count(*) into :_vars_tem
from &vlib..&set._vlist;
quit;
%let vars=%sysfunc(left(&_vars_tem));
/*如果競爭變量表不為空,則執(zhí)行變量的遍歷*/
%if &vars > 0 %then %do;
proc sql noprint;
select var_cand into :var1-:var&vars
from &vlib..&set._vlist;
quit;
%do i=1 %to &vars;
data _tmp&i;
length level 100; set &set(keep=&&var&i); level='_'||&&var&i; var_cand="&&var&i"; keep var_cand level; run; %end; data &vlib..&set._lev; length var_cand 100;
set _tmp1-_tmp&vars;
run;
%end;
%dsort(&vlib..&set._lev, var_cand level);
%mend;
宏4:cp 判斷繼續(xù)迭代或終止
/*判別宏*/
/*對于每個數(shù)據(jù)集,在循環(huán)*/
%macro cp(set,iter);
/*取出數(shù)據(jù)集根據(jù)所有可能拆分的進行分析*/
data comp1;
set tree.stack;
/*取出本輪的數(shù)據(jù)集*/
if iter=&iter and set_from="&set";
run;
/*sort是對sort過程簡寫的宏*/
%sort(comp1,var_name);
/*根據(jù)變量進行熵的匯總*/
data comp2;
set comp1;
by var_name;
if first.var_name then ent_sum=0;
ent_sum+weight_ent;
if last.var_name then output;
keep set_from var_name iter ent_sum;
run;
/*取出熵最小的變量(則熵增最大)*/
%sort(comp2,iter ent_sum)
data tem;
length set_from var_name $ 100;
set comp2;
by iter;
if first.iter;
run;
/*將本次循環(huán)勝利的變量輸入win_var保存*/
proc append base=tree.win_var data=tem;
run;
/*將本次所有循環(huán)的變量都放到detail中*/
data tree.win_vars_detail;
set comp2;
run;
/*將數(shù)據(jù)集的熵和分裂后的熵比較,如果符合條件則寫入iter表,進入下一次迭代*/
proc sql noprint;
create table tem1 as
select a.var_name,b.set_ent-a.ent_sum as ent_gap
from tem as a left join
tree.iter as b on a.set_from=b.set_name;
quit;
/*寫入宏變量*/
data _null_;
set tem1;
call symput('ent_gap',ent_gap);
call symput('win_var',var_name);
run;
/*純子集的熵增為0*/
%if &ent_gap > 0.001 %then %do;
data tem2;
set tree.stack(where=(set_from="&set" and var_name="&win_var"));
iter=iter+1;
keep set_from set_name iter set_recs set_ent;
run;
proc append base=tree.iter data=tem2;
run;
%end;
%mend;
************************以下是結(jié)尾及一些解釋***************************
解釋:iter表中的結(jié)果可以構(gòu)成一個樹狀結(jié)構(gòu),每個數(shù)據(jù)集都有其父節(jié)點名稱,所以可以根據(jù)迭代層級從下往上構(gòu)建數(shù)。
另外這個例子中可以劃分成純子集,所以信息增益的閾值隨便設(shè)一個大于0的數(shù)就可以,但是實際上這樣容易過分生長。
結(jié)果:根據(jù)例子,D1-D14是測試集,最后有D15 Outlook(Rain) Humidity(High) Wind(Weak) Play(?),預(yù)測play就是一個決策過程。
根據(jù)前面得到了決策路徑,D15的outlook變量是第一重要的,所以先根據(jù)它轉(zhuǎn)到了rain這個分支,接下來wind變量是最重要的,根據(jù)wind-weak直接可以得到類標(biāo)簽(yes),所以猜測John會去玩。
這個例子還有一些簡化的地方,outlook,humidity,wind三個變量基本上認為是可以完全描述事件的三個正交變量,事實上獲得這個條件本身就是不容易的。
圖9.John的決策路徑
|