PySpark: Trim All String Columns

Kyle Gibson
3 min readNov 29, 2023

--

Do you ever have string columns in your Spark DataFrames that have extra white-space around them? If you’re like me, you’ve sometimes had to apply the PySpark trim function on each column that needed the white-space removed. Like this:

from pyspark.sql.functions import col, trim

df_original = spark.read.format('delta').load('your_file_path')

df_trimmed = df_original\
.withColumn('col1', trim(col('col1')))\
.withColumn('col2', trim(col('col2')))

As you can imagine, that method gets tedious quickly. There is an easier way.

Solution

If you just want the solution, and not the explanation, here it is:

from pyspark.sql import DataFrame
from pyspark.sql.functions import col, trim

def trim_all_string_columns(df: DataFrame) -> DataFrame:
return df\
.select(
*[trim(col(c[0])).alias(c[0]) if c[1] == 'string' else col(c[0]) for c in df.dtypes]
)

df_original = spark.read.format('delta').load('your_file_path')

df_trimmed = df_original\
.transform(trim_all_string_columns)

That’s it. Just use that logic in the trim_all_string_columns function and you should be good to go.

If you want further explanation, keep reading.

Explanation

Suppose you have a DataFrame with this schema:

Sample dataframe with 9 columns
display of df_original dataframe

As you can see, col1, col2, col3, col4, and col5 are string columns, while the others are not. We only need to trim those first five columns and leave the others as they are.

The dtypes attribute of the DataFrame will help us accomplish this.

Here is the output of df_original.dtypes:

printed dtypes for df_original dataframe

It is a list of tuples, where the first element in each tuple is the column name, and the second element is the column data type.

If we print out the results of iterating through this list, we can see how to access each index in the tuple:

printed results of each index item of dtypes for df_original dataframe

So we can programmatically determine the column name and column data type of each column in our DataFrame.

By using Python’s list comprehension and iterating through the .dtypes attribute, we can trim the column if it’s string type, and return the original column (without trimming) if it’s not:

from pyspark.sql import DataFrame
from pyspark.sql.functions import col, trim

def trim_all_string_columns(df: DataFrame) -> DataFrame:
return df\
.select(
*[trim(col(c[0])).alias(c[0]) if c[1] == 'string' else col(c[0]) for c in df.dtypes]
)

df_original = spark.read.format('delta').load('your_file_path')

df_trimmed = df_original\
.transform(trim_all_string_columns)

The transform attribute of the DataFrame allows us to chain the trim_all_string_columns function to the original DataFrame. The benefits of this are in case we have multiple transformations we want to apply to df_original.

You could also do this if you wanted:

df_trimmed = trim_all_string_columns(df_original)

That’s it! Now it is much simpler to just apply this function to any Spark DataFrame you have, and it will automatically trim every string column.

Hope this helps.

Thanks for reading!

--

--

Kyle Gibson

Christian, husband, father. Data Engineer at Chick-fil-A.