A PySpark case
Unittesting Spark applications is not that straight-forward. For most of the cases you’ll probably need an active spark session, which means that your test cases will take a long time to run and that perhaps we’re tiptoeing around the boundaries of what can be called a unit test. But, it is definitely worth doing it.
So, should I? Well, yes! Testing your software is always a good thing, and it will most likely save you from many headaches, plus, you’ll be forced to have your code implemented in smaller bits and pieces that’ll be easier to test, thus, gain in readability and simplicity.
Okay, then what do I need to do this?
Well, I’d say we can start with pip install spark-testing-base
and work our way from there. We’ll also need pyspark
(of course) and unittest
(unittest2) and pytest
for this — even though pytest
is a personal preference.
holdenk/spark-testing-base
Base classes to use when writing tests with Spark. You've written an awesome program in Spark and now its time to write…github.com
Spark testing base is a collection of base classes to help with spark testing. For this example we’ll be using the base [SQLTestCase
](github.com/holdenk/spark-testing-base/blob/..) which inherits from SparkTestingBaseReuse
, that creates and reuses a SparkContext.
On SparkSession
and SparkContext
:
A tale of Spark Session and Spark Context
I have Spark Context, SQL context, Hive context already!medium.com
From personal experience (using the currently latest spark version 2.4.+), I’ve found that I needed to make some minor adjustments to the SQLTestCase
, which is a test case I use quite a lot in my current project. So, here’s an example of the adjustments I’ve made to suit my needs:
import traceback
from sparktestingbase.sqltestcase import SQLTestCase
class SparkSQLTestCase(SQLTestCase):
def getConf(self):
from pyspark import SparkConf
conf = SparkConf()
conf.set(
'spark.sql.session.timeZone', 'UTC'
)
# set shuffle partitions to a low number, e.g. <= cores * 2 to speed
# things up, otherwise the tests will use the default 200 partitions
# and it will take a lot more time to complete
conf.set('spark.sql.shuffle.partitions', '12')
return conf
def setUp(self):
try:
from pyspark.sql import SparkSession
self.session = SparkSession.builder.config(
conf=self.getConf()
).appName(
self.__class__.__name__
).getOrCreate()
self.sqlCtx = self.session._wrapped
except Exception:
traceback.print_exc()
from pyspark.sql import SQLContext
self.sqlCtx = SQLContext(self.sc)
def assertOrderedDataFrameEqual(self, expected, result, tol=0):
"""
Order both dataframes by the columns of the expected df before
comparing them.
"""
expected = expected.select(expected.columns).orderBy(expected.columns)
result = result.select(expected.columns).orderBy(expected.columns)
super(SQLTestCaseLatestSpark, self).assertDataFrameEqual(
expected, result, tol
)
def schema_nullable_helper(self, df, expected_schema, fields=None):
"""
Since column nullables cannot be easily changed after dataframe has
been created, given a dataframe df, an expected_schema and the fields
that need the nullable flag to be changed, return a dataframe with the
schema nullables as in the expected_schema (only for the fields
specified)
:param pyspark.sql.DataFrame df: the dataframe that needs schema
adjustments
:param pyspark.Schema expected_schema: the schema to be followed
:param list[str] fields: the fields that need adjustment of the
nullable flag
:return: the dataframe with the corrected nullable flags
:rtype: pyspark.sql.DataFrame
"""
new_schema = []
current_schema = df.schema
if not fields:
fields = df.columns
for item in current_schema:
if item.name in fields:
for expected_item in expected_schema:
if expected_item.name == item.name:
item.nullable = expected_item.nullable
new_schema.append(item)
new_schema = StructType(new_schema)
df = self.session.createDataFrame(
df.rdd, schema=new_schema
)
return df
To sum up the changes I’ve made:
I added a configuration to have the timezone set to
UTC
for consistency. Timezone consistency is something very basic to have throughout your code, so please make sure you always setspark.sql.session.timeZone
Another important thing to set in the configuration is the
spark.sql.shuffle.partitions
to something reasonable for the machine that will be running the tests, like<= cores * 2
. If we don’t do that, then spark will use the default value, which is 200 partitions, and it will unnecessarily and inevitably slow down the whole process.<= cores * 2
is a general good rule, not only for the tests.Also added a method to sort the dataframes to be compared before the comparison. There is a
compareRDDWithOrder
method in one of the base classes, but I think it is easier to work with dataframes.The
schema_nullable_helper
method should be used with caution, as it may end up sabotaging your test case, depending on what you need to test. The use case for this is for when you create dataframes without specifying a schema (which is currently deprecated), because spark tries to infer the data types, sometimes you have inconsistencies for the Nullable flag between the two dataframes to be compared, depending on the data used to create them. This method will update one of the two dataframes’ schema to what the other’s schema is regarding nullables only.And lastly, I added a slightly adjusted version of the setUp for the
appName
and theconfig
. Thesession
instantiation is also different in the latest pyspark version. (There is a pending release for the support of 2.2.+ and 2.3.+ spark versions still open here and here, so, we’ll be subclassing to work around this)
Note that getOrCreate()
will create a spark session once and then reuse it through out the test suite.
Now let’s create a simple feature to be tested:
from pyspark.sql import functions as F
class FeatureAToBRatio(object):
feature_name = 'a_to_b_ratio'
default_value = 0.
def calculate(self, df):
"""
Given a dataframe that contains columns a and b,
calculate a to b ratio. If b is 0 then the result will be
the feature's default value.
"""
df = df.withColumn(
self.feature_name,
F.when(
F.col('b') > 0.,
(
F.col('a').cast('float') / F.col('b').cast('float')
)
).otherwise(
self.default_value
)
).fillna({self.feature_name: self.default_value})
return df
This feature simply caclulates a/b
where a
and b
are columns in the input dataframe. Pretty simple and straight-forward. I didn’t include a column check here, because if either a
or b
are missing, we need the calculation process to fail hard enough to stop everything. But in general, it depends on how you want to handle error cases in your application and how important this calculation is to your process.
Something to note here: even if you mock the input to calculate
dataframe, the spark session will be needed because in our feature’s calculate
implementation we are using pyspark.sql.functions
like F.col('a')
which require you to have an active session. In case we didn’t have the session, we’d get an error like this:
Attribute error when using pyspark sql functions without an active session
This is more obvious if, for some reason, we need to declare calculations in the __init__
body (the constructor) of the feature, e.g.:
**class **FeatureAToBRatio(object):
feature_name = **'a_to_b_ratio'
**default_value = 0.
**def **__init__(self):
self.calculation = F.col(**'a'**).cast(**'float'**) / F.col(**'b'**).cast(**'float'**)
Then we’d get the error during the feature instantiation feature = FeatureAToBRatio()
.
Let’s now go ahead and add some test cases for this.
from pyspark.sql.utils import AnalysisException
from pyspark_unittesting import SparkSQLTestCase
class TestFeatureAToBRatio(SparkSQLTestCase):
def setUp(self):
super(TestFeatureAToBRatio, self).setUp()
self.feature = FeatureAToBRatio()
def helper_compare_actual_with_expected_df(
self, initial_data, expected_data
):
"""
Creates two dataframes, one from the initial data and one from the
expected data, runs feature calculation and compares the expected with
the actual dataframe
:param dict[str, T] initial_data:
:param dict[str, T] expected_data:
:return: None
"""
df = self.session.createDataFrame([initial_data])
expected_df = self.session.createDataFrame([expected_data])
result_df = self.feature.calculate(df)
self.assertDataFrameEqual(result_df, expected_df)
def test_calculate_simple_case(self):
data = {
"a": 1,
"b": 4,
"c": 155,
}
expected_data = {
"a": 1,
"b": 4,
"c": 155,
FeatureAToBRatio.feature_name: 0.25
}
self.helper_compare_actual_with_expected_df(data, expected_data)
def test_calculate_missing_column(self):
data = {
"a": 1,
"f": 4,
"c": 155,
}
df = self.session.createDataFrame([data])
with self.assertRaises(AnalysisException) as t_err:
self.feature.calculate(df)
self.assertTrue(
'cannot resolve \'`b`\' given input columns' in str(
t_err.exception)
)
def test_calculate_zero_denominator(self):
data = {
"a": 1,
"b": 0,
"c": 155,
}
expected_data = {
"a": 1,
"b": 0,
"c": 155,
FeatureAToBRatio.feature_name: FeatureAToBRatio.default_value
}
self.helper_compare_actual_with_expected_df(data, expected_data)
We are testing for:
the normal case, which is
a
andb
exist and have numeric valuesthe exception where one of them, e.g.
b
does not exist in the dataframethe case where the denominator is 0.
These are just some of the basic test cases one could test for. There are many others that come to mind, e.g. what happens if a
is null or a different kind of datatype, but for the example’s sake, let’s keep it simple and clean.
To run the test suite:
python -m pytest test_feature_a_to_b_ratio.py
Example output of tests execution
Example output of tests execution — ran with PyCharm
And that’s it! Notice that it did take 7.58 seconds to run (14.72 seconds when the shuffle partitions are set to the default 200), which is a bit much for unittesting and it’s only 3 test cases — imagine having a CI/ CD that runs the test suites on every merge or commit…
Of course, there is a lot more complex testing to be done with spark / pyspark, but I think this is a good base to build upon. Let me know if there’s a better way to do this.
I hope this was helpful. Any thoughts, questions, corrections and suggestions are very welcome :)
If you want to know more about how Spark works, take a look at: Explaining technical stuff in a non-technical way — Apache Spark What is Spark and PySpark and what can I do with it?towardsdatascience.com Adding sequential IDs to a Spark Dataframe How to do it and is it a good idea?towardsdatascience.com