遗传算法选取最优参数MATLAB程序

更新时间:2024-06-16 15:13:01 阅读量: 综合文库 文档下载

说明:文章内容仅供预览,部分内容可能不全。下载后的文档,内容与下面显示的完全一致。下载之前请确认下面内容是否您想要的,是否完整无缺。

复制代码

在这里使用启发式算法GA(遗传算法)来进行参数寻优,用网格划分(grid search)来寻找最佳的参数c和g,虽然采用网格搜索能够找到在CV意义下的最高的分类准确率,即全局最优解,但有时候如果想在更大的范围内寻找最佳的参数c和g会很费时,采用启发式算法就可以不必遍历网格内的所有的参数点,也能找到全局最优解。

关于遗传算法这里不打算过多介绍,想要学习的朋友可以自己查看相关资料。

使用GA来进行参数寻优在在libsvm-mat-2.89-3[FarutoUltimate3.0]工具箱中已经实现gaSVMcgForClass.m(分类问题参数寻优)、gaSVMcgForRegress.m(回归问题参数寻优)。

1.

2. 3. 4. 5. 6. 7.

利用GA参数寻优函数(分类问题):gaSVMcgForClass [bestCVaccuracy,bestc,bestg,ga_option]= gaSVMcgForClass(train_label,train,ga_option) 输入:

train_label:训练集的标签,格式要求与svmtrain相同。 train:训练集,格式要求与svmtrain相同。

ga_option:GA中的一些参数设置,可不输入,有默认值,详细请看代码的帮助说明。 8. 输出:

9. bestCVaccuracy:最终CV意义下的最佳分类准确率。 10. bestc:最佳的参数c。 11. bestg:最佳的参数g。

12. ga_option:记录GA中的一些参数。

13. ========================================================== 14. 利用GA参数寻优函数(回归问题):gaSVMcgForRegress 15. [bestCVmse,bestc,bestg,ga_option]=

16. gaSVMcgForRegress(train_label,train,ga_option) 17. 其输入输出与gaSVMcgForClass类似,这里不再赘述。

复制代码

gaSVMcgForClass.m源代码:

1. function [BestCVaccuracy,Bestc,Bestg,ga_option] =

gaSVMcgForClass(train_label,train_data,ga_option) 2. % gaSVMcgForClass 3. 4. %%

5. % by faruto

6. %Email:patrick.lee@foxmail.com QQ:516667408

http://blog.sina.com.cn/faruto BNU 7. %last modified 2010.01.17 8.

9. %% 若转载请注明:

10. % faruto and liyang , LIBSVM-farutoUltimateVersion

11. % a toolbox with implements for support vector machines based on libsvm,

2009. 12. %

13. % Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for 14. % support vector machines, 2001. Software available at 15. % http://www.csie.ntu.edu.tw/~cjlin/libsvm 16.

17. %% 参数初始化 18. if nargin == 2

19. ga_option = struct('maxgen',200,'sizepop',20,'ggap',0.9,... 20. 'cbound',[0,100],'gbound',[0,1000],'v',5); 21. end

22. % maxgen:最大的进化代数,默认为200,一般取值范围为[100,500] 23. % sizepop:种群最大数量,默认为20,一般取值范围为[20,100] 24. % cbound = [cmin,cmax],参数c的变化范围,默认为(0,100] 25. % gbound = [gmin,gmax],参数g的变化范围,默认为[0,1000] 26. % v:SVM Cross Validation参数,默认为5 27. 28. %%

29. MAXGEN = ga_option.maxgen; 30. NIND = ga_option.sizepop; 31. NVAR = 2; 32. PRECI = 20;

33. GGAP = ga_option.ggap; 34. trace = zeros(MAXGEN,2); 35.

36. FieldID = ...

37. [rep([PRECI],[1,NVAR]);[ga_option.cbound(1),ga_option.gbound(1);ga_

option.cbound(2),ga_option.gbound(2)]; ... 38. [1,1;0,0;0,1;1,1]]; 39.

