Unittesting Apache Spark Applications

Unittesting Apache Spark Applications

A PySpark case

Unittesting Spark applications is not that straight-forward. For most of the cases you’ll probably need an active spark session, which means that your test cases will take a long time to run and that perhaps we’re tiptoeing around the boundaries of what can be called a unit test. But, it is definitely worth doing it.

So, should I? Well, yes! Testing your software is always a good thing, and it will most likely save you from many headaches, plus, you’ll be forced to have your code implemented in smaller bits and pieces that’ll be easier to test, thus, gain in readability and simplicity.

Okay, then what do I need to do this? Well, I’d say we can start with pip install spark-testing-base and work our way from there. We’ll also need pyspark (of course) and unittest (unittest2) and pytest for this — even though pytest is a personal preference. holdenk/spark-testing-base Base classes to use when writing tests with Spark. You've written an awesome program in Spark and now its time to write…github.com

Spark testing base is a collection of base classes to help with spark testing. For this example we’ll be using the base [SQLTestCase](github.com/holdenk/spark-testing-base/blob/..) which inherits from SparkTestingBaseReuse, that creates and reuses a SparkContext.

On SparkSession and SparkContext: A tale of Spark Session and Spark Context I have Spark Context, SQL context, Hive context already!medium.com

From personal experience (using the currently latest spark version 2.4.+), I’ve found that I needed to make some minor adjustments to the SQLTestCase , which is a test case I use quite a lot in my current project. So, here’s an example of the adjustments I’ve made to suit my needs:

import traceback
from sparktestingbase.sqltestcase import SQLTestCase



class SparkSQLTestCase(SQLTestCase):
    def getConf(self):
        from pyspark import SparkConf
        conf = SparkConf()
        conf.set(
            'spark.sql.session.timeZone', 'UTC'
        )
        # set shuffle partitions to a low number, e.g. <= cores * 2 to speed
        # things up, otherwise the tests will use the default 200 partitions
        # and it will take a lot more time to complete
        conf.set('spark.sql.shuffle.partitions', '12')
        return conf

    def setUp(self):
        try:
            from pyspark.sql import SparkSession
            self.session = SparkSession.builder.config(
                conf=self.getConf()
            ).appName(
                self.__class__.__name__
            ).getOrCreate()
            self.sqlCtx = self.session._wrapped
        except Exception:
            traceback.print_exc()
            from pyspark.sql import SQLContext
            self.sqlCtx = SQLContext(self.sc)

    def assertOrderedDataFrameEqual(self, expected, result, tol=0):
        """
        Order both dataframes by the columns of the expected df before
        comparing them.
        """
        expected = expected.select(expected.columns).orderBy(expected.columns)
        result = result.select(expected.columns).orderBy(expected.columns)
        super(SQLTestCaseLatestSpark, self).assertDataFrameEqual(
            expected, result, tol
        )

    def schema_nullable_helper(self, df, expected_schema, fields=None):
        """
        Since column nullables cannot be easily changed after dataframe has
        been created, given a dataframe df, an expected_schema and the fields
        that need the nullable flag to be changed, return a dataframe with the
        schema nullables as in the expected_schema (only for the fields
        specified)
        :param pyspark.sql.DataFrame df: the dataframe that needs schema
        adjustments
        :param pyspark.Schema expected_schema: the schema to be followed
        :param list[str] fields: the fields that need adjustment of the
        nullable flag
        :return: the dataframe with the corrected nullable flags
        :rtype: pyspark.sql.DataFrame
        """
        new_schema = []
        current_schema = df.schema
        if not fields:
            fields = df.columns

        for item in current_schema:
            if item.name in fields:
                for expected_item in expected_schema:
                    if expected_item.name == item.name:
                        item.nullable = expected_item.nullable
            new_schema.append(item)

        new_schema = StructType(new_schema)
        df = self.session.createDataFrame(
            df.rdd, schema=new_schema
        )
        return df

To sum up the changes I’ve made:

  • I added a configuration to have the timezone set to UTC for consistency. Timezone consistency is something very basic to have throughout your code, so please make sure you always set spark.sql.session.timeZone

  • Another important thing to set in the configuration is the spark.sql.shuffle.partitions to something reasonable for the machine that will be running the tests, like &lt;= cores * 2. If we don’t do that, then spark will use the default value, which is 200 partitions, and it will unnecessarily and inevitably slow down the whole process. &lt;= cores * 2 is a general good rule, not only for the tests.

  • Also added a method to sort the dataframes to be compared before the comparison. There is a compareRDDWithOrder method in one of the base classes, but I think it is easier to work with dataframes.

  • The schema_nullable_helper method should be used with caution, as it may end up sabotaging your test case, depending on what you need to test. The use case for this is for when you create dataframes without specifying a schema (which is currently deprecated), because spark tries to infer the data types, sometimes you have inconsistencies for the Nullable flag between the two dataframes to be compared, depending on the data used to create them. This method will update one of the two dataframes’ schema to what the other’s schema is regarding nullables only.

  • And lastly, I added a slightly adjusted version of the setUp for the appName and the config. The session instantiation is also different in the latest pyspark version. (There is a pending release for the support of 2.2.+ and 2.3.+ spark versions still open here and here, so, we’ll be subclassing to work around this)

