参考:
http://zhangyi.farbox.com/post/framework/udf-and-udaf-in-spark
https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.sql.functions$
UDF的引入极大地丰富了Spark SQL的表现力。一方面,它让我们享受了利用Scala(当然,也包括Java或Python)更为自然地编写代码实现函数的福利,另一方面,又能精简SQL(或者DataFrame的API),更加写意自如地完成复杂的数据分析。尤其采用SQL语句去执行数据分析时,UDF帮助我们在SQL函数与Scala函数之间左右逢源,还可以在一定程度上化解不同数据源具有歧异函数的尴尬。想想不同关系数据库处理日期或时间的函数名称吧!
UDF
注册制
用Scala编写的UDF与普通的Scala函数没有任何区别,唯一需要多执行的一个步骤是要让SQLContext注册它。例如:
1 |
|
编写的UDF可以放到SQL语句的fields部分,也可以作为where、groupBy或者having子句的一部分。
若使用DataFrame的API,则可以以字符串的形式将UDF传入:1
val booksWithLongTitle = dataFrame.filter("longLength(title, 10)")
非注册制
DataFrame的API也可以接收Column对象,可以用$符号来包裹一个字符串表示一个Column。$是定义在SQLContext对象implicits中的一个隐式转换。此时,UDF的定义也不相同,不能直接定义Scala函数,而是要用定义在org.apache.spark.sql.functions中的udf方法来接收一个函数。这种方式无需register:1
2
3
4
5
6import org.apache.spark.sql.functions._
val longLength = udf((bookTitle: String, length: Int) => bookTitle.length > length)
import sqlContext.implicits._
val booksWithLongTitle = dataFrame.filter(longLength($"title", $"10"))
不幸,运行这段代码会抛出异常:
1 | cannot resolve '10' given input columns id, title, author, price, publishedDate; |
因为采用$来包裹一个常量,会让Spark错以为这是一个Column。这时,需要定义在org.apache.spark.sql.functions中的lit函数来帮助:1
val booksWithLongTitle = dataFrame.filter(longLength($"title", lit(10)))
UDAF
参考 https://databricks.com/blog/2015/09/16/apache-spark-1-5-dataframe-api-highlights.html
普通的UDF却也存在一个缺陷,就是无法在函数内部支持对表数据的聚合运算。例如,当我要对销量执行年度同比计算,就需要对当年和上一年的销量分别求和,然后再利用同比公式进行计算。此时,UDF就无能为力了。该UDAF(User Defined Aggregate Function)粉墨登场的时候了。
Spark为所有的UDAF定义了一个父类UserDefinedAggregateFunction。要继承这个类,需要实现父类的几个抽象方法:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15def inputSchema: StructType 输入参数类型,映射为每一个Field
def bufferSchema: StructType 中间结果类型
def dataType: DataType 返回结果
def deterministic: Boolean 对于一组输入是否输出相同的结果
def initialize(buffer: MutableAggregationBuffer): Unit 初始化buffer
def update(buffer: MutableAggregationBuffer, input: Row): Unit 更新row和buffer
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit merge两个buffer
def evaluate(buffer: Row): Any 计算最终结果
可以将inputSchema理解为UDAF与DataFrame列有关的输入样式。例如年同比函数需要对某个可以运算的指标与时间维度进行处理,就需要在inputSchema中定义它们。1
2
3def inputSchema: StructType = {
StructType(StructField("metric", DoubleType) :: StructField("timeCategory", DateType) :: Nil)
}
代码创建了拥有两个StructField的StructType。StructField的名字并没有特别要求,完全可以认为是两个内部结构的列名占位符。至于UDAF具体要操作DataFrame的哪个列,取决于调用者,但前提是数据类型必须符合事先的设置,如这里的DoubleType与DateType类型。这两个类型被定义在org.apache.spark.sql.types中。
bufferSchema用于定义存储聚合运算时产生的中间数据结果的Schema,例如我们需要存储当年与上一年的销量总和,就需要定义两个StructField:1
2
3def bufferSchema: StructType = {
StructType(StructField("sumOfCurrent", DoubleType) :: StructField("sumOfPrevious", DoubleType) :: Nil)
}
dataType标明了UDAF函数的返回值类型,deterministic是一个布尔值,用以标记针对给定的一组输入,UDAF是否总是生成相同的结果。
顾名思义,initialize就是对聚合运算中间结果的初始化,在我们这个例子中,两个求和的中间值都被初始化为0d:1
2
3
4def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0.0)
buffer.update(1, 0.0)
}
update函数的第一个参数为bufferSchema中两个Field的索引,默认以0开始,所以第一行就是针对“sumOfCurrent”的求和值进行初始化。
UDAF的核心计算都发生在update函数中。在我们这个例子中,需要用户设置计算同比的时间周期。这个时间周期值属于外部输入,但却并非inputSchema的一部分,所以应该从UDAF对应类的构造函数中传入。我为时间周期定义了一个样例类,且对于同比函数,我们只要求输入当年的时间周期,上一年的时间周期可以通过对年份减1来完成:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17case class DateRange(startDate: Timestamp, endDate: Timestamp) {
def in(targetDate: Date): Boolean = {
targetDate.before(endDate) && targetDate.after(startDate)
}
}
class YearOnYearBasis(current: DateRange) extends UserDefinedAggregateFunction {
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (current.in(input.getAs[Date](1))) {
buffer(0) = buffer.getAs[Double](0) + input.getAs[Double](0)
}
val previous = DateRange(subtractOneYear(current.startDate), subtractOneYear(current.endDate))
if (previous.in(input.getAs[Date](1))) {
buffer(1) = buffer.getAs[Double](0) + input.getAs[Double](0)
}
}
}
update函数的第二个参数input: Row对应的并非DataFrame的行,而是被inputSchema投影了的行。以本例而言,每一个input就应该只有两个Field的值。倘若我们在调用这个UDAF函数时,分别传入了销量和销售日期两个列的话,则input(0)代表的就是销量,input(1)代表的就是销售日期。
merge函数负责合并两个聚合运算的buffer,再将其存储到MutableAggregationBuffer中:1
2
3
4def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Double](0) + buffer2.getAs[Double](0)
buffer1(1) = buffer1.getAs[Double](1) + buffer2.getAs[Double](1)
}
最后,由evaluate函数完成对聚合Buffer值的运算,得到最后的结果:1
2
3
4
5
6def evaluate(buffer: Row): Any = {
if (buffer.getDouble(1) == 0.0)
0.0
else
(buffer.getDouble(0) - buffer.getDouble(1)) / buffer.getDouble(1) * 100
}
假设我们创建了这样一个简单的DataFrame:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17val conf = new SparkConf().setAppName("TestUDF").setMaster("local[*]")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
val sales = Seq(
(1, "Widget Co", 1000.00, 0.00, "AZ", "2014-01-01"),
(2, "Acme Widgets", 2000.00, 500.00, "CA", "2014-02-01"),
(3, "Widgetry", 1000.00, 200.00, "CA", "2015-01-11"),
(4, "Widgets R Us", 2000.00, 0.0, "CA", "2015-02-19"),
(5, "Ye Olde Widgete", 3000.00, 0.0, "MA", "2015-02-28")
)
val salesRows = sc.parallelize(sales, 4)
val salesDF = salesRows.toDF("id", "name", "sales", "discount", "state", "saleDate")
salesDF.registerTempTable("sales")
那么,要使用之前定义的UDAF,则需要实例化该UDAF类,然后再通过udf进行注册:1
2
3
4
5
6val current = DateRange(Timestamp.valueOf("2015-01-01 00:00:00"), Timestamp.valueOf("2015-12-31 00:00:00"))
val yearOnYear = new YearOnYearBasis(current)
sqlContext.udf.register("yearOnYear", yearOnYear)
val dataFrame = sqlContext.sql("select yearOnYear(sales, saleDate) as yearOnYear from sales")
dataFrame.show()
在使用上,除了需要对UDAF进行实例化之外,与普通的UDF使用没有任何区别。但显然,UDAF更加地强大和灵活。如果Spark自身没有提供符合你需求的函数,且需要进行较为复杂的聚合运算,UDAF是一个不错的选择。
通过Spark提供的UDF与UDAF,你可以慢慢实现属于自己行业的函数库,让Spark SQL变得越来越强大,对于使用者而言,却能变得越来越简单。