40. Chrom = crtbp(NIND,NVAR*PRECI); 41.

42. gen = 1;

43. v = ga_option.v; 44. BestCVaccuracy = 0; 45. Bestc = 0; 46. Bestg = 0; 47. %%

48. cg = bs2rv(Chrom,FieldID); 49.

50. for nind = 1:NIND

51. cmd = ['-v ',num2str(v),' -c ',num2str(cg(nind,1)),' -g

',num2str(cg(nind,2))];

52. ObjV(nind,1) = svmtrain(train_label,train_data,cmd); 53. end

54. [BestCVaccuracy,I] = max(ObjV); 55. Bestc = cg(I,1); 56. Bestg = cg(I,2); 57. 58. %%

59. while 1

60. % for gen = 1:MAXGEN

61. FitnV = ranking(-ObjV); 62.

63. SelCh = select('sus',Chrom,FitnV,GGAP); 64. SelCh = recombin('xovsp',SelCh,0.7); 65. SelCh = mut(SelCh); 66.

67. cg = bs2rv(SelCh,FieldID); 68. for nind = 1:size(SelCh,1)

69. cmd = ['-v ',num2str(v),' -c ',num2str(cg(nind,1)),' -g

',num2str(cg(nind,2))];

70. ObjVSel(nind,1) = svmtrain(train_label,train_data,cmd); 71. end 72.

73. [Chrom,ObjV] = reins(Chrom,SelCh,1,1,ObjV,ObjVSel); 74.

75. if max(ObjV) <= 50 76. continue; 77. end 78.

79. [NewBestCVaccuracy,I] = max(ObjV); 80. cg_temp = bs2rv(Chrom,FieldID);

81. temp_NewBestCVaccuracy = NewBestCVaccuracy; 82.

83. if NewBestCVaccuracy > BestCVaccuracy 84. BestCVaccuracy = NewBestCVaccuracy; 85. Bestc = cg_temp(I,1); 86. Bestg = cg_temp(I,2); 87. end 88.

89. if abs( NewBestCVaccuracy-BestCVaccuracy ) <= 10^(-2) && ... 90. cg_temp(I,1) < Bestc

91. BestCVaccuracy = NewBestCVaccuracy; 92. Bestc = cg_temp(I,1); 93. Bestg = cg_temp(I,2); 94. end

95.

96. trace(gen,1) = max(ObjV);

97. trace(gen,2) = sum(ObjV)/length(ObjV); 98.

99. gen = gen+1; 100. 101. if gen <= MAXGEN/2 102. continue; 103. end 104. if BestCVaccuracy >=80 && ... 105. ( temp_NewBestCVaccuracy-BestCVaccuracy ) <= 10^(-2) 106. break; 107. end 108. if gen == MAXGEN 109. break; 110. end 111. 112. end 113. gen = gen -1; 114. %% 115. figure; 116. hold on; 117. trace = round(trace*10000)/10000; 118. plot(trace(1:gen,1),'r*-','LineWidth',1.5); 119. plot(trace(1:gen,2),'o-','LineWidth',1.5); 120. legend('最佳适应度','平均适应度',3); 121. xlabel('进化代数','FontSize',12); 122. ylabel('适应度','FontSize',12); 123. axis([0 gen 0 100]); 124. grid on; 125. axis auto; 126. 127. line1 = '适应度曲线Accuracy[GAmethod]'; 128. line2 = ['(终止代数=', ... 129. num2str(gen),',种群数量pop=', ... 130. num2str(NIND),')']; 131. line3 = ['Best c=',num2str(Bestc),' g=',num2str(Bestg), ... 132. ' CVAccuracy=',num2str(BestCVaccuracy),'%']; 133. title({line1;line2;line3},'FontSize',12);

本文来源:https://www.bwwdw.com/article/ai03.html

Top