Written by Enrico Minack, Open Source Software Contributor.
Apache Spark is very popular when it comes to processing tabular data of arbitrary size. One common operation is to group the data by some columns to further process those grouped data. Spark has two ways of grouping data groupBy
and groupByKey
, while the latter works, it may cause performance issues in some cases. As good practice, avoid groupByKey
whenever possible to prevent those performance issues.
Grouping data
Spark provides two ways to group and process data. Grouping can be done via groupBy
and groupByKey
. These functions return a RelationalGroupedDataset
and a KeyValueGroupedDataset[K, V]
, respectively.
If you are already familiar with the differences between these two types of grouped datasets, you can jump right into the performance implications section further down. If you are not familiar with these two types, you may rightly ask:
Why are there two different types of grouped data?
Either type of grouped dataset provides different operations on groups. But before we dive into processing groups, here is an example dataset ds
, that we use throughout this article:
// just here to make sure our tiny example dataset does not optimize joins spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1) case class Val(id: Long, number: Int) val ds = Seq((1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)).toDF("id", "number").as[Val] ds.show
+---+------+ | id|number| +---+------+ | 1| 1| | 1| 2| | 1| 3| | 2| 2| | 2| 3| | 3| 3| +---+------+
We will use column "id"
as the grouping column, so we will get three groups.
Aggregating groups
Both types of grouped datasets RelationalGroupedDataset
and KeyValueGroupedDataset[K, V]
allow for aggregating groups. This is the most common way to process grouped data. An aggregate function returns a single row per group.
Grouping with groupBy
and aggregating the groups returns a DataFrame
with schema id: int, sum: bigint
:
ds.groupBy("id") .agg(sum("number").as("sum")) .show()
+---+---+ | id|sum| +---+---+ | 1| 6| | 2| 5| | 3| 3| +---+---+
Grouping with groupByKey
and aggregating the groups returns a DataSet[(Int, Int)]
with schema key: int, sum: bigint
:
ds.groupByKey(row => row.id) .agg(sum("number").as("sum").as[Int]) .show()
+---+---+ |key|sum| +---+---+ | 1| 6| | 2| 5| | 3| 3| +---+---+
We have seen that both types of grouped data are pretty similar when it comes to aggregating groups, but only one allows us to iterate group values.
Iterate group values
Only KeyValueGroupedDataset[K, V]
allows to process groups with a function defined by the user. That function obtains an iterator and can return an arbitrary number of rows. Hence, the groupByKey
grouped data can be processed into none, one or many rows per group. The user has more possibilities processing the groups with groupByKey
than with groupBy
.
// return first and last element of iterator def firstAndLast[T](id: Long, it: Iterator[T]): Iterator[T] = { if (it.hasNext) { val first = it.next if (it.hasNext) { Iterator(first, it.reduceLeft((b, n) => n)) } else { Iterator.single(first) } } else { Iterator.empty } } // now get the first and last row of each group // group by row id ds.groupByKey(row => row.id) // call firstAndLast for each id and group iterator .flatMapGroups(firstAndLast) .show()
+---+-----+ | id|value| +---+-----+ | 1| 1| | 1| 3| | 2| 2| | 2| 3| | 3| 3| +---+-----+
We have seen that the two grouped datasets RelationalGroupedDataset
and KeyValueGroupedDataset[K, V]
provide similar, but also differing, functions on grouped data.
But are there also differences other than functional (API)?
In fact, there can be a significant performance penalty using one over the other.
Performance Considerations
Before Spark can process individual groups, it first has to rearrange the data. This is expensive and involves partitioning the data by the group columns. Calling one of the mapGroups
or flatMapGroups
methods additionally involves sorting the individual partitions.
If data is already partitioned and sorted, Spark skips these steps and processes the groups right away. This saves a lot of time and processing power. But it can only do so if it knows which columns are used for grouping.
A common situation where data come already partitioned and sorted, occurs after performing a join on a prospect group column:
// join dataset ds with another dataset on column id ds.join(spark.range(4), "id").as[Val] // group by column id .groupByKey(row => row.id) // process groups via iterator .mapGroups((id, it) => id * it.size) // show the query plan .explain
== Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- SerializeFromObject [input[0, bigint, false] AS value#824L] +- MapGroups …, [value#821L], [id#806L, number#807], obj#823: bigint +- Sort [value#821L ASC NULLS FIRST], false, 0 <=== ❌ +- Exchange hashpartitioning(value#821L, 200), ENSURE_REQUIREMENTS, [id=#1283] <=== ❌ +- AppendColumns …, [input[0, bigint, false] AS value#821L] +- Project [id#806L, number#807] +- SortMergeJoin [id#806L], [id#811L], Inner :- Sort [id#806L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(id#806L, 200), …, [id=#1275] : +- LocalTableScan [id#806L, number#807] +- Sort [id#811L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(id#811L, 200), …, [id=#1276] +- Range (0, 3, step=1, splits=16)
We can read the following from this query plan:
The join (SortMergeJoin
) triggers the partitioning (Exchange hashpartitioning(id#806L, 200)
) and sorting (Sort [id#806L ASC NULLS FIRST]
) of dataset ds
(and the second dataset (Range (0, 3, step=1, splits=16)
)) by column "id"
. Our joined dataset (Project [id#806L, number#807]
) is partitioned and sorted by column "id"
. Grouping that dataset should not require another partitioning or sorting step.
Then, the grouping key is added as a new column "value"
(AppendColumns … AS value#821L
), by executing the function row => row.id
for each row. Spark cannot know that the values of this new column are equivalent to column "id"
, because the expression row => row.id
is a Scala function that is opaque to Spark. So Spark partitions and sorts all data by column value#821L
, not knowing that this is redundant.
Avoid groupByKey(...)
, better use groupBy(...).as[...]
If we were to use groupBy("id")
instead, Spark would know the missing bit. But how can we access mapGroups
and flatMapGroups
methods when using groupBy
rather than groupByKey
?
We can get from RelationalGroupedDataset
to KeyValueGroupedDataset[K, V]
via as [K, V]
:
// join dataset ds with another dataset on column id ds.join(spark.range(4), "id").as[Val] // group by column id .groupBy("id") // turn into a KeyValueGroupedDataset[Long, (Long, Int)] .as[Long, (Long, Int)] // process groups via iterator .mapGroups((id, it) => id * it.size) // show the query plan .explain
== Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- SerializeFromObject [input[0, bigint, false] AS value#845L] +- MapGroups …, [id#806L], [id#806L, number#807], obj#844: bigint +- Project [id#806L, number#807] +- SortMergeJoin [id#806L], [id#827L], Inner :- Sort [id#806L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(id#806L, 200), …, [id=#1312] : +- LocalTableScan [id#806L, number#807] +- Sort [id#827L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(id#827L, 200), …, [id=#1313] +- Range (0, 3, step=1, splits=16)
Now that Spark knows we want to group by column "id"
, it skips the outer partitioning and sorting.
Any other method of KeyValueGroupedDataset[K, V]
also benefits from using groupBy(...).as[...]
. For instance, aggregating groups skips the partitioning step (no sorting involved):
First with groupByKey
:
// join dataset ds with another dataset on column id ds.join(spark.range(4), "id").as[Val] // group by column id .groupByKey(row => row.id) // process groups via iterator .agg(sum("number").as("sum").as[Int]) // show the query plan .explain
== Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- HashAggregate(keys=[value#917L], functions=[sum(number#807)]) +- Exchange hashpartitioning(value#917L, 200), ENSURE_REQUIREMENTS, [id=#1397] <=== ❌ +- HashAggregate(keys=[value#917L], functions=[partial_sum(number#807)]) +- Project [number#807, value#917L] +- AppendColumns …, [input[0, bigint, false] AS value#917L] +- Project [id#806L, number#807] +- SortMergeJoin [id#806L], [id#907L], Inner :- Sort [id#806L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(id#806L, 200), …, [id=#1387] : +- LocalTableScan [id#806L, number#807] +- Sort [id#907L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(id#907L, 200), …, [id=#1388] +- Range (0, 3, step=1, splits=16)
We again see a redundant partitioning (Exchange hashpartitioning(value#917L, 200)
).
Now with groupBy(...).as[...]
:
// join dataset ds with another dataset on column id ds.join(spark.range(4), "id").as[Val] // group by column id .groupBy("id") // turn into a KeyValueGroupedDataset[Long, (Long, Int)] .as[Long, (Long, Int)] // process groups via iterator .agg(sum("number").as("sum").as[Int]) // show the query plan .explain
== Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- HashAggregate(keys=[id#806L], functions=[sum(number#807)]) +- HashAggregate(keys=[id#806L], functions=[partial_sum(number#807)]) +- Project [id#806L, number#807] +- SortMergeJoin [id#806L], [id#881L], Inner :- Sort [id#806L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(id#806L, 200), …, [id=#1347] : +- LocalTableScan [id#806L, number#807] +- Sort [id#881L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(id#881L, 200), …, [id=#1348] +- Range (0, 3, step=1, splits=16)
Working with DataFrame
s.
Using groupBy(...).as[...]
with DataFrame
s is a bit tricky, as you need to provide an encoder for the Row values.
It is easiest to reuse the encoder of the dataframe that is being grouped:
import org.apache.spark.sql.Encoders import org.apache.spark.sql.Row // join dataset ds with another dataset on column id, creates a DataFrame val df = ds.join(spark.range(4), "id") // group dataframe by column id df.groupBy("id") // turn into a KeyValueGroupedDataset[Long, Row] .as[Long, Row](Encoders.scalaLong, df.encoder) // process groups via iterator .agg(sum("number").as("sum").as[Int]) // show the query plan .explain
Summary
We have seen that using groupByKey
can have a significant impact on performance when data is already partitioned and sorted. It can be considered good practice to prefer groupBy(...)
(RelationalGroupedDataset
) over groupByKey(...)
(KeyValueGroupedDataset[K, V]
).
If you really need to use KeyValueGroupedDataset[K, V]
use groupBy(...).as[K, V]
instead of groupByKey(...)
. This allows for Spark’s query optimisation.