UDAFs In Spark

Today, we’re going to talk about User-Defined Aggregation Functions (UDAFs) and Dataset Aggregators in Spark, similar but different ways of adding custom aggregation abilities to Spark’s DataFrames. This may be of interest to some. For the rest of you, umm…I suggest a cup of tea and some digestives instead.

User-Defined Aggregation Functions

UDAFs as a concept come from the world of databases, and are related to User-Defined Functions (UDFs) While UDFs operate on a column in a row in an independent fashion (e.g. transforming an entry of ‘5/2/2016’ into ‘Monday’), UDAFs operate across rows to produce a result. The simplest example would be COUNT, an aggregator that just increments an value for every row in the database table that it sees and then returns that number as a result. Or SUM, which might add up a column in every row in the table.

Or, to put it another way: UDFs are map(), UDAFs are reduce().

Normally things like SUM and COUNT will be built into the data manipulation framework you’re using; UDAFs come into their own for implementing custom logic in a reusable manner. If you and your team often need to generate a custom probability distribution for your warehouses’ delivery times, maybe you can implement it as a UDAF once and then everybody can get access to it without having to reimplement the logic over repeated queries.

UDAFs in Spark

Adding a UDF in Spark is simply a matter of registering a function. UDAFs, however, are a little more complicated. Instead of a function, you have to implement a class that extends UserDefinedAggregateFunction. Here’s a UDAF that implements harmonic mean:


class HarmonicMean() extends UserDefinedAggregateFunction {

  def deterministic: Boolean = true
  def inputSchema: StructType = StructType(Array(StructField("value", DoubleType)))
  def dataType: DataType = DoubleType
 
  def bufferSchema = StructType(Array(
    StructField("sum", DoubleType),
    StructField("count", LongType)
  ))
  
  def initialize(buffer: MutableAggregationBuffer) = {
    buffer(0) = 0.toDouble
    buffer(1) = 0L
  }
 
  def update(buffer: MutableAggregationBuffer, input: Row) = {
    buffer(0) = buffer.getDouble(0) + ( 1 / input.getDouble(0))
    buffer(1) = buffer.getLong(1) + 1
  }
 
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }
 
  def evaluate(buffer: Row): Double = {
    buffer.getLong(1).toDouble / buffer.getDouble(0) 
  }
 }

Let’s walk through the class to see what each of the methods do. Firstly, deterministic() is a simple flag that tells Spark if this UDAF will always return the same value if given the same inputs (in this example, the Harmonic Mean should always be the same given the same inputs, or else we’ve got bigger problems). The next two methods, inputSchema() and dataType() specify the input and output data formats. We’re not doing anything crazy here - just requiring that our input column is a double and that our output will also be a double. You’re free to create UDAFs with weird and wonderful type signatures though, bringing in multiple columns and outputting anything you like.

With those out of the way, we now need to specify the schema of our buffer. The buffer is a mutable object that will hold our in-process calculations. For calculating the harmonic mean, we’re going to need a running count of the sum of the reciprocals, plus another variable which can count the numbers we’ve seen for the final calculation. The bufferSchema is defined as a StructType, here with an array of two StructFields, one of type Double and the other as type Long.

Having finally set up all the types (come on, this is Scala! The typing is all the fun, right? Right? Anybody?), we can implement the methods that will calculate our mean. As you might expect, initialize() is called first. Here, we’re making sure that both the sum and the count fields are initialized with zero values.

update() and merge() are where your aggregations happen. update() takes two arguments, a MutableAggregationBuffer where aggregation has already taken place, and a new Row which needs to be processed. In this example, we add the reciprocal value of the incoming row to the buffer (note that we don’t do a reciprocal on the buffer because the contents of the buffer have already been processed).

merge(), on the other hand, merges two already-aggregated buffers together. This is needed because Spark will likely split the execution of the UDAF across many executors (which is want you’d want, of course!) and it needs a way of combining those aggregations for the result. Here, like in many UDAF examples that compute a mean, the merge() is very straightforward. We just need to sum the two different counts, and the two different sums of reciprocals. Your custom merging logic may be more complicated than this.1

Finally, there’s evaluate(). This gets called at the end of the UDAF’s processing. In this example, evaulate() actually produces the harmonic mean result we’re looking for by dividing the count by the sum of the reciprocals.

Using UDAFs

Having defined the UDAF, how do you actually use it? Well, it’s so easy, like UDFs, you get two choices. Firstly, there’s the fairly-obvious method of using it in DataFrame aggregations, like this:

val hm = new HarmonicMean()
val df = sc.parallelize(Seq(1,2,4)).toDF("value")
df.agg(hm(col("value"))).show

But you can also register the UDAF and use it transparently within SparkSQL queries:

sqlContext.udf.register("hm",hm)
sqlContext.sql("SELECT hm(value) AS hm FROM df")

As you can imagine, the latter method is a great way of providing additional functionality to your Spark platform which can be introduced to your analytics team without having to step outside of their SQL comfort zone.

Behind the scenes

UDAFs are implemented as SparkUDAF, a class that extends ImperativeAggregate (This is one of the two AggregateFunctions available in Spark - the other being DeclarativeAggregate which works directly with Catalyst expressions rather than the row-based approach of ImperativeAggregate).

You can trace through the AggregationIterator class to see how Spark walks through the execution of the aggregators - it’s not especially pretty, but it does work!

What Happened to Dataset Aggregators?

I spent a bit longer on the UDAFs than I planned, so I’ll do a separate follow-up post where I look at Dataset Aggregators.


  1. Essentially ‘may you live in interesting times’, but for Spark. ↩︎