Apriori算法实现

jopen 9年前

Apriori算法原理:http://blog.csdn.net/kingzone_2008/article/details/8183768


import java.util.HashMap;  import java.util.HashSet;  import java.util.Iterator;  import java.util.Map;  import java.util.Set;  import java.util.TreeMap;  /**  * <B>关联规则挖掘:Apriori算法</B>  *   * <P>按照Apriori算法的基本思想来实现  *   * @author king  * @since 2013/06/27  *   */  public class Apriori {   private Map<Integer, Set<String>> txDatabase; // 事务数据库   private Float minSup; // 最小支持度   private Float minConf; // 最小置信度   private Integer txDatabaseCount; // 事务数据库中的事务数      private Map<Integer, Set<Set<String>>> freqItemSet; // 频繁项集集合   private Map<Set<String>, Set<Set<String>>> assiciationRules; // 频繁关联规则集合      public Apriori(       Map<Integer, Set<String>> txDatabase,        Float minSup,        Float minConf) {      this.txDatabase = txDatabase;      this.minSup = minSup;      this.minConf = minConf;      this.txDatabaseCount = this.txDatabase.size();      freqItemSet = new TreeMap<Integer, Set<Set<String>>>();      assiciationRules = new HashMap<Set<String>, Set<Set<String>>>();   }      /**   * 扫描事务数据库,计算频繁1-项集   * @return   */   public Map<Set<String>, Float> getFreq1ItemSet() {      Map<Set<String>, Float> freq1ItemSetMap = new HashMap<Set<String>, Float>();      Map<Set<String>, Integer> candFreq1ItemSet = this.getCandFreq1ItemSet();      Iterator<Map.Entry<Set<String>, Integer>> it = candFreq1ItemSet.entrySet().iterator();      while(it.hasNext()) {       Map.Entry<Set<String>, Integer> entry = it.next();       // 计算支持度       Float supported = new Float(entry.getValue().toString())/new Float(txDatabaseCount);       if(supported>=minSup) {        freq1ItemSetMap.put(entry.getKey(), supported);       }      }      return freq1ItemSetMap;   }      /**   * 计算候选频繁1-项集   * @return   */   public Map<Set<String>, Integer> getCandFreq1ItemSet() {      Map<Set<String>, Integer> candFreq1ItemSetMap = new HashMap<Set<String>, Integer>();      Iterator<Map.Entry<Integer, Set<String>>> it = txDatabase.entrySet().iterator();      // 统计支持数,生成候选频繁1-项集      while(it.hasNext()) {       Map.Entry<Integer, Set<String>> entry = it.next();       Set<String> itemSet = entry.getValue();       for(String item : itemSet) {        Set<String> key = new HashSet<String>();        key.add(item.trim());        if(!candFreq1ItemSetMap.containsKey(key)) {         Integer value = 1;         candFreq1ItemSetMap.put(key, value);        }        else {         Integer value = 1+candFreq1ItemSetMap.get(key);         candFreq1ItemSetMap.put(key, value);        }       }      }      return candFreq1ItemSetMap;   }      /**   * 根据频繁(k-1)-项集计算候选频繁k-项集   *    * @param m 其中m=k-1   * @param freqMItemSet 频繁(k-1)-项集   * @return   */   public Set<Set<String>> aprioriGen(int m, Set<Set<String>> freqMItemSet) {      Set<Set<String>> candFreqKItemSet = new HashSet<Set<String>>();      Iterator<Set<String>> it = freqMItemSet.iterator();      Set<String> originalItemSet = null;      while(it.hasNext()) {       originalItemSet = it.next();       Iterator<Set<String>> itr = this.getIterator(originalItemSet, freqMItemSet);       while(itr.hasNext()) {        Set<String> identicalSet = new HashSet<String>(); // 两个项集相同元素的集合(集合的交运算)            identicalSet.addAll(originalItemSet);         Set<String> set = itr.next();         identicalSet.retainAll(set); // identicalSet中剩下的元素是identicalSet与set集合中公有的元素        if(identicalSet.size() == m-1) { // (k-1)-项集中k-2个相同         Set<String> differentSet = new HashSet<String>(); // 两个项集不同元素的集合(集合的差运算)         differentSet.addAll(originalItemSet);         differentSet.removeAll(set); // 因为有k-2个相同,则differentSet中一定剩下一个元素,即differentSet大小为1         differentSet.addAll(set); // 构造候选k-项集的一个元素(set大小为k-1,differentSet大小为k)         if(!this.has_infrequent_subset(differentSet, freqMItemSet))             candFreqKItemSet.add(differentSet); // 加入候选k-项集集合        }       }      }      return candFreqKItemSet;   }      /**    * 使用先验知识,剪枝。若候选k项集中存在k-1项子集不是频繁k-1项集,则删除该候选k项集    * @param candKItemSet    * @param freqMItemSet    * @return    */   private boolean has_infrequent_subset(Set<String> candKItemSet, Set<Set<String>> freqMItemSet) {    Set<String> tempSet = new HashSet<String>();    tempSet.addAll(candKItemSet);    Iterator<String> itItem = candKItemSet.iterator();    while(itItem.hasNext()) {     String item = itItem.next();     tempSet.remove(item);// 该候选去掉一项后变为k-1项集     if(!freqMItemSet.contains(tempSet))// 判断k-1项集是否是频繁项集      return true;     tempSet.add(item);// 恢复    }    return false;   }      /**   * 根据一个频繁k-项集的元素(集合),获取到频繁k-项集的从该元素开始的迭代器实例   * @param itemSet   * @param freqKItemSet 频繁k-项集   * @return   */   private Iterator<Set<String>> getIterator(Set<String> itemSet, Set<Set<String>> freqKItemSet) {      Iterator<Set<String>> it = freqKItemSet.iterator();      while(it.hasNext()) {       if(itemSet.equals(it.next())) {        break;       }      }      return it;   }      /**   * 根据频繁(k-1)-项集,调用aprioriGen方法,计算频繁k-项集   *    * @param k    * @param freqMItemSet 频繁(k-1)-项集   * @return   */   public Map<Set<String>, Float> getFreqKItemSet(int k, Set<Set<String>> freqMItemSet) {      Map<Set<String>, Integer> candFreqKItemSetMap = new HashMap<Set<String>, Integer>();      // 调用aprioriGen方法,得到候选频繁k-项集      Set<Set<String>> candFreqKItemSet = this.aprioriGen(k-1, freqMItemSet);           // 扫描事务数据库      Iterator<Map.Entry<Integer, Set<String>>> it = txDatabase.entrySet().iterator();      // 统计支持数      while(it.hasNext()) {       Map.Entry<Integer, Set<String>> entry = it.next();       Iterator<Set<String>> kit = candFreqKItemSet.iterator();       while(kit.hasNext()) {        Set<String> kSet = kit.next();        Set<String> set = new HashSet<String>();        set.addAll(kSet);        set.removeAll(entry.getValue()); // 候选频繁k-项集与事务数据库中元素做差运算        if(set.isEmpty()) { // 如果拷贝set为空,支持数加1         if(candFreqKItemSetMap.get(kSet) == null) {          Integer value = 1;          candFreqKItemSetMap.put(kSet, value);         }         else {          Integer value = 1+candFreqKItemSetMap.get(kSet);          candFreqKItemSetMap.put(kSet, value);         }        }       }      }        // 计算支持度,生成频繁k-项集,并返回      return support(candFreqKItemSetMap);   }      /**   * 根据候选频繁k-项集,得到频繁k-项集   *    * @param candFreqKItemSetMap 候选k项集(包含支持计数)   * @return freqKItemSetMap 频繁k项集及其支持度(比例)   */   public Map<Set<String>, Float> support(Map<Set<String>, Integer> candFreqKItemSetMap) {      Map<Set<String>, Float> freqKItemSetMap = new HashMap<Set<String>, Float>();      Iterator<Map.Entry<Set<String>, Integer>> it = candFreqKItemSetMap.entrySet().iterator();      while(it.hasNext()) {       Map.Entry<Set<String>, Integer> entry = it.next();       // 计算支持度       Float supportRate = new Float(entry.getValue().toString())/new Float(txDatabaseCount);       if(supportRate<minSup) { // 如果不满足最小支持度,删除        it.remove();       }       else {        freqKItemSetMap.put(entry.getKey(), supportRate);       }      }      return freqKItemSetMap;   }      /**   * 挖掘全部频繁项集   */   public void mineFreqItemSet() {      // 计算频繁1-项集      Set<Set<String>> freqKItemSet = this.getFreq1ItemSet().keySet();      freqItemSet.put(1, freqKItemSet);      // 计算频繁k-项集(k>1)      int k = 2;      while(true) {       Map<Set<String>, Float> freqKItemSetMap = this.getFreqKItemSet(k, freqKItemSet);       if(!freqKItemSetMap.isEmpty()) {        this.freqItemSet.put(k, freqKItemSetMap.keySet());        freqKItemSet = freqKItemSetMap.keySet();       }       else {        break;       }       k++;      }   }      /**   * <P>挖掘频繁关联规则   * <P>首先挖掘出全部的频繁项集,在此基础上挖掘频繁关联规则   */   public void mineAssociationRules() {      freqItemSet.remove(1); // 删除频繁1-项集      Iterator<Map.Entry<Integer, Set<Set<String>>>> it = freqItemSet.entrySet().iterator();      while(it.hasNext()) {       Map.Entry<Integer, Set<Set<String>>> entry = it.next();       for(Set<String> itemSet : entry.getValue()) {        // 对每个频繁项集进行关联规则的挖掘        mine(itemSet);       }      }   }      /**   * 对从频繁项集集合freqItemSet中每迭代出一个频繁项集元素,执行一次关联规则的挖掘   * @param itemSet 频繁项集集合freqItemSet中的一个频繁项集元素   */   public void mine(Set<String> itemSet) {        int n = itemSet.size()/2; // 根据集合的对称性,只需要得到一半的真子集      for(int i=1; i<=n; i++) {       // 得到频繁项集元素itemSet的作为条件的真子集集合       Set<Set<String>> properSubset = ProperSubsetCombination.getProperSubset(i, itemSet);       // 对条件的真子集集合中的每个条件项集,获取到对应的结论项集,从而进一步挖掘频繁关联规则       for(Set<String> conditionSet : properSubset) {        Set<String> conclusionSet = new HashSet<String>();        conclusionSet.addAll(itemSet);        conclusionSet.removeAll(conditionSet); // 删除条件中存在的频繁项        confide(conditionSet, conclusionSet); // 调用计算置信度的方法,并且挖掘出频繁关联规则       }      }   }      /**   * 对得到的一个条件项集和对应的结论项集,计算该关联规则的支持计数,从而根据置信度判断是否是频繁关联规则   * @param conditionSet 条件频繁项集   * @param conclusionSet 结论频繁项集   */   public void confide(Set<String> conditionSet, Set<String> conclusionSet) {      // 扫描事务数据库      Iterator<Map.Entry<Integer, Set<String>>> it = txDatabase.entrySet().iterator();      // 统计关联规则支持计数      int conditionToConclusionCnt = 0; // 关联规则(条件项集推出结论项集)计数      int conclusionToConditionCnt = 0; // 关联规则(结论项集推出条件项集)计数      int supCnt = 0; // 关联规则支持计数      while(it.hasNext()) {       Map.Entry<Integer, Set<String>> entry = it.next();       Set<String> txSet = entry.getValue();       Set<String> set1 = new HashSet<String>();       Set<String> set2 = new HashSet<String>();       set1.addAll(conditionSet);             set1.removeAll(txSet); // 集合差运算:set-txSet       if(set1.isEmpty()) { // 如果set为空,说明事务数据库中包含条件频繁项conditionSet        // 计数        conditionToConclusionCnt++;       }       set2.addAll(conclusionSet);       set2.removeAll(txSet); // 集合差运算:set-txSet       if(set2.isEmpty()) { // 如果set为空,说明事务数据库中包含结论频繁项conclusionSet        // 计数        conclusionToConditionCnt++;              }       if(set1.isEmpty() && set2.isEmpty()) {        supCnt++;       }      }      // 计算置信度      Float conditionToConclusionConf = new Float(supCnt)/new Float(conditionToConclusionCnt);      if(conditionToConclusionConf>=minConf) {       if(assiciationRules.get(conditionSet) == null) { // 如果不存在以该条件频繁项集为条件的关联规则        Set<Set<String>> conclusionSetSet = new HashSet<Set<String>>();        conclusionSetSet.add(conclusionSet);        assiciationRules.put(conditionSet, conclusionSetSet);       }       else {        assiciationRules.get(conditionSet).add(conclusionSet);       }      }      Float conclusionToConditionConf = new Float(supCnt)/new Float(conclusionToConditionCnt);      if(conclusionToConditionConf>=minConf) {       if(assiciationRules.get(conclusionSet) == null) { // 如果不存在以该结论频繁项集为条件的关联规则        Set<Set<String>> conclusionSetSet = new HashSet<Set<String>>();        conclusionSetSet.add(conditionSet);        assiciationRules.put(conclusionSet, conclusionSetSet);       }       else {        assiciationRules.get(conclusionSet).add(conditionSet);       }      }   }   /**   * 经过挖掘得到的频繁项集Map   *    * @return 挖掘得到的频繁项集集合   */   public Map<Integer, Set<Set<String>>> getFreqItemSet() {      return freqItemSet;   }   /**   * 获取挖掘到的全部的频繁关联规则的集合   * @return 频繁关联规则集合   */   public Map<Set<String>, Set<Set<String>>> getAssiciationRules() {      return assiciationRules;   }  }

其中ProperSubsetCombination类,是用于生成真子集的辅助类:

import java.util.BitSet;  import java.util.HashSet;  import java.util.Set;  /**  * <B>求频繁项集元素(集合)的非空真子集集合</B>  * <P>从一个集合(大小为n)中取出m(m属于2~n/2的闭区间)个元素的组合实现类,获取非空真子集的集合  *   * @author king  * @date 2013/06/27   *   */  public class ProperSubsetCombination {   private static String[] array;   private static BitSet startBitSet; // 比特集合起始状态   private static BitSet endBitSet; // 比特集合终止状态,用来控制循环   private static Set<Set<String>> properSubset; // 真子集集合   /**   * 计算得到一个集合的非空真子集集合   *    * @param n 真子集的大小   * @param itemSet 一个频繁项集元素   * @return 非空真子集集合   */   public static Set<Set<String>> getProperSubset(int n, Set<String> itemSet) {      String[] array = new String[itemSet.size()];      ProperSubsetCombination.array = itemSet.toArray(array);      properSubset = new HashSet<Set<String>>();      startBitSet = new BitSet();      endBitSet = new BitSet();      // 初始化startBitSet,左侧占满1      for (int i=0; i<n; i++) {       startBitSet.set(i, true);      }      // 初始化endBit,右侧占满1      for (int i=array.length-1; i>=array.length-n; i--) {       endBitSet.set(i, true);      }           // 根据起始startBitSet,将一个组合加入到真子集集合中      get(startBitSet);              while(!startBitSet.equals(endBitSet)) {        int zeroCount = 0; // 统计遇到10后,左边0的个数        int oneCount = 0; // 统计遇到10后,左边1的个数        int pos = 0; // 记录当前遇到10的索引位置               // 遍历startBitSet来确定10出现的位置        for (int i=0; i<array.length; i++) {          if (!startBitSet.get(i)) {           zeroCount++;          }          if (startBitSet.get(i) && !startBitSet.get(i+1)) {           pos = i;           oneCount = i - zeroCount;           // 将10变为01           startBitSet.set(i, false);           startBitSet.set(i+1, true);           break;          }        }        // 将遇到10后,左侧的1全部移动到最左侧        int counter = Math.min(zeroCount, oneCount);        int startIndex = 0;        int endIndex = 0;        if(pos>1 && counter>0) {          pos--;          endIndex = pos;          for (int i=0; i<counter; i++) {           startBitSet.set(startIndex, true);           startBitSet.set(endIndex, false);           startIndex = i+1;           pos--;           if(pos>0) {            endIndex = pos;           }          }        }        get(startBitSet);      }        return properSubset;   }      /**   * 根据一次移位操作得到的startBitSet,得到一个真子集   * @param bitSet   */   private static void get(BitSet bitSet) {      Set<String> set = new HashSet<String>();      for(int i=0; i<array.length; i++) {       if(bitSet.get(i)) {        set.add(array[i]);       }      }      properSubset.add(set);   }  }



测试类如下:

import java.io.BufferedReader;  import java.io.File;  import java.io.FileNotFoundException;  import java.io.FileReader;  import java.io.IOException;  import java.util.HashMap;  import java.util.HashSet;  import java.util.Map;  import java.util.Set;  import java.util.TreeSet;    import junit.framework.TestCase;  /**  * <B>Apriori算法测试类</B>  *   * @author king  * @date 2013/07/28   */  public class AprioriTest extends TestCase {     private Apriori apriori;   private Map<Integer, Set<String>> txDatabase;   private Float minSup = new Float("0.50");   private Float minConf = new Float("0.70");      public static void main(String []args) throws Exception {    AprioriTest at = new AprioriTest();    at.setUp();        long from = System.currentTimeMillis();    at.testGetFreqItemSet();    long to = System.currentTimeMillis();    System.out.println("耗时:" + (to-from));      }      @Override   protected void setUp() throws Exception {  //     create(); // 构造事务数据库    this.buildData(Integer.MAX_VALUE, "f_faqk_.dat");       apriori = new Apriori(txDatabase, minSup, minConf);   }      /**   * 构造模拟事务数据库txDatabase   */   public void create() {      txDatabase = new HashMap<Integer, Set<String>>();      Set<String> set1 = new TreeSet<String>();      set1.add("A");      set1.add("B");      set1.add("C");      set1.add("E");      txDatabase.put(1, set1);      Set<String> set2 = new TreeSet<String>();      set2.add("A");      set2.add("B");      set2.add("C");      txDatabase.put(2, set2);      Set<String> set3 = new TreeSet<String>();      set3.add("C");      set3.add("D");      txDatabase.put(3, set3);      Set<String> set4 = new TreeSet<String>();      set4.add("A");      set4.add("B");      set4.add("E");      txDatabase.put(4, set4);   }      /**    * 构造数据集    * @param fileName 存储事务数据的文件名    * @param totalcount 获取的事务数    */   public void buildData(int totalCount, String...fileName) {    txDatabase = new HashMap<Integer, Set<String>>();    if(fileName.length !=0){     File file = new File(fileName[0]);     int count = 0;     try {      BufferedReader reader = new BufferedReader(new FileReader(file));      String line;      while( (line = reader.readLine()) != null){       String []arr = line.split(" ");       Set<String> set = new HashSet<String>();       for(String s : arr)        set.add(s);       count++;       this.txDatabase.put(count, set);              if(count >= totalCount) return;      }     } catch (FileNotFoundException e) {      e.printStackTrace();     } catch (IOException e) {      e.printStackTrace();     }    }else{    }   }      /**   * 测试挖掘频繁1-项集   */   public void testFreq1ItemSet() {      System.out.println("挖掘频繁1-项集 : " + apriori.getFreq1ItemSet());   }      /**   * 测试aprioriGen方法,生成候选频繁项集   */   public void testAprioriGen() {      System.out.println(        "候选频繁2-项集 : " +        this.apriori.aprioriGen(1, this.apriori.getFreq1ItemSet().keySet())        );   }      /**   * 测试挖掘频繁2-项集   */   public void testGetFreq2ItemSet() {      System.out.println(        "挖掘频繁2-项集 :" +        this.apriori.getFreqKItemSet(2, this.apriori.getFreq1ItemSet().keySet())        );   }      /**   * 测试挖掘频繁3-项集   */   public void testGetFreq3ItemSet() {      System.out.println(        "挖掘频繁3-项集 :" +        this.apriori.getFreqKItemSet(          3,           this.apriori.getFreqKItemSet(2, this.apriori.getFreq1ItemSet().keySet()).keySet()          )        );   }      /**   * 测试挖掘全部频繁项集   */   public void testGetFreqItemSet() {      this.apriori.mineFreqItemSet(); // 挖掘频繁项集      System.out.println("挖掘频繁项集 :" + this.apriori.getFreqItemSet());   }      /**   * 测试挖掘全部频繁关联规则   */   public void testMineAssociationRules() {      this.apriori.mineFreqItemSet(); // 挖掘频繁项集      this.apriori.mineAssociationRules();      System.out.println("挖掘频繁关联规则 :" + this.apriori.getAssiciationRules());   }  }

参考:http://hi.baidu.com/shirdrn/item/5b74a313d55256711009b5d8

在此基础上添加了has_infrequent_subset方法,此方法使用先验知识进行剪枝,是典型Apriori算法必备的。

来自: http://blog.csdn.net//kingzone_2008/article/details/17127567