Collective的Spark ML经验分享:读者模型
Collective成立于2005年,其总部位于纽约,是一家从事 数字广告业务的公司。 该公司的数字广告业务非常依赖于机器学习和预测模型,对于特定的用户在特定的时间应该投放什么样的广告完全是由实时或者离线的机器学习模型决定的。本文来 自Databricks的技术博客,Eugene Zhulenev分享了自己在Collective公司 从事机器学习和读者模型工作的经验。
Collective公司有很多使用机器学习的项目,这些项目可以统称为读者模型,因为这些项目都是基于用户的浏览历史、行为数据等因素预测读者转 化、点击率等信息的。在机器学习库的选择上,Collective公司内部新开发的大部分项目都是基于Spark和Spark MLLib的,对于一些被大家广泛使用而Spark并不具备的工具和类库Collective还专门创建了一个扩展库Spark Ext。在本文中,Eugene Zhulenev介绍了如何使用Spark Ext和Spark ML两个类库基于地理位置信息和浏览历史数据来预测用户转化。
预测数据
预测数据包含两种数据集,虽然这些数据都是使用虚拟的数据生成器生成的,但是它们与数字广告所使用的真实数据非常相似。这两类数据分别是:
用户的浏览历史日志
Cookie | Site | Impressions --------------- |-------------- | ------------- wKgQaV0lHZanDrp | live.com | 24 wKgQaV0lHZanDrp | pinterest.com | 21 rfTZLbQDwbu5mXV | wikipedia.org | 14 rfTZLbQDwbu5mXV | live.com | 1 rfTZLbQDwbu5mXV | amazon.com | 1 r1CSY234HTYdvE3 | 油Tube.com | 10
经纬度地理位置日志
Cookie | Lat | Lng | Impressions --------------- |---------| --------- | ------------ wKgQaV0lHZanDrp | 34.8454 | 77.009742 | 13 wKgQaV0lHZanDrp | 31.8657 | 114.66142 | 1 rfTZLbQDwbu5mXV | 41.1428 | 74.039600 | 20 rfTZLbQDwbu5mXV | 36.6151 | 119.22396 | 4 r1CSY234HTYdvE3 | 42.6732 | 73.454185 | 4 r1CSY234HTYdvE3 | 35.6317 | 120.55839 | 5 20ep6ddsVckCmFy | 42.3448 | 70.730607 | 21 20ep6ddsVckCmFy | 29.8979 | 117.51683 | 1
转换预测数据
正如上面所展示的,预测数据是长格式,对于每一个cookie与之相关的记录有多条,通常情况下,这种格式并不适合于机器学习算法,需要将其转换成“主键——特征向量”的形式。
Gather转换程序
受到了R语音 tidyr和reshape2包的启发,Collective将每一个键对应的值的长数据框(long DataFrame)转换成一个宽数据框(wide DataFrame),如果某个键对应多个值就应用聚合函数。
val gather = new Gather() .setPrimaryKeyCols("cookie") .setKeyCol("site") .setValueCol("impressions") .setValueAgg("sum") //通过key对impression的值求和 .setOutputCol("sites") val gatheredSites = gather.transform(siteLog)
转换后的结果
Cookie | Sites -----------------|---------------------------------------------- wKgQaV0lHZanDrp | [ | { site: live.com, impressions: 24.0 }, | { site: pinterest.com, impressions: 21.0 } | ] rfTZLbQDwbu5mXV | [ | { site: wikipedia.org, impressions: 14.0 }, | { site: live.com, impressions: 1.0 }, | { site: amazon.com, impressions: 1.0 } | ]
Google S2几何单元Id转换程序
Google S2几何类库是一个球面几何类库,该库非常适合于操作球面(通常是地球)上的区域和索引地理数据,它会为地球上的每一个区域分配一个唯一的单元Id。
为了将经纬度信息转换成键值对的形式,Eugene Zhulenev结合使用了S2类库和Gather,转换后数据的键值是S2的单元Id。
// Transform lat/lon into S2 Cell Id val s2Transformer = new S2CellTransformer() .setLevel(5) .setCellCol("s2_cell") // Gather S2 CellId log val gatherS2Cells = new Gather() .setPrimaryKeyCols("cookie") .setKeyCol("s2_cell") .setValueCol("impressions") .setOutputCol("s2_cells") val gatheredCells = gatherS2Cells.transform(s2Transformer.transform(geoDf))
转换后的结果
Cookie | S2 Cells -----------------|---------------------------------------------- wKgQaV0lHZanDrp | [ | { s2_cell: d5dgds, impressions: 5.0 }, | { s2_cell: b8dsgd, impressions: 1.0 } | ] rfTZLbQDwbu5mXV | [ | { s2_cell: d5dgds, impressions: 12.0 }, | { s2_cell: b8dsgd, impressions: 3.0 }, | { s2_cell: g7aeg3, impressions: 5.0 } | ]
生成特征向量
虽然Gather程序将与某个cookie相关的所有信息都组织到了一行中,变成了键值对的形式,但是这种形式依然不能作为机器学习算法的输入。为了能够训练一个模型,预测数据需要表示成double类型的向量。
Gather 编码程序
使用虚拟变量对明确的键值对进行编码。
// Encode S2 Cell data val encodeS2Cells = new GatherEncoder() .setInputCol("s2_cells") .setOutputCol("s2_cells_f") .setKeyCol("s2_cell") .setValueCol("impressions") .setCover(0.95) // dimensionality reduction
原始数据
Cookie | S2 Cells -----------------|---------------------------------------------- wKgQaV0lHZanDrp | [ | { s2_cell: d5dgds, impressions: 5.0 }, | { s2_cell: b8dsgd, impressions: 1.0 } | ] rfTZLbQDwbu5mXV | [ | { s2_cell: d5dgds, impressions: 12.0 }, | { s2_cell: g7aeg3, impressions: 5.0 } | ]
转换后的结果
Cookie | S2 Cells Features -----------------|------------------------ wKgQaV0lHZanDrp | [ 5.0 , 1.0 , 0 ] rfTZLbQDwbu5mXV | [ 12.0 , 0 , 5.0 ]
对于转换后的结果,用户还可以根据场景选择性地使用顶部转换进行降维。首先计算不同用户每个特征的值,然后根据特征值进行降序排序,最后从结果 列表中选择最上面那些数值总和占所有用户总和的百分比超过某个阈值(例如,选择最上面覆盖99%用户的那些网站)的数据作为最终的分类值。
Spark ML 管道
Spark ML 管道是Spark MLLib的一个新的高层API。一个真正的ML管道通常会包含数据预处理、特征提取、模型拟合和验证几个阶段。例如,文本文档的分类可能会涉及到文本分 割与清理、特征提取、使用交叉验证训练分类模型这几步。在使用Spark ML时,用户能够将一个ML管道拆分成多个独立的阶段,然后可以在一个单独的管道中将他们组合到一起,最后使用交叉验证和参数网格运行该管道从而找到最佳 参数集合。
使用Spark ML管道将它们组合到一起
// Encode site data val encodeSites = new GatherEncoder() .setInputCol("sites") .setOutputCol("sites_f") .setKeyCol("site") .setValueCol("impressions") // Encode S2 Cell data val encodeS2Cells = new GatherEncoder() .setInputCol("s2_cells") .setOutputCol("s2_cells_f") .setKeyCol("s2_cell") .setValueCol("impressions") .setCover(0.95) // Assemble feature vectors together val assemble = new VectorAssembler() .setInputCols(Array("sites_f", "s2_cells_f")) .setOutputCol("features") // Build logistic regression val lr = new LogisticRegression() .setFeaturesCol("features") .setLabelCol("response") .setProbabilityCol("probability") // Define pipeline with 4 stages val pipeline = new Pipeline() .setStages(Array(encodeSites, encodeS2Cells, assemble, lr)) val evaluator = new BinaryClassificationEvaluator() .setLabelCol(Response.response) val crossValidator = new CrossValidator() .setEstimator(pipeline) .setEvaluator(evaluator) val paramGrid = new ParamGridBuilder() .addGrid(lr.elasticNetParam, Array(0.1, 0.5)) .build() crossValidator.setEstimatorParamMaps(paramGrid) crossValidator.setNumFolds(2) println(s"Train model on train set") val cvModel = crossValidator.fit(trainSet)
结论
Spark ML API让机器学习变得更加容易。同时,用户还可以通过Spark Ext创建自定义的转换/估计,并对这些自定义的内容进行组装使其成为更大管道中的一部分,此外这些程序还能够很容易地在多个项目中共享和重用。如果想要查看本示例的代码,可以点击这里。
来自:http://www.infoq.com/cn/news/2015/11/collective-audience-model-exampl