`
czhsuccess
  • 浏览: 41330 次
社区版块
存档分类
最新评论

java实现adaboost算法

阅读更多

adaboost算法的主要原理是训练若干个弱分类器,根据训练结果赋予它们不同的权值,最后再将这些弱分类器组合起来,形成一个强分类器,adaboost的基本原理在http://wenku.baidu.com/view/49478920aaea998fcc220e98.html###中已经有很详细的描述

这里使用上一篇博客中的感知器算法作为弱分类器,代码如下:

首先是adaboost算法的结果类

/**
 * 
 * @author zhenhua.chen
 * @Description: adboost算法的结果类,包括弱分类器的集合和每个弱分类器的权重
 * @date 2013-3-8 下午3:14:58 
 *
 */
public class AdboostResult {
	private ArrayList<ArrayList<Double>> weakClassifierSet;
	private ArrayList<Double> classifierWeightSet;
	
	public ArrayList<ArrayList<Double>> getWeakClassifierSet() {
		return weakClassifierSet;
	}
	public void setWeakClassifierSet(ArrayList<ArrayList<Double>> weakClassifierSet) {
		this.weakClassifierSet = weakClassifierSet;
	}
	public ArrayList<Double> getClassifierWeightSet() {
		return classifierWeightSet;
	}
	public void setClassifierWeightSet(ArrayList<Double> classifierWeightSet) {
		this.classifierWeightSet = classifierWeightSet;
	}
}

 adaboost算法:

/**
 * http://wenku.baidu.com/view/49478920aaea998fcc220e98.html
 * @author zhenhua.chen
 * @Description: TODO
 * @date 2013-3-8 下午3:09:36 
 *
 */
public class AdaboostAlgorithm {
	private static final int T = 30; // 迭代次数
	PerceptronApproach pa = new PerceptronApproach(); // 弱分类器
	
	/**
	 * 
	* @Title: adaboostClassify 
	* @Description: 通过训练集计算出组合分类器
	* @return AdboostResult
	* @throws
	 */
	public AdboostResult adaboostClassify(ArrayList<ArrayList<Double>> dataSet) {
		AdboostResult res = new AdboostResult();
		
		int dataDimension;
		if(null != dataSet && dataSet.size() > 0) {
			dataDimension = dataSet.get(0).size();
		} else {
			return null;
		}
		
		// 为每条数据的权重赋初值
		ArrayList<Double> dataWeightSet = new ArrayList<Double>();
		for(int i = 0; i < dataSet.size(); i ++) {
			dataWeightSet.add(1.0 / (double)dataSet.size());
		}
		
		// 存储每个弱分类器的权重
		ArrayList<Double> classifierWeightSet = new ArrayList<Double>();
		
		// 存储每个弱分类器
		ArrayList<ArrayList<Double>> weakClassifierSet = new ArrayList<ArrayList<Double>>();
		
		for(int i = 0; i < T; i++) {
			// 计算弱分类器
			ArrayList<Double> sensorWeightVector = pa.getWeightVector(dataSet, dataWeightSet);
			weakClassifierSet.add(sensorWeightVector);
			
			// 计算弱分类器误差
			double error = 0; //分类数
			int rightClassifyNum = 0;
			ArrayList<Double> cllassifyResult = new ArrayList<Double>();
			for(int j = 0; j < dataSet.size(); j++) { 
				double result = 0;
				for(int k = 0; k < dataDimension - 1; k++) {
					result += dataSet.get(j).get(k) * sensorWeightVector.get(k);
					
				}
				result += sensorWeightVector.get(dataDimension - 1);
				if(result < 0) { // 说明预测错误
					error += dataWeightSet.get(j);
					cllassifyResult.add(-1d);
				} else{ 
					cllassifyResult.add(1d);
					rightClassifyNum++;
				}
			}
			System.out.println("总数:" + dataSet.size() + "正确预测数" + rightClassifyNum);
			if(dataSet.size() == rightClassifyNum) {
				classifierWeightSet.clear();
				weakClassifierSet.clear();
				classifierWeightSet.add(1.0);
				weakClassifierSet.add(sensorWeightVector);
				break;
			}
			
			// 更新数据集中每条数据的权重并归一化
			double dataWeightSum = 0;
			for(int j = 0; j < dataSet.size(); j++) {
				dataWeightSet.set(j, dataWeightSet.get(j) * Math.pow(Math.E, (-1) * 0.5 * Math.log((1 - error) / error) * cllassifyResult.get(j))); // 按照http://wenku.baidu.com/view/49478920aaea998fcc220e98.html,更新的权重少除一个常数
				dataWeightSum += dataWeightSet.get(j);
			}
			for(int j = 0; j < dataSet.size(); j++) {
				dataWeightSet.set(j, dataWeightSet.get(j) / dataWeightSum);
			}
			
			
			// 计算次弱分类器的权重
			double currentWeight = (0.5 * Math.log((1 - error) / error));
			classifierWeightSet.add(currentWeight);
			System.out.println("classifier weight: " + currentWeight);
		}
		
		res.setClassifierWeightSet(classifierWeightSet);
		res.setWeakClassifierSet(weakClassifierSet);
		return res;
	}
	
	/**
	 * 
	* @Title: computeResult 
	* @Description: 计算输入数据的类别
	* @return double
	* @throws
	 */
	public int computeResult(ArrayList<Double> data, AdboostResult classifier) {
		double result = 0;
		int dataSize = data.size();
		ArrayList<ArrayList<Double>> weakClassifierSet = classifier.getWeakClassifierSet();
		ArrayList<Double> classifierWeightSet = classifier.getClassifierWeightSet();
		for(int i = 0; i < weakClassifierSet.size(); i++) {
			for(int j = 0; j < dataSize; j++) {
				result += weakClassifierSet.get(i).get(j) * data.get(j) * classifierWeightSet.get(i);
			}
			result += weakClassifierSet.get(i).get(dataSize);
		}
		if(result > 0) {
			return 1;
		} else {
			return -1;
		}
		
	}

 测试类:

public static void main(String[] args) {
		/**
		 * 测试数据,产生两类随机数据一类位于圆内,另一类位于包含小圆的大圆内,成环状
		 * 小圆半径为1,大圆半径为2,公共圆心位于(2, 2)内
		 */
		final int SMALL_CIRCLE_NUM = 24;
		final int RING_NUM = 34;
		
		ArrayList<ArrayList<Double>> dataSet = new ArrayList<ArrayList<Double>>();
		
		// 产生小圆数据
		for(int i = 0; i < SMALL_CIRCLE_NUM; i++) {
			double x = 1 + Math.random() * 2; // 1到3的随机数
			double y = 1 + Math.random() * 2; // 1到3的随机数
			if((x - 2) * (x - 2) + (y - 2) * (y - 2) - 1 <= 0) { //说明位于圆内
				ArrayList<Double> smallCircle = new ArrayList<Double>();
				smallCircle.add(x);
				smallCircle.add(y);
				smallCircle.add(1d); // 列别1
				dataSet.add(smallCircle);
			}
		}
		
		// 产生外围环形数据
		for(int i = 0; i < RING_NUM; i++) {
			double x1 = Math.random() * 4;
			double y1 = Math.random() * 4;
			if((x1 - 2) * (x1 - 2) + (y1 - 2) * (y1 - 2) - 4 < 0 && (x1 - 2) * (x1 - 2) + (y1 - 2) * (y1 - 2) - 1 > 0) { //说明位于环形区域内
				ArrayList<Double> ring = new ArrayList<Double>();
				ring.add(-x1);
				ring.add(-y1);
				ring.add(-1d); // 列别2
				dataSet.add(ring);
			}
		}
		
		AdaboostAlgorithm algo = new AdaboostAlgorithm();
		AdboostResult result = algo.adaboostClassify(dataSet);
		
		// 产生测试数据
		for(int i = 0; i < 10; i++) {
		
		ArrayList<Double> testData = new ArrayList<Double>();
		
		double x1 = Math.random() * 4;
		double y1 = Math.random() * 4;
		if((x1 - 2) * (x1 - 2) + (y1 - 2) * (y1 - 2) - 4 < 0 && (x1 - 2) * (x1 - 2) + (y1 - 2) * (y1 - 2) - 1 > 0) {
			testData.add(x1);
			testData.add(y1);
		}
		
//		double x = 1 + Math.random() * 2; // 1到3的随机数
//		double y = 1 + Math.random() * 2; // 1到3的随机数
//		if((x - 2) * (x - 2) + (y - 2) * (y - 2) - 1 <= 0) { //说明位于圆内
//			testData.add(x);
//			testData.add(y);
//		}
		
		algo.computeResult(testData, result);
		System.out.println(algo.computeResult(testData, result));
		}
		
	}

 

分享到:
评论
3 楼 微笑春天 2014-09-23  
楼主 if(result < 0) { // 说明预测错误 
                    error += dataWeightSet.get(j); 
                    cllassifyResult.add(-1d); 
                } else{  
                    cllassifyResult.add(1d); 
                    rightClassifyNum++; 
                } 
这句话中-ld和ld表示啥子意思 我看不懂啊 谢谢指点哈 其他地方我都觉得没有问题 看得懂 我是刚学这个算法的哈
2 楼 czhsuccess 2014-05-08  
reacherxu 写道
楼主  PerceptronApproach 这个类没有提供啊

这个类在http://czhsuccess.iteye.com/blog/1897914中。
1 楼 reacherxu 2013-10-13  
楼主  PerceptronApproach 这个类没有提供啊

相关推荐

Global site tag (gtag.js) - Google Analytics