在内存中执行k-means聚类算法
DeliaPitt
9年前
来自: http://blog.csdn.net/u012965373/article/details/50754381
/*** * @author YangXin * @info 利用点集测试K-Means聚类算法 */ package unitNine; import java.util.ArrayList; import java.util.List; import org.apache.mahout.clustering.Cluster; import org.apache.mahout.clustering.UncommonDistributions; import org.apache.mahout.common.distance.EuclideanDistanceMeasure; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.Vector; public class KMeansExample { private static void generateSamples(List<Vector> vectors, int num, double mx, double my, double sd){ for(int i = 0; i < num; i++){ vectors.add(new DenseVector(new double[]{UncommonDistributions.rNorm(mx, sd),UncommonDistributions.rNorm(my, sd) })); } } public static void main(String[] args){ List<Vector> sampleData = new ArrayList<Vector>(); RandomPointsUtil.generateSamples(sampleData, 400, 1, 1, 3); RandomPointsUtil.generateSamples(sampleData, 300, 1, 0, 0.5); RandomPointsUtil.generateSamples(sampleData, 300, 0, 2, 0.1); int k = 3; List<Vector> randomPoints = RandomPointsUtil.chooseRandomPoints( sampleData, k); List<Cluster> clusters = new ArrayList<Cluster>(); int clusterId = 0; for (Vector v : randomPoints) { clusters.add(new Cluster(v, clusterId++, new EuclideanDistanceMeasure())); } List<List<Cluster>> finalClusters = KMeansClusterer.clusterPoints( sampleData, clusters, new EuclideanDistanceMeasure(), 3, 0.01); for (Cluster cluster : finalClusters.get(finalClusters.size() - 1)) { System.out.println("Cluster id: " + cluster.getId() + " center: " + cluster.getCenter().asFormatString()); } } </span></strong>
<strong><span style="font-size:18px;">/*** * @author YangXin * @info 处理随机点的类 */ package unitNine; import java.util.ArrayList; import java.util.List; import java.util.Random; import org.apache.mahout.common.RandomUtils; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.Vector; public class RandomPointsUtil { public static void generateSamples(List<Vector> vectors, int num, double mx, double my, double sd) { for (int i = 0; i < num; i++) { vectors.add(new DenseVector(new double[] { org.apache.mahout.clustering.UncommonDistributions.rNorm(mx, sd), org.apache.mahout.clustering.UncommonDistributions.rNorm(my, sd) })); } } public static List<Vector> chooseRandomPoints(Iterable<Vector> vectors, int k) { List<Vector> chosenPoints = new ArrayList<Vector>(k); Random random = RandomUtils.getRandom(); for (Vector value : vectors) { int currentSize = chosenPoints.size(); if (currentSize < k) { chosenPoints.add(value); } else if (random.nextInt(currentSize + 1) == 0) { // with chance 1/(currentSize+1) pick new element int indexToRemove = random.nextInt(currentSize); // evict one chosen randomly chosenPoints.remove(indexToRemove); chosenPoints.add(value); } } return chosenPoints; } }