|
2000 | 2000 | "SquareNumbers(lit(1), lit(3)).show()" |
2001 | 2001 | ] |
2002 | 2002 | }, |
| 2003 | + { |
| 2004 | + "cell_type": "markdown", |
| 2005 | + "id": "216e9fc0-12a9-4f45-85b1-8e791755b1d3", |
| 2006 | + "metadata": {}, |
| 2007 | + "source": [ |
| 2008 | + "### Best Practices for PySpark DataFrame Comparison Testing" |
| 2009 | + ] |
| 2010 | + }, |
| 2011 | + { |
| 2012 | + "cell_type": "code", |
| 2013 | + "execution_count": null, |
| 2014 | + "id": "9badb0ee-16ec-4291-9477-8a38ebd7e876", |
| 2015 | + "metadata": { |
| 2016 | + "editable": true, |
| 2017 | + "slideshow": { |
| 2018 | + "slide_type": "" |
| 2019 | + }, |
| 2020 | + "tags": [ |
| 2021 | + "hide-cell" |
| 2022 | + ] |
| 2023 | + }, |
| 2024 | + "outputs": [], |
| 2025 | + "source": [ |
| 2026 | + "!pip install \"pyspark[sql]\"" |
| 2027 | + ] |
| 2028 | + }, |
| 2029 | + { |
| 2030 | + "cell_type": "code", |
| 2031 | + "execution_count": 19, |
| 2032 | + "id": "d2adcd65-5197-404f-88d6-c368a863cf75", |
| 2033 | + "metadata": { |
| 2034 | + "editable": true, |
| 2035 | + "slideshow": { |
| 2036 | + "slide_type": "" |
| 2037 | + }, |
| 2038 | + "tags": [ |
| 2039 | + "hide-cell" |
| 2040 | + ] |
| 2041 | + }, |
| 2042 | + "outputs": [], |
| 2043 | + "source": [ |
| 2044 | + "from pyspark.sql import SparkSession\n", |
| 2045 | + "\n", |
| 2046 | + "# Create SparkSession\n", |
| 2047 | + "spark = SparkSession.builder.getOrCreate()" |
| 2048 | + ] |
| 2049 | + }, |
| 2050 | + { |
| 2051 | + "cell_type": "markdown", |
| 2052 | + "id": "1002536a", |
| 2053 | + "metadata": {}, |
| 2054 | + "source": [ |
| 2055 | + "Manually comparing PySpark DataFrame outputs using `collect()` and equality comparison leads to brittle tests due to ordering issues and unclear error messages when data doesn't match expectations.\n", |
| 2056 | + "\n", |
| 2057 | + "For example, the following test will fail due to ordering issues, resulting in an unclear error message.\n" |
| 2058 | + ] |
| 2059 | + }, |
| 2060 | + { |
| 2061 | + "cell_type": "code", |
| 2062 | + "execution_count": 31, |
| 2063 | + "id": "e4299f30", |
| 2064 | + "metadata": {}, |
| 2065 | + "outputs": [ |
| 2066 | + { |
| 2067 | + "name": "stderr", |
| 2068 | + "output_type": "stream", |
| 2069 | + "text": [ |
| 2070 | + " \r" |
| 2071 | + ] |
| 2072 | + }, |
| 2073 | + { |
| 2074 | + "name": "stdout", |
| 2075 | + "output_type": "stream", |
| 2076 | + "text": [ |
| 2077 | + "assert [Row(id=1, name='Alice', value=100), Row(id=2, name='Bob', value=200)] == [Row(id=2, name='Bob', value=200), Row(id=1, name='Alice', value=100)]\n", |
| 2078 | + " + where [Row(id=1, name='Alice', value=100), Row(id=2, name='Bob', value=200)] = <bound method DataFrame.collect of DataFrame[id: bigint, name: string, value: bigint]>()\n", |
| 2079 | + " + where <bound method DataFrame.collect of DataFrame[id: bigint, name: string, value: bigint]> = DataFrame[id: bigint, name: string, value: bigint].collect\n", |
| 2080 | + " + and [Row(id=2, name='Bob', value=200), Row(id=1, name='Alice', value=100)] = <bound method DataFrame.collect of DataFrame[id: bigint, name: string, value: bigint]>()\n", |
| 2081 | + " + where <bound method DataFrame.collect of DataFrame[id: bigint, name: string, value: bigint]> = DataFrame[id: bigint, name: string, value: bigint].collect\n" |
| 2082 | + ] |
| 2083 | + } |
| 2084 | + ], |
| 2085 | + "source": [ |
| 2086 | + "# Manual DataFrame comparison\n", |
| 2087 | + "result_df = spark.createDataFrame(\n", |
| 2088 | + " [(1, \"Alice\", 100), (2, \"Bob\", 200)], [\"id\", \"name\", \"value\"]\n", |
| 2089 | + ")\n", |
| 2090 | + "\n", |
| 2091 | + "expected_df = spark.createDataFrame(\n", |
| 2092 | + " [(2, \"Bob\", 200), (1, \"Alice\", 100)], [\"id\", \"name\", \"value\"]\n", |
| 2093 | + ")\n", |
| 2094 | + "\n", |
| 2095 | + "try:\n", |
| 2096 | + " assert result_df.collect() == expected_df.collect()\n", |
| 2097 | + "except AssertionError as e:\n", |
| 2098 | + " print(e)" |
| 2099 | + ] |
| 2100 | + }, |
| 2101 | + { |
| 2102 | + "cell_type": "markdown", |
| 2103 | + "id": "7c4f8fd8-c2c2-4804-8e42-6fd3eb6aec27", |
| 2104 | + "metadata": {}, |
| 2105 | + "source": [ |
| 2106 | + "`assertDataFrameEqual` provides a robust way to compare DataFrames, allowing for order-independent comparison.\n" |
| 2107 | + ] |
| 2108 | + }, |
| 2109 | + { |
| 2110 | + "cell_type": "code", |
| 2111 | + "execution_count": null, |
| 2112 | + "id": "73b0d483-8b00-44ab-9279-4c7765ca1ff6", |
| 2113 | + "metadata": {}, |
| 2114 | + "outputs": [], |
| 2115 | + "source": [ |
| 2116 | + "# Testing with DataFrame equality\n", |
| 2117 | + "from pyspark.testing.utils import assertDataFrameEqual" |
| 2118 | + ] |
| 2119 | + }, |
| 2120 | + { |
| 2121 | + "cell_type": "code", |
| 2122 | + "execution_count": 7, |
| 2123 | + "id": "7c46ae8a", |
| 2124 | + "metadata": {}, |
| 2125 | + "outputs": [], |
| 2126 | + "source": [ |
| 2127 | + "assertDataFrameEqual(result_df, expected_df)" |
| 2128 | + ] |
| 2129 | + }, |
| 2130 | + { |
| 2131 | + "cell_type": "markdown", |
| 2132 | + "id": "085f150d-20ff-4b0a-a4ab-1ee452598e9e", |
| 2133 | + "metadata": {}, |
| 2134 | + "source": [ |
| 2135 | + "Using `collect()` for comparison cannot detect type mismatch, whereas `assertDataFrameEqual` can.\n", |
| 2136 | + "\n", |
| 2137 | + "For example, the following test will pass, even though there is a type mismatch.\n" |
| 2138 | + ] |
| 2139 | + }, |
| 2140 | + { |
| 2141 | + "cell_type": "code", |
| 2142 | + "execution_count": 27, |
| 2143 | + "id": "da7494c0-c05f-4a2f-a411-805c8f2f73ba", |
| 2144 | + "metadata": {}, |
| 2145 | + "outputs": [], |
| 2146 | + "source": [ |
| 2147 | + "# Manual DataFrame comparison\n", |
| 2148 | + "result_df = spark.createDataFrame(\n", |
| 2149 | + " [(1, \"Alice\", 100), (2, \"Bob\", 200)], [\"id\", \"name\", \"value\"]\n", |
| 2150 | + ")\n", |
| 2151 | + "\n", |
| 2152 | + "expected_df = spark.createDataFrame(\n", |
| 2153 | + " [(1, \"Alice\", 100.0), (2, \"Bob\", 200.0)], [\"id\", \"name\", \"value\"]\n", |
| 2154 | + ")\n", |
| 2155 | + "\n", |
| 2156 | + "assert result_df.collect() == expected_df.collect()" |
| 2157 | + ] |
| 2158 | + }, |
| 2159 | + { |
| 2160 | + "cell_type": "markdown", |
| 2161 | + "id": "82914b3b-69c0-4c68-9d72-d2ce31417397", |
| 2162 | + "metadata": {}, |
| 2163 | + "source": [ |
| 2164 | + "The error message produced by `assertDataFrameEqual` is clear and informative, highlighting the difference in schemas." |
| 2165 | + ] |
| 2166 | + }, |
| 2167 | + { |
| 2168 | + "cell_type": "code", |
| 2169 | + "execution_count": 30, |
| 2170 | + "id": "3faa1dbc-887a-4c36-ace8-c621411c3fb7", |
| 2171 | + "metadata": {}, |
| 2172 | + "outputs": [ |
| 2173 | + { |
| 2174 | + "name": "stdout", |
| 2175 | + "output_type": "stream", |
| 2176 | + "text": [ |
| 2177 | + "[DIFFERENT_SCHEMA] Schemas do not match.\n", |
| 2178 | + "--- actual\n", |
| 2179 | + "+++ expected\n", |
| 2180 | + "- StructType([StructField('id', LongType(), True), StructField('name', StringType(), True), StructField('value', LongType(), True)])\n", |
| 2181 | + "? ^ ^^\n", |
| 2182 | + "\n", |
| 2183 | + "+ StructType([StructField('id', LongType(), True), StructField('name', StringType(), True), StructField('value', DoubleType(), True)])\n", |
| 2184 | + "? ^ ^^^^\n", |
| 2185 | + "\n" |
| 2186 | + ] |
| 2187 | + } |
| 2188 | + ], |
| 2189 | + "source": [ |
| 2190 | + "try:\n", |
| 2191 | + " assertDataFrameEqual(result_df, expected_df)\n", |
| 2192 | + "except AssertionError as e:\n", |
| 2193 | + " print(e)" |
| 2194 | + ] |
| 2195 | + }, |
2003 | 2196 | { |
2004 | 2197 | "cell_type": "markdown", |
2005 | 2198 | "id": "9da7e800", |
|
0 commit comments