结合PageRank算法用Java实现文本相似度
jopen
10年前
目标
尝试了一下把PageRank算法结合了文本相似度计算。直觉上是想把一个list里,和大家都比较靠拢的文本可能最后的PageRank值会比较大。因为如 果最后计算的PageRank值大,说明有比较多的文本和他的相似度值比较高,或者有更多的文本向他靠拢。这样是不是就可以得到一些相对核心的文本,或者 相对代表性的文本?如果是要在整堆文本里切分一些关键的词做token,那么每个token在每份文本里的权重就可以不一样,那么是否就可以得到比较核心 的token,来给这些文本打标签?当然,分词切词的时候都是要用工具过滤掉stopword的。
我也只是想尝试一下这个想法,就简单实现了整个过程。可能实现上还有问题。我的结果是最后大家的PageRank值都非常接近。如:
5.626742067958352, 5.626742066427658, 5.626742070495978, 5.626742056768215, 5.626742079766638
代码实现
文本之间的相似度计算用的是余弦距离,先哈希过。下面是计算两个List<String>的余弦距离代码:
package dcd.academic.recommend; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import dcd.academic.util.StdOutUtil; public class CosineDis { public static double getSimilarity(ArrayList<String> doc1, ArrayList<String> doc2) { if (doc1 != null && doc1.size() > 0 && doc2 != null && doc2.size() > 0) { Map<Long, int[]> AlgorithmMap = new HashMap<Long, int[]>(); for (int i = 0; i < doc1.size(); i++) { String d1 = doc1.get(i); long sIndex = hashId(d1); int[] fq = AlgorithmMap.get(sIndex); if (fq != null) { fq[0]++; } else { fq = new int[2]; fq[0] = 1; fq[1] = 0; AlgorithmMap.put(sIndex, fq); } } for (int i = 0; i < doc2.size(); i++) { String d2 = doc2.get(i); long sIndex = hashId(d2); int[] fq = AlgorithmMap.get(sIndex); if (fq != null) { fq[1]++; } else { fq = new int[2]; fq[0] = 0; fq[1] = 1; AlgorithmMap.put(sIndex, fq); } } Iterator<Long> iterator = AlgorithmMap.keySet().iterator(); double sqdoc1 = 0; double sqdoc2 = 0; double denominator = 0; while (iterator.hasNext()) { int[] c = AlgorithmMap.get(iterator.next()); denominator += c[0] * c[1]; sqdoc1 += c[0] * c[0]; sqdoc2 += c[1] * c[1]; } return denominator / Math.sqrt(sqdoc1 * sqdoc2); } else { return 0; } } public static long hashId(String s) { long seed = 131; // 31 131 1313 13131 131313 etc.. BKDRHash long hash = 0; for (int i = 0; i < s.length(); i++) { hash = (hash * seed) + s.charAt(i); } return hash; } public static void main(String[] args) { ArrayList<String> t1 = new ArrayList<String>(); ArrayList<String> t2 = new ArrayList<String>(); t1.add("sa"); t1.add("dfg"); t1.add("df"); t2.add("gfd"); t2.add("sa"); StdOutUtil.out(getSimilarity(t1, t2)); } }
利用上面这个类,根据文本之间的相似度,为每份文本计算得到一个向量(最后要归一一下),用来初始化PageRank的起始矩阵。我用的数据是我 solr里的论文标题+摘要的文本,我是通过SolrjHelper这个类去取得了一个List<String>。你想替换的话把这部分换成 自己想测试的String列就可以了。下面是读取数据,生成向量给PageRank类的代码:
package dcd.academic.recommend; import java.io.IOException; import java.net.UnknownHostException; import java.util.ArrayList; import java.util.List; import dcd.academic.mongodb.MyMongoClient; import dcd.academic.solrj.SolrjHelper; import dcd.academic.util.StdOutUtil; import dcd.academic.util.StringUtil; import com.mongodb.BasicDBList; import com.mongodb.BasicDBObject; import com.mongodb.DBCollection; import com.mongodb.DBCursor; import com.mongodb.DBObject; public class BtwPublication { public static final int NUM = 20; public static void main(String[] args) throws IOException{ BtwPublication bp = new BtwPublication(); //bp.updatePublicationForComma(); PageRank pageRank = new PageRank(bp.getPagerankS("random")); pageRank.doPagerank(); } public double getDist(String pub1, String pub2) throws IOException { if (pub1 != null && pub2 != null) { ArrayList<String> doc1 = StringUtil.getTokens(pub1); ArrayList<String> doc2 = StringUtil.getTokens(pub2); return CosineDis.getSimilarity(doc1, doc2); } else { return 0; } } // public List<Map<String, String>> getPubs(String name) { // // } public List<List<Double>> getPagerankS(String text) throws IOException { SolrjHelper helper = new SolrjHelper(1); List<String> pubs = helper.getPubsByTitle(text, 0, NUM); List<List<Double>> s = new ArrayList<List<Double>>(); for (String pub : pubs) { List<Double> tmp_row = new ArrayList<Double>(); double total = 0.0; for (String other : pubs) { if (!pub.equals(other)) { double tmp = getDist(pub, other); tmp_row.add(tmp); total += tmp; } else { tmp_row.add(0.0); } } s.add(getNormalizedRow(tmp_row, total)); } return s; } public List<Double> getNormalizedRow(List<Double> row, double d) { List<Double> res = new ArrayList<Double>(); for (int i = 0; i < row.size(); i ++) { res.add(row.get(i) / d); } StdOutUtil.out(res.toString()); return res; } }
package dcd.academic.recommend; import java.util.ArrayList; import java.util.List; import java.util.Random; import dcd.academic.util.StdOutUtil; public class PageRank { private static final double ALPHA = 0.85; private static final double DISTANCE = 0.0000001; private static final double MUL = 10; public static int SIZE; public static List<List<Double>> s; PageRank(List<List<Double>> s) { this.SIZE = s.get(0).size(); this.s = s; } public static void doPagerank() { List<Double> q = new ArrayList<Double>(); for (int i = 0; i < SIZE; i ++) { q.add(new Random().nextDouble()*MUL); } System.out.println("初始的向量q为:"); printVec(q); System.out.println("初始的矩阵G为:"); printMatrix(getG(ALPHA)); List<Double> pageRank = calPageRank(q, ALPHA); System.out.println("PageRank为:"); printVec(pageRank); System.out.println(); } /** * 打印输出一个矩阵 * * @param m */ public static void printMatrix(List<List<Double>> m) { for (int i = 0; i < m.size(); i++) { for (int j = 0; j < m.get(i).size(); j++) { System.out.print(m.get(i).get(j) + ", "); } System.out.println(); } } /** * 打印输出一个向量 * * @param v */ public static void printVec(List<Double> v) { for (int i = 0; i < v.size(); i++) { System.out.print(v.get(i) + ", "); } System.out.println(); } /** * 获得一个初始的随机向量q * * @param n * 向量q的维数 * @return 一个随机的向量q,每一维是0-5之间的随机数 */ public static List<Double> getInitQ(int n) { Random random = new Random(); List<Double> q = new ArrayList<Double>(); for (int i = 0; i < n; i++) { q.add(new Double(5 * random.nextDouble())); } return q; } /** * 计算两个向量的距离 * * @param q1 * 第一个向量 * @param q2 * 第二个向量 * @return 它们的距离 */ public static double calDistance(List<Double> q1, List<Double> q2) { double sum = 0; if (q1.size() != q2.size()) { return -1; } for (int i = 0; i < q1.size(); i++) { sum += Math.pow(q1.get(i).doubleValue() - q2.get(i).doubleValue(), 2); } return Math.sqrt(sum); } /** * 计算pagerank * * @param q1 * 初始向量 * @param a * alpha的值 * @return pagerank的结果 */ public static List<Double> calPageRank(List<Double> q1, double a) { List<List<Double>> g = getG(a); List<Double> q = null; while (true) { q = vectorMulMatrix(g, q1); double dis = calDistance(q, q1); System.out.println(dis); if (dis <= DISTANCE) { System.out.println("q1:"); printVec(q1); System.out.println("q:"); printVec(q); break; } q1 = q; } return q; } /** * 计算获得初始的G矩阵 * * @param a * 为alpha的值,0.85 * @return 初始矩阵G */ public static List<List<Double>> getG(double a) { List<List<Double>> aS = numberMulMatrix(s, a); List<List<Double>> nU = numberMulMatrix(getU(), (1 - a) / SIZE); List<List<Double>> g = addMatrix(aS, nU); return g; } /** * 计算一个矩阵乘以一个向量 * * @param m * 一个矩阵 * @param v * 一个向量 * @return 返回一个新的向量 */ public static List<Double> vectorMulMatrix(List<List<Double>> m, List<Double> v) { if (m == null || v == null || m.size() <= 0 || m.get(0).size() != v.size()) { return null; } List<Double> list = new ArrayList<Double>(); for (int i = 0; i < m.size(); i++) { double sum = 0; for (int j = 0; j < m.get(i).size(); j++) { double temp = m.get(i).get(j).doubleValue() * v.get(j).doubleValue(); sum += temp; } list.add(sum); } return list; } /** * 计算两个矩阵的和 * * @param list1 * 第一个矩阵 * @param list2 * 第二个矩阵 * @return 两个矩阵的和 */ public static List<List<Double>> addMatrix(List<List<Double>> list1, List<List<Double>> list2) { List<List<Double>> list = new ArrayList<List<Double>>(); if (list1.size() != list2.size() || list1.size() <= 0 || list2.size() <= 0) { return null; } for (int i = 0; i < list1.size(); i++) { list.add(new ArrayList<Double>()); for (int j = 0; j < list1.get(i).size(); j++) { double temp = list1.get(i).get(j).doubleValue() + list2.get(i).get(j).doubleValue(); list.get(i).add(new Double(temp)); } } return list; } /** * 计算一个数乘以矩阵 * * @param s * 矩阵s * @param a * double类型的数 * @return 一个新的矩阵 */ public static List<List<Double>> numberMulMatrix(List<List<Double>> s, double a) { List<List<Double>> list = new ArrayList<List<Double>>(); for (int i = 0; i < s.size(); i++) { list.add(new ArrayList<Double>()); for (int j = 0; j < s.get(i).size(); j++) { double temp = a * s.get(i).get(j).doubleValue(); list.get(i).add(new Double(temp)); } } return list; } /** * 初始化U矩阵,全1 * * @return U */ public static List<List<Double>> getU() { List<Double> row = new ArrayList<Double>(); for (int i = 0; i < SIZE; i ++) { row.add(new Double(1)); } List<List<Double>> s = new ArrayList<List<Double>>(); for (int j = 0; j < SIZE; j ++) { s.add(row); } return s; } }
下面是我一次实验结果的数据,我设置了五分文本,这样看起来比较短:
[0.0, 0.09968643574761415, 0.2601130421632277, 0.31094706119099713, 0.32925346089816093] [0.1315115598803241, 0.0, 0.23650307622882252, 0.2827229880685279, 0.34926237582232544] [0.13521235055030142, 0.09318868159350341, 0.0, 0.3996835314966943, 0.3719154363595009] [0.1389453620825689, 0.0957614822411479, 0.34357346750710194, 0.0, 0.4217196881691813] [0.14612484353723476, 0.11749453142051332, 0.31752920814285096, 0.4188514168994011, 0.0] 初始的向量q为: 8.007763265073303, 3.1232982446687387, 1.1722525763669134, 5.906625842576609, 9.019220483814852, 初始的矩阵G为: 0.030000000000000006, 0.11473347038547205, 0.2510960858387436, 0.2943050020123476, 0.30986544176343683, 0.14178482589827548, 0.030000000000000006, 0.23102761479449913, 0.2703145398582487, 0.3268730194489766, 0.1449304979677562, 0.10921037935447789, 0.030000000000000006, 0.36973100177219015, 0.3461281209055758, 0.14810355777018358, 0.11139725990497573, 0.3220374473810367, 0.030000000000000006, 0.38846173494380415, 0.15420611700664955, 0.12987035170743633, 0.29989982692142336, 0.38602370436449096, 0.030000000000000006, 8.215210604296416 2.1786836521210637 0.6343362349619535 0.19024536572818584 0.05836227176176904 0.018354791916908083 0.0059297512567364945 0.0019669982458251243 6.679891158687752E-4 2.312017647733628E-4 8.117199104238135E-5 2.8787511843006215E-5 1.0279598478348542E-5 3.6872987746593366E-6 1.3264993458811192E-6 4.780938295685138E-7 1.7251588746973008E-7 6.229666266632005E-8 q1: 5.62674207030434, 5.626742074589739, 5.626742063777632, 5.626742101012727, 5.626742037269133, q: 5.626742067958352, 5.626742066427658, 5.626742070495978, 5.626742056768215, 5.626742079766638, PageRank为: 5.626742067958352, 5.626742066427658, 5.626742070495978, 5.626742056768215, 5.626742079766638,