Note that getOrCreate() will create a spark session once and then reuse it through out the test suite.

Now let’s create a simple feature to be tested:

from pyspark.sql import functions as F


class FeatureAToBRatio(object):
    feature_name = 'a_to_b_ratio'
    default_value = 0.

    def calculate(self, df):
        """
        Given a dataframe that contains columns a and b,
        calculate a to b ratio. If b is 0 then the result will be
        the feature's default value.
        """
        df = df.withColumn(
            self.feature_name,
            F.when(
                F.col('b') > 0.,
                (
                    F.col('a').cast('float') / F.col('b').cast('float')
                )
            ).otherwise(
                self.default_value
            )
        ).fillna({self.feature_name: self.default_value})

        return df

This feature simply caclulates a/b where a and b are columns in the input dataframe. Pretty simple and straight-forward. I didn’t include a column check here, because if either a or b are missing, we need the calculation process to fail hard enough to stop everything. But in general, it depends on how you want to handle error cases in your application and how important this calculation is to your process.

Something to note here: even if you mock the input to calculate dataframe, the spark session will be needed because in our feature’s calculate implementation we are using pyspark.sql.functions like F.col('a') which require you to have an active session. In case we didn’t have the session, we’d get an error like this:

Attribute error when using pyspark sql functions without an active sessionAttribute error when using pyspark sql functions without an active session

This is more obvious if, for some reason, we need to declare calculations in the __init__ body (the constructor) of the feature, e.g.:

**class **FeatureAToBRatio(object):
    feature_name = **'a_to_b_ratio'
    **default_value = 0.

    **def **__init__(self):
        self.calculation = F.col(**'a'**).cast(**'float'**) / F.col(**'b'**).cast(**'float'**)

Then we’d get the error during the feature instantiation feature = FeatureAToBRatio().

Let’s now go ahead and add some test cases for this.

from pyspark.sql.utils import AnalysisException
from pyspark_unittesting import SparkSQLTestCase


class TestFeatureAToBRatio(SparkSQLTestCase):

    def setUp(self):
        super(TestFeatureAToBRatio, self).setUp()
        self.feature = FeatureAToBRatio()

    def helper_compare_actual_with_expected_df(
            self, initial_data, expected_data
    ):
        """
        Creates two dataframes, one from the initial data and one from the
        expected data, runs feature calculation and compares the expected with
        the actual dataframe
        :param dict[str, T] initial_data:
        :param dict[str, T] expected_data:
        :return: None
        """
        df = self.session.createDataFrame([initial_data])
        expected_df = self.session.createDataFrame([expected_data])
        result_df = self.feature.calculate(df)

        self.assertDataFrameEqual(result_df, expected_df)

    def test_calculate_simple_case(self):
        data = {
            "a": 1,
            "b": 4,
            "c": 155,
        }
        expected_data = {
            "a": 1,
            "b": 4,
            "c": 155,
            FeatureAToBRatio.feature_name: 0.25
        }
        self.helper_compare_actual_with_expected_df(data, expected_data)

    def test_calculate_missing_column(self):
        data = {
            "a": 1,
            "f": 4,
            "c": 155,
        }
        df = self.session.createDataFrame([data])

        with self.assertRaises(AnalysisException) as t_err:
            self.feature.calculate(df)

        self.assertTrue(
            'cannot resolve \'`b`\' given input columns' in str(
                t_err.exception)
        )

    def test_calculate_zero_denominator(self):
        data = {
            "a": 1,
            "b": 0,
            "c": 155,
        }
        expected_data = {
            "a": 1,
            "b": 0,
            "c": 155,
            FeatureAToBRatio.feature_name: FeatureAToBRatio.default_value
        }
        self.helper_compare_actual_with_expected_df(data, expected_data)

We are testing for:

  • the normal case, which is a and b exist and have numeric values

  • the exception where one of them, e.g. b does not exist in the dataframe

  • the case where the denominator is 0.

These are just some of the basic test cases one could test for. There are many others that come to mind, e.g. what happens if a is null or a different kind of datatype, but for the example’s sake, let’s keep it simple and clean.

To run the test suite:

python -m pytest test_feature_a_to_b_ratio.py

Example output of tests executionExample output of tests execution

Example output of tests execution — ran with PyCharmExample output of tests execution — ran with PyCharm

And that’s it! Notice that it did take 7.58 seconds to run (14.72 seconds when the shuffle partitions are set to the default 200), which is a bit much for unittesting and it’s only 3 test cases — imagine having a CI/ CD that runs the test suites on every merge or commit…

Of course, there is a lot more complex testing to be done with spark / pyspark, but I think this is a good base to build upon. Let me know if there’s a better way to do this.

I hope this was helpful. Any thoughts, questions, corrections and suggestions are very welcome :)

If you want to know more about how Spark works, take a look at: Explaining technical stuff in a non-technical way — Apache Spark What is Spark and PySpark and what can I do with it?towardsdatascience.com Adding sequential IDs to a Spark Dataframe How to do it and is it a good idea?towardsdatascience.com