From 8b8132156dc22d0af7e6c744b50aee8c99941666 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Thu, 14 Dec 2023 04:57:36 -0800 Subject: [PATCH 1/4] Add support for Spark Connect DataFrames --- requirements.in | 4 +-- setup.py | 2 +- src/pyspark_test.py | 29 ++++++++++++++----- .../unit_test/test_assert_pyspark_df_equal.py | 18 +++++++++++- 4 files changed, 41 insertions(+), 12 deletions(-) diff --git a/requirements.in b/requirements.in index 506c151..a538f0e 100644 --- a/requirements.in +++ b/requirements.in @@ -1,4 +1,4 @@ black==20.8b1 -pyspark==2.4.7 +pyspark==3.4.0 pytest-testdox==2.0.1 -pytest==6.1.1 \ No newline at end of file +pytest==7.3.1 \ No newline at end of file diff --git a/setup.py b/setup.py index bc579bf..c5d267a 100644 --- a/setup.py +++ b/setup.py @@ -19,5 +19,5 @@ def readme(): license="Apache Software License (Apache 2.0)", py_modules=["pyspark_test"], package_dir={"": "src"}, - install_requires=["pyspark>=2.1.2"], + install_requires=["pyspark>=3.4.0"], ) diff --git a/src/pyspark_test.py b/src/pyspark_test.py index f2c6777..dd3f680 100644 --- a/src/pyspark_test.py +++ b/src/pyspark_test.py @@ -2,15 +2,28 @@ import pyspark +try: + from pyspark.sql.connect.dataframe import DataFrame as CDF + has_connect_deps = True +except ImportError: + has_connect_deps = False -def _check_isinstance(left: Any, right: Any, cls): - assert isinstance( - left, cls - ), f"Left expected type {cls}, found {type(left)} instead" - assert isinstance( - right, cls - ), f"Right expected type {cls}, found {type(right)} instead" +def _check_isinstance_df(left: Any, right: Any): + types_to_test = [pyspark.sql.DataFrame] + if has_connect_deps: + types_to_test.append(CDF) + + left_good = any(map(lambda x: isinstance(left, x), types_to_test)) + right_good = any(map(lambda x: isinstance(right, x), types_to_test)) + assert left_good, \ + f"Left expected type {pyspark.sql.DataFrame} or {CDF}, found {type(left)} instead" + assert right_good, \ + f"Right expected type {pyspark.sql.DataFrame} or {CDF}, found {type(right)} instead" + + # Check that both sides are of the same DataFrame type. + assert type(left) == type(right), \ + f"Left and right DataFrames are not of the same type: {type(left)} != {type(right)}" def _check_columns( check_columns_in_order: bool, @@ -88,7 +101,7 @@ def assert_pyspark_df_equal( """ # Check if - _check_isinstance(left_df, right_df, pyspark.sql.DataFrame) + _check_isinstance_df(left_df, right_df) # Check Column Names if check_column_names: diff --git a/tests/unit_test/test_assert_pyspark_df_equal.py b/tests/unit_test/test_assert_pyspark_df_equal.py index 90cd2d7..5d5901d 100644 --- a/tests/unit_test/test_assert_pyspark_df_equal.py +++ b/tests/unit_test/test_assert_pyspark_df_equal.py @@ -12,6 +12,7 @@ ) from src.pyspark_test import assert_pyspark_df_equal +from src.pyspark_test import _check_isinstance_df class TestAssertPysparkDfEqual: @@ -68,7 +69,7 @@ def test_assert_pyspark_df_equal_one_is_not_pyspark_df( right_df = "Demo" with pytest.raises( AssertionError, - match="Right expected type , found instead", + match="Right expected type or .*?, found instead", ): assert_pyspark_df_equal(left_df, right_df) @@ -324,3 +325,18 @@ def test_assert_pyspark_df_equal_different_row_count( match="Number of rows are not same.\n \n Actual Rows: 2\n Expected Rows: 3", ): assert_pyspark_df_equal(left_df, right_df) + + def test_instance_checks_for_spark_connect( + self, spark_session: pyspark.sql.SparkSession + ): + from pyspark.sql.connect.dataframe import DataFrame as CDF + left_df = spark_session.range(1) + right_df = spark_session.range(1) + _check_isinstance_df(left_df, right_df) + + left_df = CDF.withPlan(None, None) + right_df = CDF.withPlan(None, None) + _check_isinstance_df(left_df, right_df) + + with pytest.raises(AssertionError): + _check_isinstance_df(spark_session.range(1), right_df) From 5dce31a548390afb49c93bf01ccc99c53fbcc3ef Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Thu, 14 Dec 2023 07:48:37 -0800 Subject: [PATCH 2/4] requirements --- requirements.txt | 26 ++++++++++++-------------- setup.py | 2 +- src/pyspark_test.py | 24 +++++++++++++++++------- 3 files changed, 30 insertions(+), 22 deletions(-) diff --git a/requirements.txt b/requirements.txt index 654b48a..d379428 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,17 +1,17 @@ # -# This file is autogenerated by pip-compile -# To update, run: +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: # # pip-compile requirements.in # appdirs==1.4.4 # via black -attrs==20.2.0 - # via pytest black==20.8b1 # via -r requirements.in click==7.1.2 # via black +exceptiongroup==1.2.0 + # via pytest iniconfig==1.0.1 # via pytest mypy-extensions==0.4.3 @@ -22,28 +22,26 @@ pathspec==0.8.0 # via black pluggy==0.13.1 # via pytest -py4j==0.10.7 +py4j==0.10.9.7 # via pyspark -py==1.10.0 - # via pytest pyparsing==2.4.7 # via packaging -pyspark==2.4.7 - # via -r requirements.in -pytest-testdox==2.0.1 +pyspark==3.4.0 # via -r requirements.in -pytest==6.1.1 +pytest==7.3.1 # via # -r requirements.in # pytest-testdox +pytest-testdox==2.0.1 + # via -r requirements.in regex==2020.10.28 # via black six==1.15.0 # via packaging toml==0.10.1 - # via - # black - # pytest + # via black +tomli==2.0.1 + # via pytest typed-ast==1.4.1 # via black typing-extensions==3.7.4.3 diff --git a/setup.py b/setup.py index c5d267a..bc579bf 100644 --- a/setup.py +++ b/setup.py @@ -19,5 +19,5 @@ def readme(): license="Apache Software License (Apache 2.0)", py_modules=["pyspark_test"], package_dir={"": "src"}, - install_requires=["pyspark>=3.4.0"], + install_requires=["pyspark>=2.1.2"], ) diff --git a/src/pyspark_test.py b/src/pyspark_test.py index dd3f680..dbcb56c 100644 --- a/src/pyspark_test.py +++ b/src/pyspark_test.py @@ -4,6 +4,7 @@ try: from pyspark.sql.connect.dataframe import DataFrame as CDF + has_connect_deps = True except ImportError: has_connect_deps = False @@ -11,19 +12,27 @@ def _check_isinstance_df(left: Any, right: Any): types_to_test = [pyspark.sql.DataFrame] + msg_string = "" + # If Spark Connect dependencies are not available, the input is not going to be a Spark Connect + # DataFrame so we can safely skip the validation. if has_connect_deps: types_to_test.append(CDF) + msg_string = " or {CDF}" left_good = any(map(lambda x: isinstance(left, x), types_to_test)) right_good = any(map(lambda x: isinstance(right, x), types_to_test)) - assert left_good, \ - f"Left expected type {pyspark.sql.DataFrame} or {CDF}, found {type(left)} instead" - assert right_good, \ - f"Right expected type {pyspark.sql.DataFrame} or {CDF}, found {type(right)} instead" + assert ( + left_good + ), f"Left expected type {pyspark.sql.DataFrame}{msg_string}, found {type(left)} instead" + assert ( + right_good + ), f"Right expected type {pyspark.sql.DataFrame}{msg_string}, found {type(right)} instead" # Check that both sides are of the same DataFrame type. - assert type(left) == type(right), \ - f"Left and right DataFrames are not of the same type: {type(left)} != {type(right)}" + assert type(left) == type( + right + ), f"Left and right DataFrames are not of the same type: {type(left)} != {type(right)}" + def _check_columns( check_columns_in_order: bool, @@ -52,7 +61,8 @@ def _check_schema( def _check_df_content( - left_df: pyspark.sql.DataFrame, right_df: pyspark.sql.DataFrame, + left_df: pyspark.sql.DataFrame, + right_df: pyspark.sql.DataFrame, ): left_df_list = left_df.collect() right_df_list = right_df.collect() From f3db060ed30b82b5c44b0b25245a4b223a7d0141 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Thu, 14 Dec 2023 08:07:35 -0800 Subject: [PATCH 3/4] docker --- Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index d8d7d5a..90a7e8a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,7 +5,7 @@ RUN apk --update add openjdk8-jre gcc musl-dev bash ENV JAVA_HOME /usr/ # Hadoop -ENV HADOOP_VERSION 2.7.2 +ENV HADOOP_VERSION 3.3.3 ENV HADOOP_HOME /usr/hadoop-$HADOOP_VERSION ENV HADOOP_CONF_DIR=$HADOOP_HOME/etc/hadoop ENV PATH $PATH:$HADOOP_HOME/bin @@ -14,7 +14,7 @@ RUN wget "http://archive.apache.org/dist/hadoop/common/hadoop-$HADOOP_VERSION/ha && rm "hadoop-$HADOOP_VERSION.tar.gz" # Spark -ENV SPARK_VERSION 2.4.8 +ENV SPARK_VERSION 3.3.3 ENV SPARK_PACKAGE spark-$SPARK_VERSION ENV SPARK_HOME /usr/$SPARK_PACKAGE-bin-without-hadoop ENV PYSPARK_PYTHON python From 5d9432bd6c6a3534057b39c964a1f35ce856ce8f Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Mon, 18 Dec 2023 13:12:36 +0100 Subject: [PATCH 4/4] deps --- requirements.in | 2 +- requirements.txt | 39 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/requirements.in b/requirements.in index a538f0e..aa6e850 100644 --- a/requirements.in +++ b/requirements.in @@ -1,4 +1,4 @@ black==20.8b1 -pyspark==3.4.0 +pyspark[connect]==3.4.0 pytest-testdox==2.0.1 pytest==7.3.1 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index d379428..5829e86 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,32 +12,63 @@ click==7.1.2 # via black exceptiongroup==1.2.0 # via pytest +googleapis-common-protos==1.62.0 + # via + # grpcio-status + # pyspark +grpcio==1.60.0 + # via + # grpcio-status + # pyspark +grpcio-status==1.60.0 + # via pyspark iniconfig==1.0.1 # via pytest mypy-extensions==0.4.3 # via black +numpy==1.26.2 + # via + # pandas + # pyarrow + # pyspark packaging==20.4 # via pytest +pandas==2.1.4 + # via pyspark pathspec==0.8.0 # via black pluggy==0.13.1 # via pytest +protobuf==4.25.1 + # via + # googleapis-common-protos + # grpcio-status py4j==0.10.9.7 # via pyspark +pyarrow==14.0.1 + # via pyspark pyparsing==2.4.7 # via packaging -pyspark==3.4.0 - # via -r requirements.in +pyspark[connect]==3.4.0 + # via + # -r requirements.in + # pyspark pytest==7.3.1 # via # -r requirements.in # pytest-testdox pytest-testdox==2.0.1 # via -r requirements.in +python-dateutil==2.8.2 + # via pandas +pytz==2023.3.post1 + # via pandas regex==2020.10.28 # via black six==1.15.0 - # via packaging + # via + # packaging + # python-dateutil toml==0.10.1 # via black tomli==2.0.1 @@ -46,3 +77,5 @@ typed-ast==1.4.1 # via black typing-extensions==3.7.4.3 # via black +tzdata==2023.3 + # via pandas