GMM算发步骤:
1. 初始化参数,包括Gauss分布个数、均值、协方差;
2. 计算每个节点属于每个分布的概率;
3. 计算每个分布产生每个节点的概率;
4. 更新每个分布的权值,均值和它们的协方差。
基本参数类:
public class Parameter { private ArrayList<ArrayList<Double>> pMiu; // 均值参数k个分布的中心点,每个中心点d维 private ArrayList<Double> pPi; // k个GMM的权值 private ArrayList<ArrayList<ArrayList<Double>>> pSigma; // k类GMM的协方差矩阵,d*d*k public ArrayList<ArrayList<Double>> getpMiu() { return pMiu; } public void setpMiu(ArrayList<ArrayList<Double>> pMiu) { this.pMiu = pMiu; } public ArrayList<Double> getpPi() { return pPi; } public void setpPi(ArrayList<Double> pPi) { this.pPi = pPi; } public ArrayList<ArrayList<ArrayList<Double>>> getpSigma() { return pSigma; } public void setpSigma(ArrayList<ArrayList<ArrayList<Double>>> pSigma) { this.pSigma = pSigma; } }
核心代码如下:
public class GMMAlgorithm { /** * * @Title: GMMCluster * @Description: GMM聚类算法的实现类,返回每条数据的类别(0~k-1) * @return int[] * @throws */ public int[] GMMCluster(ArrayList<ArrayList<Double>>dataSet, ArrayList<ArrayList<Double>> pMiu, int dataNum, int k, int dataDimen) { Parameter parameter = iniParameters(dataSet, dataNum, k, dataDimen); double Lpre = -1000000; // 上一次聚类的误差 double threshold = 0.0001; while(true) { ArrayList<ArrayList<Double>> px = computeProbablity(dataSet, pMiu, dataNum, k, dataDimen); double[][] pGama = new double[dataNum][k]; for(int i = 0; i < dataNum; i++) { for(int j = 0; j < k; j++) { pGama[i][j] = px.get(i).get(j) * parameter.getpPi().get(j); } } double[] sumPGama = GMMUtil.matrixSum(pGama, 2); for(int i = 0; i < dataNum; i++) { for(int j = 0; j < k; j++) { pGama[i][j] = pGama[i][j] / sumPGama[i]; } } double[] NK = GMMUtil.matrixSum(pGama, 1); // 第k个高斯生成每个样本的概率的和,所有Nk的总和为N // 更新pMiu double[] NKReciprocal = new double[NK.length]; for(int i = 0; i < NK.length; i++) { NKReciprocal[i] = 1 / NK[i]; } double[][] pMiuTmp = GMMUtil.matrixMultiply(GMMUtil.matrixMultiply(GMMUtil.diag(NKReciprocal), GMMUtil.matrixReverse(pGama)), GMMUtil.toArray(dataSet)); // 更新pPie double[][] pPie = new double[k][1]; for(int i = 0; i < NK.length; i++) { pPie[i][1] = NK[i] / dataNum; } // 更新k个pSigma double[][][] pSigmaTmp = new double[dataDimen][dataDimen][k]; for(int i = 0; i < k; i++) { double[][] xShift = new double[dataNum][dataDimen]; for(int j = 0; j < dataNum; j++) { for(int l = 0; l < dataDimen; l++) { xShift[j][l] = pMiuTmp[i][l]; } } double[] pGamaK = new double[dataNum]; // 第k条pGama值 for(int j = 0; j < dataNum; j++) { pGamaK[j] = pGama[j][i]; } double[][] diagPGamaK = GMMUtil.diag(pGamaK); double[][] pSigmaK = GMMUtil.matrixMultiply(GMMUtil.matrixReverse(xShift), (GMMUtil.matrixMultiply(diagPGamaK, xShift))); for(int j = 0; j < dataDimen; j++) { for(int l = 0; l < dataDimen; l++) { pSigmaTmp[j][l][k] = pSigmaK[j][l] / NK[i]; } } } // 判断是否迭代结束 double[][] a = GMMUtil.matrixMultiply(GMMUtil.toArray(px), pPie); for(int i = 0; i < dataNum; i++) { a[i][0] = Math.log(a[i][0]); } double L = GMMUtil.matrixSum(a, 1)[0]; if(L - Lpre < threshold) { break; } Lpre = L; } return null; } /** * * @Title: computeProbablity * @Description: 计算每个节点(共n个)属于每个分布(k个)的概率 * @return ArrayList<ArrayList<Double>> * @throws */ public ArrayList<ArrayList<Double>> computeProbablity(ArrayList<ArrayList<Double>>dataSet, ArrayList<ArrayList<Double>> pMiu, int dataNum, int k, int dataDimen) { double[][] px = new double[dataNum][k]; // 每条数据属于每个分布的概率 int[] type = getTypes(dataSet, pMiu, k, dataNum); // 计算k个分布的协方差矩阵 ArrayList<ArrayList<ArrayList<Double>>> covList = new ArrayList<ArrayList<ArrayList<Double>>>(); for(int i = 0; i < k; i++) { ArrayList<ArrayList<Double>> dataSetK = new ArrayList<ArrayList<Double>>(); for(int j = 0; j < dataNum; j++) { if(type[k] == i) { dataSetK.add(dataSet.get(i)); } } covList.set(i, GMMUtil.computeCov(dataSetK, dataDimen, dataSetK.size())); } // 计算每条数据属于每个分布的概率 for(int i = 0; i < dataNum; i++) { for(int j = 0; j < k; j++) { ArrayList<Double> offset = GMMUtil.matrixMinus(dataSet.get(i), pMiu.get(j)); ArrayList<ArrayList<Double>> invSigma = covList.get(k); double[] tmp = GMMUtil.matrixSum(GMMUtil.matrixMultiply(GMMUtil.toArray1(offset), GMMUtil.toArray(invSigma)), 2); double coef = Math.pow((2 * Math.PI), -(double)dataDimen / 2d) * Math.sqrt(GMMUtil.computeDet(invSigma, invSigma.size())); px[i][j] = coef * Math.pow(Math.E, -0.5 * tmp[0]); } } return GMMUtil.toList(px); } /** * * @Title: iniParameters * @Description: 初始化参数Parameter * @return Parameter * @throws */ public Parameter iniParameters(ArrayList<ArrayList<Double>> dataSet, int dataNum, int k, int dataDimen) { Parameter res = new Parameter(); ArrayList<ArrayList<Double>> pMiu = generateCentroids(dataSet, dataNum, k); res.setpMiu(pMiu); // 计算每个样本节点与每个中心节点的距离,以此为据对样本节点进行分类计数,进而初始化k个分布的权值 ArrayList<Double> pPi = new ArrayList<Double>(); int[] type = getTypes(dataSet, pMiu, k, dataNum); int[] typeNum = new int[k]; for(int i = 0; i < dataNum; i++) { typeNum[type[i]]++; } for(int i = 0; i < k; i++) { pPi.add((double)(typeNum[i]) / (double)(dataNum)); } res.setpPi(pPi); // 计算k个分布的k个协方差 ArrayList<ArrayList<ArrayList<Double>>> pSigma = new ArrayList<ArrayList<ArrayList<Double>>>(); for(int i = 0; i < k; i++) { ArrayList<ArrayList<Double>> tmp = new ArrayList<ArrayList<Double>>(); for(int j = 0; j < dataNum; j++) { if(type[j] == i) { tmp.add(dataSet.get(i)); } } pSigma.add(GMMUtil.computeCov(tmp, dataDimen, dataNum)); } res.setpSigma(pSigma); return res; } /** * * @Title: generateCentroids * @Description: 获取随机的k个中心点 * @return ArrayList<ArrayList<Double>> * @throws */ public ArrayList<ArrayList<Double>> generateCentroids(ArrayList<ArrayList<Double>> data, int dataNum, int k) { ArrayList<ArrayList<Double>> res = null; if(dataNum < k) { return res; } else { res = new ArrayList<ArrayList<Double>>(); List<Integer> random = new ArrayList<Integer>(); // 随机产生不重复的k个数 while(k > 0) { int index = (int)(Math.random() * dataNum); if(!random.contains(index)) { random.add(index); k--; res.add(data.get(index)); } } } return res; } /** * * @Title: getTypes * @Description: 返回每条数据的类别 * @return int[] * @throws */ public int[] getTypes(ArrayList<ArrayList<Double>> dataSet, ArrayList<ArrayList<Double>> pMiu, int k, int dataNum) { int[] type = new int[dataNum]; for(int j = 0; j < dataNum; j++) { double minDistance = GMMUtil.computeDistance(dataSet.get(j), pMiu.get(0)); type[j] = 0; // 0作为该条数据的类别 for(int i = 1; i < k; i++) { if(GMMUtil.computeDistance(dataSet.get(j), pMiu.get(0)) < minDistance) { minDistance = GMMUtil.computeDistance(dataSet.get(j), pMiu.get(0)); type[j] = k; } } } return type; } public static void main(String[] args) { ArrayList<Double> pPi = new ArrayList<Double>(); System.out.println(pPi.get(0)); } }
一些工具类:
public class GMMUtil { /** * * @Title: computeDistance * @Description: 计算任意两个节点间的距离 * @return double * @throws */ public static double computeDistance(ArrayList<Double> d1, ArrayList<Double> d2) { double squareSum = 0; for(int i = 0; i < d1.size() - 1; i++) { squareSum += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i)); } return Math.sqrt(squareSum); } /** * * @Title: computeCov * @Description: 计算协方差矩阵 * @return ArrayList<ArrayList<Double>> * @throws */ public static ArrayList<ArrayList<Double>> computeCov(ArrayList<ArrayList<Double>> dataSet, int dataDimen, int dataNum) { ArrayList<ArrayList<Double>> res = new ArrayList<ArrayList<Double>>(); // 计算每一维数据的均值 double[] sum = new double[dataDimen]; for(ArrayList<Double> data : dataSet) { for(int i = 0; i < dataDimen; i++) { sum[i] += data.get(i); } } for(int i = 0; i < dataDimen; i++) { sum[i] = sum[i] / dataNum; } // 计算任意两组数据的协方差 for(int i = 0; i < dataDimen; i++) { ArrayList<Double> tmp = new ArrayList<Double>(); for(int j = 0; j < dataDimen; j++) { double cov = 0; for(ArrayList<Double> data : dataSet) { cov += (data.get(i) - sum[i]) * (data.get(j) - sum[j]); } tmp.add(cov); } res.add(tmp); } return res; } /** * * @Title: computeInv * @Description: 计算矩阵的逆矩阵 * @return ArrayList<ArrayList<Double>> * @throws */ public static double[][] computeInv(ArrayList<ArrayList<Double>> dataSet) { int dataDimen = dataSet.size(); double[][] res = new double[dataDimen][dataDimen]; // 将list转化为array double[][] a = toArray(dataSet); // 计算伴随矩阵 double detA = computeDet(dataSet, dataDimen); // 整个矩阵的行列式 for (int i = 0; i < dataDimen; i++) { for (int j = 0; j < dataDimen; j++) { double num; if ((i + j) % 2 == 0) { num = computeDet(toList(computeAC(a, i + 1, j + 1)), dataDimen - 1); } else { num = -computeDet(toList(computeAC(a, i + 1, j + 1)), dataDimen - 1); } res[j][i] = num / detA; } } return res; } /** * * @Title: computeAC * @Description: 求指定行、列的代数余子式(algebraic complement) * @return double[][] * @throws */ public static double[][] computeAC(double[][] dataSet, int r, int c) { int H = dataSet.length; int V = dataSet[0].length; double[][] newData = new double[H - 1][V - 1]; for (int i = 0; i < newData.length; i++) { if (i < r - 1) { for (int j = 0; j < newData[i].length; j++) { if (j < c - 1) { newData[i][j] = dataSet[i][j]; } else { newData[i][j] = dataSet[i][j + 1]; } } } else { for (int j = 0; j < newData[i].length; j++) { if (j < c - 1) { newData[i][j] = dataSet[i + 1][j]; } else { newData[i][j] = dataSet[i + 1][j + 1]; } } } } return newData; } /** * * @Title: computeDet * @Description: 计算行列式 * @return double * @throws */ public static double computeDet(ArrayList<ArrayList<Double>> dataSet, int dataDimen) { // 将list转化为array double[][] a = toArray(dataSet); if(dataDimen == 2) { return a[0][0] * a[1][1] - a[0][1] * a[1][0]; } double res = 0; for(int i = 0; i < dataDimen; i++) { if(i % 2 == 0) { res += a[0][i] * computeDet(toList(computeAC(toArray(dataSet), 1, i + 1)), dataDimen - 1); } else { res += -a[0][i] * computeDet(toList(computeAC(toArray(dataSet), 1, i + 1)), dataDimen - 1); } } return res; } /** * * @Title: toList * @Description: 将array转化成list * @return ArrayList<ArrayList<Double>> * @throws */ public static ArrayList<ArrayList<Double>> toList(double[][] a) { ArrayList<ArrayList<Double>> res = new ArrayList<ArrayList<Double>>(); for(int i = 0; i < a.length; i++) { ArrayList<Double> tmp = new ArrayList<Double>(); for(int j = 0; j < a[i].length; j++) { tmp.add(a[i][j]); } res.add(tmp); } return res; } public static double[][] matrixMultiply(double[][] a, double[][] b) { double[][] res = new double[a.length][b[0].length]; for(int i = 0; i < a.length; i++) { for(int j = 0; j < b[0].length; j++) { for(int k = 0; k < a[0].length; k++) { res[i][j] += a[i][k] * b[k][j]; } } } return res; } /** * * @Title: dotMatrixMultiply * @Description: 矩阵的点乘,即对应元素相乘 * @return double[][] * @throws */ public static double[][] dotMatrixMultiply (double[][] a, double[][] b) { double[][] res = new double[a.length][a[0].length]; for(int i = 0; i < a.length; i++) { for(int j = 0; j < a[0].length; j++) { res[i][j] = a[i][j] * b[i][j]; } } return res; } /** * * @Title: dotMatrixMultiply * @Description: 矩阵的点除,即对应元素相除 * @return double[][] * @throws */ public static double[][] dotMatrixDivide(double[][] a, double[][] b) { double[][] res = new double[a.length][a[0].length]; for(int i = 0; i < a.length; i++) { for(int j = 0; j < a[0].length; j++) { res[i][j] = a[i][j] / b[i][j]; } } return res; } /** * * @Title: repmat * @Description: 对应matlab的repmat的函数,对矩阵进行横向或纵向的平铺 * @return double[][] * @throws */ public static double[][] repmat(double[][] a, int row, int clo) { double[][] res = new double[a.length * row][a[0].length * clo]; return null; } /** * * @Title: matrixMinux * @Description: 计算集合只差 * @return ArrayList<ArrayList<Double>> * @throws */ public static ArrayList<Double> matrixMinus(ArrayList<Double> a1, ArrayList<Double> a2) { ArrayList<Double> res = new ArrayList<Double>(); for(int i = 0; i < a1.size(); i++) { res.add(a1.get(i) - a2.get(i)); } return res; } /** * * @Title: matrixSum * @Description: 返回矩阵每行之和(mark==2)或每列之和(mark==1) * @return ArrayList<Double> * @throws */ public static double[] matrixSum(double[][] a, int mark) { double res[] = new double[a.length]; if(mark == 1) { // 计算每列之和,返回行向量 res = new double[a[0].length]; for(int i = 0; i < a[0].length; i++) { for(int j = 0; j < a.length; j++) { res[i] += a[j][i]; } } } else if (mark == 2) { // 计算每行之和, 返回列向量 for(int i = 0; i < a.length; i++) { for(int j = 0; j < a[0].length; j++) { res[i] += a[i][j]; } } } return res; } public static double[][] toArray(ArrayList<ArrayList<Double>> a) { int dataDimen = a.size(); double[][] res = new double[dataDimen][dataDimen]; for(int i = 0; i < dataDimen; i++) { for(int j = 0; j < dataDimen; j++) { res[i][j] = a.get(i).get(j); } } return res; } public static double[][] toArray1(ArrayList<Double> a) { int dataDimen = a.size(); double[][] res = new double[1][dataDimen]; for(int i = 0; i < dataDimen; i++) { res[1][i] = a.get(i); } return res; } /** * * @Title: matrixReverse * @Description: 矩阵专制 * @return double[][] * @throws */ public static double[][] matrixReverse(double[][] a) { double[][] res = new double[a[0].length][a.length]; for(int i = 0; i < a.length; i++) { for(int j = 0; j < a[0].length; j++) { res[j][i] = a[i][j]; } } return res; } /** * * @Title: diag * @Description: 向量对角化 * @return double[][] * @throws */ public static double[][] diag(double[] a) { double[][] res = new double[a.length][a.length]; for(int i = 0; i < a.length; i++) { for(int j = 0; j < a.length; j++) { if(i == j) { res[i][j] = a[i]; } } } return res; } }
相关推荐
有关GMM算法的EM实现,里面都是本人在学习GMM算法时候的资料,非常有用
所谓混合高斯模型(GMM)就是指对样本的概率密度分布进行估计,而估计采用的模型(训练模型)是几个高斯模型的加权和(具体是几个要在模型训练前建立好)。每个高斯模型就代表了一个类(一个Cluster)。对样本中的...
利用EM算法实现GMM算法,文件包含GMM模型以及一个简单的2分类问题的实现,课程作业绝对可用。
机器学习C++源码解析-GMM高斯混合模型算法-源码+数据
改文件包中包含EM算法,已经使用GMM算法进行参数估计,并同时示例进行分类训练和预测
matlab 实现GMM——EM算法,自动生产混合高斯分布,GMM算法的示例demo
实现了gmm分类算法,并对参数进行优化,实现了算法的优化
使用EM算法实现GMM,使用了C++进行编程。
提出了一种基于GMM算法的视网膜血管分割新方法
C++实现GMM分类模型的源码,高斯参数自己可以调整
EM GMM 算法
利用GMM算法实现Voice conversion,文件中有样例,实现的大体方法见博客,资源来自CodeOcean
2.内容:基于GMM的图像分割算法matlab仿真+代码操作视频 3.用处:用于GMM图像分割算法学习 4.指向人群:本硕博等教研学习使用 5.运行注意事项: 使用matlab2021a或者更高版本测试,运行里面的Runme.m文件,不要...
图像处理gmm算法的c++实现。希望对大家有帮助。
这个文档详细介绍了GMM算法,包括公式的推倒,非常适合初学者
高斯混合模型,GMM,分类器,MLE算法公式推导
GMM及Kmeans算法实现,包含简单的测试程序,可直接在Linux下编译
高斯混合模型GMM与EM算法的matlab实现,用户可直接运行代码,观看结果,欢迎下载,进行进一步讨论
基于python的高斯混合模型(GMM 聚类)的 EM 算法实现
GMM算法是混合高斯模型,其求解过程需要不断迭代,本程序利用EM算法进行了仿真实现,可以加深对GMM的理解。