PySpark: Using Arrays to Avoid Date Cross Joins
Everybody says to avoid cross joins, but sometimes we’re backed into a corner and don’t see any way to avoid them.
This article will look at a couple of techniques in PySpark to accomplish the same thing as a cross join without actually using a cross join.
Scenario:
Let’s assume we have a Spark DataFrame of business units:
We need a row for each month ending date for each business_unit. To make it easier on ourselves, let’s assume we only need rows for the last 12 months.
The typical cross join method would involve reading in some kind of Dim Dates DataFrame. Here is an example of what that could look like:
df_dim_dates = (
spark.read.load('path_to_dim_date_dataset')
.select(col('end_of_month_date').alias('month_end_date'))
.distinct()
.withColumn('current_month_last_day', last_day(current_date()))
.filter(col('month_end_date') < col('current_month_last_day'))
.withColumn('rank', rank().over(Window.orderBy(col('month_end_date').desc())))
.filter(col('rank') <= lit(12))
.drop('rank', 'current_month_last_day')
)
Then you have a df_dim_dates DataFrame that looks something like this:
Now if you want one row per business_unit per month_end_date, you would typically do something like this…
df_cross_joined = df_original.crossJoin(df_dim_dates)
or like this…
df_cross_joined = df_original.join(df_dim_dates)
…which gives you a result like this:
This resulting DataFrame is the Cartesian product of the two DataFrames. In other words, there is one row per business_unit per month_end_date. Since we had 4 rows in df_original, and 12 rows in df_dim_dates, we end up with 48 (4*12) rows in the cross joined DataFrame.
While this method works fine for a small dataset, it typically does not scale well with larger datasets. Let’s look at a couple of other methods to accomplish this same thing that may be more performant.
Method 1: Arrays from Python List
Step 1: Generate the list of dates using your favorite Python method. (There’s several different ways this could be done):
from datetime import datetime, timedelta
def get_previous_month_ending_dates(lookback_months):
current_date = datetime.now()
month_ending_dates = []
for i in range(lookback_months):
first_day_of_month = current_date.replace(day=1)
last_day_of_prev_month = first_day_of_month - timedelta(days=1)
month_ending_dates.append(last_day_of_prev_month.strftime("%Y-%m-%d"))
current_date = last_day_of_prev_month
return month_ending_dates
month_ending_dates = get_previous_month_ending_dates(lookback_months=12)
This should generate a Python list that looks something like this:
Step 2: Create an array type column on df_original with this list:
from pyspark.sql.functions import array, lit, to_date
df_array = (df_original
.withColumn('month_end_date', array(*[to_date(lit(date)) for date in month_ending_dates])))
This should generate an array type month_end_date column like this:
Step 3: Explode the month_end_date column:
from pyspark.sql.functions import col, explode
df_array_exploded = (df_array
.withColumn('month_end_date', explode(col('month_end_date'))))
This should generate a date type month_end_date column like this:
Just like in the cross joined example, we end up with the same 48 rows.
Step 4: That’s it!
Now let’s look at another similar (and in my opinion, better) method that utilizes a PySpark technique instead.
Method 2: Arrays from PySpark Functions
Step 1: Create an array type column of month_end_date on df_original using PySpark functions:
from pyspark.sql.functions expr
df_array = (df_original
.withColumn('month_end_date', expr('sequence(add_months(last_day(current_date()), -12), add_months(last_day(current_date()), -1), interval 1 month)')))
Just like in Method 1, this generates an array type month_end_date column:
Step 2: Explode the month_end_date column:
from pyspark.sql.functions import col, explode
df_array_exploded = (df_array
.withColumn('month_end_date', explode(col('month_end_date'))))
This should generate a date type month_end_date column like this:
Just like in the cross joined example and Method 1 from above , we end up with the same 48 rows.
Step 4: That’s it!
Sometimes we end up with a need to cross join DataFrames. Hopefully this article showed you a couple of ways you can achieve the same result without having to perform the cross join.
Thanks for reading!