推荐算法之Slope One Java 及 PHP实现

来源:互联网 发布:视频剪辑特效软件 编辑:程序博客网 时间:2024/05/16 00:33

这两个貌似都是原作者自己写的

import java.util.*;/** * Daniel Lemire A simple implementation of the weighted slope one algorithm in * Java for item-based collaborative filtering. Assumes Java 1.5. *  * See main function for example. *  * June 1st 2006. Revised by Marco Ponzi on March 29th 2007 */public class SlopeOne {  public static void main(String args[]) {    // this is my data base    Map<UserId, Map<ItemId, Float>> data = new HashMap<UserId, Map<ItemId, Float>>();    // items    ItemId item1 = new ItemId("       candy");    ItemId item2 = new ItemId("         dog");    ItemId item3 = new ItemId("         cat");    ItemId item4 = new ItemId("         war");    ItemId item5 = new ItemId("strange food");    mAllItems = new ItemId[] { item1, item2, item3, item4, item5 };    // I'm going to fill it in    HashMap<ItemId, Float> user1 = new HashMap<ItemId, Float>();    HashMap<ItemId, Float> user2 = new HashMap<ItemId, Float>();    HashMap<ItemId, Float> user3 = new HashMap<ItemId, Float>();    HashMap<ItemId, Float> user4 = new HashMap<ItemId, Float>();    user1.put(item1, 1.0f);    user1.put(item2, 0.5f);    user1.put(item4, 0.1f);    data.put(new UserId("Bob"), user1);    user2.put(item1, 1.0f);    user2.put(item3, 0.5f);    user2.put(item4, 0.2f);    data.put(new UserId("Jane"), user2);    user3.put(item1, 0.9f);    user3.put(item2, 0.4f);    user3.put(item3, 0.5f);    user3.put(item4, 0.1f);    data.put(new UserId("Jo"), user3);    user4.put(item1, 0.1f);    // user4.put(item2,0.4f);    // user4.put(item3,0.5f);    user4.put(item4, 1.0f);    user4.put(item5, 0.4f);    data.put(new UserId("StrangeJo"), user4);    // next, I create my predictor engine    SlopeOne so = new SlopeOne(data);    System.out.println("Here's the data I have accumulated...");    so.printData();    // then, I'm going to test it out...    HashMap<ItemId, Float> user = new HashMap<ItemId, Float>();    System.out.println("Ok, now we predict...");    user.put(item5, 0.4f);    System.out.println("Inputting...");    SlopeOne.print(user);    System.out.println("Getting...");    SlopeOne.print(so.predict(user));    //    user.put(item4, 0.2f);    System.out.println("Inputting...");    SlopeOne.print(user);    System.out.println("Getting...");    SlopeOne.print(so.predict(user));  }  Map<UserId, Map<ItemId, Float>> mData;  Map<ItemId, Map<ItemId, Float>> mDiffMatrix;  Map<ItemId, Map<ItemId, Integer>> mFreqMatrix;  static ItemId[] mAllItems;  public SlopeOne(Map<UserId, Map<ItemId, Float>> data) {    mData = data;    buildDiffMatrix();  }  /**   * Based on existing data, and using weights, try to predict all missing   * ratings. The trick to make this more scalable is to consider only   * mDiffMatrix entries having a large (>1) mFreqMatrix entry.   *    * It will output the prediction 0 when no prediction is possible.   */  public Map<ItemId, Float> predict(Map<ItemId, Float> user) {    HashMap<ItemId, Float> predictions = new HashMap<ItemId, Float>();    HashMap<ItemId, Integer> frequencies = new HashMap<ItemId, Integer>();    for (ItemId j : mDiffMatrix.keySet()) {      frequencies.put(j, 0);      predictions.put(j, 0.0f);    }    for (ItemId j : user.keySet()) {      for (ItemId k : mDiffMatrix.keySet()) {        try {          float newval = (mDiffMatrix.get(k).get(j).floatValue() + user.get(j)              .floatValue()) * mFreqMatrix.get(k).get(j).intValue();          predictions.put(k, predictions.get(k) + newval);          frequencies.put(k, frequencies.get(k)              + mFreqMatrix.get(k).get(j).intValue());        } catch (NullPointerException e) {        }      }    }    HashMap<ItemId, Float> cleanpredictions = new HashMap<ItemId, Float>();    for (ItemId j : predictions.keySet()) {      if (frequencies.get(j) > 0) {        cleanpredictions.put(j, predictions.get(j).floatValue()            / frequencies.get(j).intValue());      }    }    for (ItemId j : user.keySet()) {      cleanpredictions.put(j, user.get(j));    }    return cleanpredictions;  }  /**   * Based on existing data, and not using weights, try to predict all missing   * ratings. The trick to make this more scalable is to consider only   * mDiffMatrix entries having a large (>1) mFreqMatrix entry.   */  public Map<ItemId, Float> weightlesspredict(Map<ItemId, Float> user) {    HashMap<ItemId, Float> predictions = new HashMap<ItemId, Float>();    HashMap<ItemId, Integer> frequencies = new HashMap<ItemId, Integer>();    for (ItemId j : mDiffMatrix.keySet()) {      predictions.put(j, 0.0f);      frequencies.put(j, 0);    }    for (ItemId j : user.keySet()) {      for (ItemId k : mDiffMatrix.keySet()) {        // System.out.println("Average diff between "+j+" and "+ k +        // " is "+mDiffMatrix.get(k).get(j).floatValue()+" with n = "+mFreqMatrix.get(k).get(j).floatValue());        float newval = (mDiffMatrix.get(k).get(j).floatValue() + user.get(j)            .floatValue());        predictions.put(k, predictions.get(k) + newval);      }    }    for (ItemId j : predictions.keySet()) {      predictions.put(j, predictions.get(j).floatValue() / user.size());    }    for (ItemId j : user.keySet()) {      predictions.put(j, user.get(j));    }    return predictions;  }  public void printData() {    for (UserId user : mData.keySet()) {      System.out.println(user);      print(mData.get(user));    }    for (int i = 0; i < mAllItems.length; i++) {      System.out.print("\n" + mAllItems[i] + ":");      printMatrixes(mDiffMatrix.get(mAllItems[i]),          mFreqMatrix.get(mAllItems[i]));    }  }  private void printMatrixes(Map<ItemId, Float> ratings,      Map<ItemId, Integer> frequencies) {    for (int j = 0; j < mAllItems.length; j++) {      System.out.format("%10.3f", ratings.get(mAllItems[j]));      System.out.print(" ");      System.out.format("%10d", frequencies.get(mAllItems[j]));    }    System.out.println();  }  public static void print(Map<ItemId, Float> user) {    for (ItemId j : user.keySet()) {      System.out.println(" " + j + " --> " + user.get(j).floatValue());    }  }  public void buildDiffMatrix() {    mDiffMatrix = new HashMap<ItemId, Map<ItemId, Float>>();    mFreqMatrix = new HashMap<ItemId, Map<ItemId, Integer>>();    // first iterate through users    for (Map<ItemId, Float> user : mData.values()) {      // then iterate through user data      for (Map.Entry<ItemId, Float> entry : user.entrySet()) {        if (!mDiffMatrix.containsKey(entry.getKey())) {          mDiffMatrix.put(entry.getKey(), new HashMap<ItemId, Float>());          mFreqMatrix.put(entry.getKey(), new HashMap<ItemId, Integer>());        }        for (Map.Entry<ItemId, Float> entry2 : user.entrySet()) {          int oldcount = 0;          if (mFreqMatrix.get(entry.getKey()).containsKey(entry2.getKey()))            oldcount = mFreqMatrix.get(entry.getKey()).get(entry2.getKey())                .intValue();          float olddiff = 0.0f;          if (mDiffMatrix.get(entry.getKey()).containsKey(entry2.getKey()))            olddiff = mDiffMatrix.get(entry.getKey()).get(entry2.getKey())                .floatValue();          float observeddiff = entry.getValue() - entry2.getValue();          mFreqMatrix.get(entry.getKey()).put(entry2.getKey(), oldcount + 1);          mDiffMatrix.get(entry.getKey()).put(entry2.getKey(),              olddiff + observeddiff);        }      }    }    for (ItemId j : mDiffMatrix.keySet()) {      for (ItemId i : mDiffMatrix.get(j).keySet()) {        float oldvalue = mDiffMatrix.get(j).get(i).floatValue();        int count = mFreqMatrix.get(j).get(i).intValue();        mDiffMatrix.get(j).put(i, oldvalue / count);      }    }  }}class UserId {  String content;  public UserId(String s) {    content = s;  }  public int hashCode() {    return content.hashCode();  }  public String toString() {    return content;  }}class ItemId {  String content;  public ItemId(String s) {    content = s;  }  public int hashCode() {    return content.hashCode();  }  public String toString() {    return content;  }}

 # This is the code in plain text out of the technical report.## Daniel Lemire, Sean McGrath, Implementing a Rating-Based Item-to-Item# Recommender System in PHP/SQL, Technical Report D-01, January 2005.## http://www.ondelette.com/lemire/abstracts/TRD01.html## This code is in the public domain, use at your own risks.# It is assumed that you looked at the report and know some SQL and PHP.## Daniel Lemire, February 3rd 2005## First part is sample SQL code.#########CUT HERE####################CREATE TABLE rating (    userID INT PRIMARY KEY,    itemID INT NOT NULL,    ratingValue INT NOT NULL,    datetimestamp TIMESTAMP NOT NULL);CREATE TABLE dev (  itemID1 int(11) NOT NULL default '0',  itemID2 int(11) NOT NULL default '0',  count int(11) NOT NULL default '0',  sum int(11) NOT NULL default '0',  PRIMARY KEY  (itemID1,itemID2));# simple query to output 10 most liked items# by people who rated item 1SELECT itemID2, ( sum / count ) AS averageFROM devWHERE count > 2 AND itemID1 = 1ORDER  BY ( sum / count ) DESCLIMIT 10;# Next part is sample PHP code.#########CUT HERE####################// This code assumes $itemID is set to that of // the item that was just rated. // Get all of the user's rating pairs$sql = "SELECT DISTINCT r.itemID, r2.ratingValue - r.ratingValue             as rating_difference            FROM rating r, rating r2            WHERE r.userID=$userID AND                     r2.itemID=$itemID AND                     r2.userID=$userID;";$db_result = mysql_query($sql, $connection);$num_rows = mysql_num_rows($db_result);//For every one of the user's rating pairs, //update the dev tablewhile ($row = mysql_fetch_assoc($db_result)) {    $other_itemID = $row["itemID"];    $rating_difference = $row["rating_difference"];    //if the pair ($itemID, $other_itemID) is already in the dev table    //then we want to update 2 rows.    if (mysql_num_rows(mysql_query("SELECT itemID1     FROM dev WHERE itemID1=$itemID AND itemID2=$other_itemID",    $connection)) > 0)  {        $sql = "UPDATE dev SET count=count+1, sum=sum+$rating_difference WHERE itemID1=$itemID AND itemID2=$other_itemID";        mysql_query($sql, $connection);//We only want to update if the items are different                        if ($itemID != $other_itemID) {            $sql = "UPDATE dev SET count=count+1,     sum=sum-$rating_difference     WHERE (itemID1=$other_itemID AND itemID2=$itemID)";            mysql_query($sql, $connection);        }    }    else { //we want to insert 2 rows into the dev table        $sql = "INSERT INTO dev VALUES ($itemID, $other_itemID,        1, $rating_difference)";        mysql_query($sql, $connection); //We only want to insert if the items are different               if ($itemID != $other_itemID) {                     $sql = "INSERT INTO dev VALUES ($other_itemID,     $itemID, 1, -$rating_difference)";            mysql_query($sql, $connection);        }    }    }function predict($userID, $itemID) {    global $connection;        $denom = 0.0; //denominator    $numer = 0.0; //numerator        $k = $itemID;        $sql = "SELECT r.itemID, r.ratingValue     FROM rating r WHERE r.userID=$userID AND r.itemID <> $itemID";    $db_result = mysql_query($sql, $connection);            //for all items the user has rated    while ($row = mysql_fetch_assoc($db_result))  {        $j = $row["itemID"];        $ratingValue = $row["ratingValue"];                //get the number of times k and j have both been rated by the same user        $sql2 = "SELECT d.count, d.sum FROM dev d WHERE itemID1=$k AND itemID2=$j";        $count_result = mysql_query($sql2, $connection);                //skip the calculation if it isn't found        if(mysql_num_rows($count_result) > 0)  {            $count = mysql_result($count_result, 0, "count");            $sum = mysql_result($count_result, 0, "sum");                        //calculate the average            $average = $sum / $count;                        //increment denominator by count            $denom += $count;                        //increment the numerator            $numer += $count * ($average + $ratingValue);        }            }        if ($denom == 0)        return 0;    else        return ($numer / $denom);}function predict_all($userID ) {    $sql2 = "SELECT d.itemID1 as 'item', sum(d.count) as 'denom',     sum(d.sum + d.count*r.ratingValue) as 'numer' FROM rating r,    dev d WHERE r.userID=$userID     AND d.itemID1 NOT IN     (SELECT itemID FROM rating WHERE userID=$userID)      AND d.itemID2=r.itemID GROUP BY d.itemID1";    return mysql_query($sql2, $connection);}function predict_best($userID, $n ) {    $sql2 = "SELECT d.itemID1 as 'item',     sum(d.sum + d.count*r.ratingValue)/sum(d.count) as 'avgrat'     FROM  rating r, dev d     WHERE r.userID=$userID     AND d.itemID1 NOT IN     (SELECT itemID FROM rating WHERE userID=$userID)      AND d.itemID2=r.itemID     GROUP BY d.itemID1 ORDER BY avgrat DESC LIMIT $n";    return mysql_query($sql2, $connection);}

转自

http://lemire.me/fr/documents/publications/webpaper.txt

http://lemire.me/fr/documents/publications/SlopeOne.java

原创粉丝点击