Skip to content

Commit f641eda

Browse files
committed
fix tests
Signed-off-by: Kai Huang <ahkcs@amazon.com>
1 parent 4c199f6 commit f641eda

File tree

3 files changed

+70
-76
lines changed

3 files changed

+70
-76
lines changed

integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteMultisearchCommandIT.java

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ public void testMultisearchWithDifferentIndicesSchemaMerge() throws IOException
210210
executeQuery(
211211
String.format(
212212
"| multisearch [search source=%s | where age > 35 | fields account_number,"
213-
+ " firstname, age, balance] [search source=%s | where age > 35 | fields"
214-
+ " account_number, balance, age] | stats count() as total_count",
213+
+ " firstname, balance] [search source=%s | where age > 35 | fields"
214+
+ " account_number, balance] | stats count() as total_count",
215215
TEST_INDEX_ACCOUNT, TEST_INDEX_BANK));
216216

217217
verifySchema(result, schema("total_count", null, "bigint"));
@@ -345,18 +345,23 @@ public void testMultisearchCrossIndexFieldSelection() throws IOException {
345345
}
346346

347347
@Test
348-
public void testMultisearchTypeConflictWithStats() throws IOException {
349-
JSONObject result =
350-
executeQuery(
351-
String.format(
352-
"| multisearch "
353-
+ "[search source=%s | fields age] "
354-
+ "[search source=%s | fields age] "
355-
+ "| stats count() as total",
356-
TEST_INDEX_ACCOUNT, TEST_INDEX_LOCATIONS_TYPE_CONFLICT));
357-
358-
verifySchema(result, schema("total", null, "bigint"));
348+
public void testMultisearchTypeConflictWithStats() {
349+
Exception exception =
350+
assertThrows(
351+
ResponseException.class,
352+
() ->
353+
executeQuery(
354+
String.format(
355+
"| multisearch "
356+
+ "[search source=%s | fields age] "
357+
+ "[search source=%s | fields age] "
358+
+ "| stats count() as total",
359+
TEST_INDEX_ACCOUNT, TEST_INDEX_LOCATIONS_TYPE_CONFLICT)));
359360

360-
verifyDataRows(result, rows(1010L));
361+
assertTrue(
362+
"Error message should indicate type conflict",
363+
exception
364+
.getMessage()
365+
.contains("Schema unification failed: field 'age' has conflicting types"));
361366
}
362367
}

ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAppendTest.java

Lines changed: 44 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import java.util.List;
1010
import org.apache.calcite.rel.RelNode;
1111
import org.apache.calcite.test.CalciteAssert;
12+
import org.junit.Assert;
1213
import org.junit.Test;
1314

1415
public class CalcitePPLAppendTest extends CalcitePPLAbstractTest {
@@ -71,15 +72,16 @@ public void testAppendEmptySearchCommand() {
7172
@Test
7273
public void testAppendNested() {
7374
String ppl =
74-
"source=EMP | append [ | where DEPTNO = 10 | append [ source=EMP | where DEPTNO = 20 ] ]";
75+
"source=EMP | fields ENAME, SAL | append [ | append [ source=EMP | where DEPTNO = 20 ] ]";
7576
RelNode root = getRelNode(ppl);
7677
String expectedLogical =
7778
"LogicalUnion(all=[true])\n"
78-
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4],"
79-
+ " SAL=[$5], COMM=[$6], DEPTNO=[$7], EMPNO0=[null:SMALLINT])\n"
79+
+ " LogicalProject(ENAME=[$1], SAL=[$5], EMPNO=[null:SMALLINT], JOB=[null:VARCHAR(9)],"
80+
+ " MGR=[null:SMALLINT], HIREDATE=[null:DATE], COMM=[null:DECIMAL(7, 2)],"
81+
+ " DEPTNO=[null:TINYINT])\n"
8082
+ " LogicalTableScan(table=[[scott, EMP]])\n"
81-
+ " LogicalProject(EMPNO=[null:SMALLINT], ENAME=[$1], JOB=[$2], MGR=[$3],"
82-
+ " HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], EMPNO0=[$0])\n"
83+
+ " LogicalProject(ENAME=[$1], SAL=[$5], EMPNO=[$0], JOB=[$2], MGR=[$3],"
84+
+ " HIREDATE=[$4], COMM=[$6], DEPTNO=[$7])\n"
8385
+ " LogicalUnion(all=[true])\n"
8486
+ " LogicalValues(tuples=[[]])\n"
8587
+ " LogicalFilter(condition=[=($7, 20)])\n"
@@ -88,12 +90,12 @@ public void testAppendNested() {
8890
verifyResultCount(root, 19); // 14 original table rows + 5 filtered subquery rows
8991

9092
String expectedSparkSql =
91-
"SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`, CAST(NULL AS"
92-
+ " SMALLINT) `EMPNO0`\n"
93+
"SELECT `ENAME`, `SAL`, CAST(NULL AS SMALLINT) `EMPNO`, CAST(NULL AS STRING) `JOB`,"
94+
+ " CAST(NULL AS SMALLINT) `MGR`, CAST(NULL AS DATE) `HIREDATE`, CAST(NULL AS"
95+
+ " DECIMAL(7, 2)) `COMM`, CAST(NULL AS TINYINT) `DEPTNO`\n"
9396
+ "FROM `scott`.`EMP`\n"
9497
+ "UNION ALL\n"
95-
+ "SELECT CAST(NULL AS SMALLINT) `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`,"
96-
+ " `COMM`, `DEPTNO`, `EMPNO` `EMPNO0`\n"
98+
+ "SELECT `ENAME`, `SAL`, `EMPNO`, `JOB`, `MGR`, `HIREDATE`, `COMM`, `DEPTNO`\n"
9799
+ "FROM (SELECT *\n"
98100
+ "FROM (VALUES (NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL)) `t` (`EMPNO`,"
99101
+ " `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`)\n"
@@ -109,61 +111,63 @@ public void testAppendNested() {
109111
public void testAppendEmptySourceWithJoin() {
110112
List<String> emptySourceWithEmptySourceJoinPPLs =
111113
Arrays.asList(
112-
"source=EMP | append [ | where DEPTNO = 10 | join on ENAME = DNAME DEPT ]",
113-
"source=EMP | append [ | where DEPTNO = 10 | cross join on ENAME = DNAME DEPT ]",
114-
"source=EMP | append [ | where DEPTNO = 10 | left join on ENAME = DNAME DEPT ]",
115-
"source=EMP | append [ | where DEPTNO = 10 | semi join on ENAME = DNAME DEPT ]",
116-
"source=EMP | append [ | where DEPTNO = 10 | anti join on ENAME = DNAME DEPT ]");
114+
"source=EMP | fields EMPNO, ENAME, JOB | append [ | where DEPTNO = 10 | join on ENAME"
115+
+ " = DNAME DEPT ]",
116+
"source=EMP | fields EMPNO, ENAME, JOB | append [ | where DEPTNO = 10 | cross join on"
117+
+ " ENAME = DNAME DEPT ]",
118+
"source=EMP | fields EMPNO, ENAME, JOB | append [ | where DEPTNO = 10 | left join on"
119+
+ " ENAME = DNAME DEPT ]",
120+
"source=EMP | fields EMPNO, ENAME, JOB | append [ | where DEPTNO = 10 | semi join on"
121+
+ " ENAME = DNAME DEPT ]",
122+
"source=EMP | fields EMPNO, ENAME, JOB | append [ | where DEPTNO = 10 | anti join on"
123+
+ " ENAME = DNAME DEPT ]");
117124

118125
for (String ppl : emptySourceWithEmptySourceJoinPPLs) {
119126
RelNode root = getRelNode(ppl);
120127
String expectedLogical =
121128
"LogicalUnion(all=[true])\n"
122-
+ " LogicalTableScan(table=[[scott, EMP]])\n"
129+
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2])\n"
130+
+ " LogicalTableScan(table=[[scott, EMP]])\n"
123131
+ " LogicalValues(tuples=[[]])\n";
124132
verifyLogical(root, expectedLogical);
125133
verifyResultCount(root, 14);
126134

127135
String expectedSparkSql =
128-
"SELECT *\n"
136+
"SELECT `EMPNO`, `ENAME`, `JOB`\n"
129137
+ "FROM `scott`.`EMP`\n"
130138
+ "UNION ALL\n"
131139
+ "SELECT *\n"
132-
+ "FROM (VALUES (NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL)) `t` (`EMPNO`,"
133-
+ " `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`)\n"
140+
+ "FROM (VALUES (NULL, NULL, NULL)) `t` (`EMPNO`, `ENAME`, `JOB`)\n"
134141
+ "WHERE 1 = 0";
135142
verifyPPLToSparkSQL(root, expectedSparkSql);
136143
}
137144

138145
List<String> emptySourceWithRightOrFullJoinPPLs =
139146
Arrays.asList(
140-
"source=EMP | append [ | where DEPTNO = 10 | right join on ENAME = DNAME DEPT ]",
141-
"source=EMP | append [ | where DEPTNO = 10 | full join on ENAME = DNAME DEPT ]");
147+
"source=EMP | fields EMPNO, ENAME, JOB | append [ | where DEPTNO = 10 | right join on"
148+
+ " ENAME = DNAME DEPT ]",
149+
"source=EMP | fields EMPNO, ENAME, JOB | append [ | where DEPTNO = 10 | full join on"
150+
+ " ENAME = DNAME DEPT ]");
142151

143152
for (String ppl : emptySourceWithRightOrFullJoinPPLs) {
144153
RelNode root = getRelNode(ppl);
145154
String expectedLogical =
146155
"LogicalUnion(all=[true])\n"
147-
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4],"
148-
+ " SAL=[$5], COMM=[$6], DEPTNO=[$7], DEPTNO0=[null:TINYINT],"
156+
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], DEPTNO=[null:TINYINT],"
149157
+ " DNAME=[null:VARCHAR(14)], LOC=[null:VARCHAR(13)])\n"
150158
+ " LogicalTableScan(table=[[scott, EMP]])\n"
151159
+ " LogicalProject(EMPNO=[null:SMALLINT], ENAME=[null:VARCHAR(10)],"
152-
+ " JOB=[null:VARCHAR(9)], MGR=[null:SMALLINT], HIREDATE=[null:DATE],"
153-
+ " SAL=[null:DECIMAL(7, 2)], COMM=[null:DECIMAL(7, 2)], DEPTNO=[null:TINYINT],"
154-
+ " DEPTNO0=[$0], DNAME=[$1], LOC=[$2])\n"
160+
+ " JOB=[null:VARCHAR(9)], DEPTNO=[$0], DNAME=[$1], LOC=[$2])\n"
155161
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
156162
verifyLogical(root, expectedLogical);
157163

158164
String expectedSparkSql =
159-
"SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`, CAST(NULL AS"
160-
+ " TINYINT) `DEPTNO0`, CAST(NULL AS STRING) `DNAME`, CAST(NULL AS STRING) `LOC`\n"
165+
"SELECT `EMPNO`, `ENAME`, `JOB`, CAST(NULL AS TINYINT) `DEPTNO`, CAST(NULL AS STRING)"
166+
+ " `DNAME`, CAST(NULL AS STRING) `LOC`\n"
161167
+ "FROM `scott`.`EMP`\n"
162168
+ "UNION ALL\n"
163169
+ "SELECT CAST(NULL AS SMALLINT) `EMPNO`, CAST(NULL AS STRING) `ENAME`, CAST(NULL AS"
164-
+ " STRING) `JOB`, CAST(NULL AS SMALLINT) `MGR`, CAST(NULL AS DATE) `HIREDATE`,"
165-
+ " CAST(NULL AS DECIMAL(7, 2)) `SAL`, CAST(NULL AS DECIMAL(7, 2)) `COMM`, CAST(NULL"
166-
+ " AS TINYINT) `DEPTNO`, `DEPTNO` `DEPTNO0`, `DNAME`, `LOC`\n"
170+
+ " STRING) `JOB`, `DEPTNO`, `DNAME`, `LOC`\n"
167171
+ "FROM `scott`.`DEPT`";
168172
verifyPPLToSparkSQL(root, expectedSparkSql);
169173
}
@@ -172,27 +176,27 @@ public void testAppendEmptySourceWithJoin() {
172176
@Test
173177
public void testAppendDifferentIndex() {
174178
String ppl =
175-
"source=EMP | fields EMPNO, DEPTNO | append [ source=DEPT | fields DEPTNO, DNAME | where"
179+
"source=EMP | fields EMPNO, ENAME | append [ source=DEPT | fields DEPTNO, DNAME | where"
176180
+ " DEPTNO = 20 ]";
177181
RelNode root = getRelNode(ppl);
178182
String expectedLogical =
179183
"LogicalUnion(all=[true])\n"
180-
+ " LogicalProject(EMPNO=[$0], DEPTNO=[$7], DEPTNO0=[null:TINYINT],"
184+
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], DEPTNO=[null:TINYINT],"
181185
+ " DNAME=[null:VARCHAR(14)])\n"
182186
+ " LogicalTableScan(table=[[scott, EMP]])\n"
183-
+ " LogicalProject(EMPNO=[null:SMALLINT], DEPTNO=[null:TINYINT], DEPTNO0=[$0],"
187+
+ " LogicalProject(EMPNO=[null:SMALLINT], ENAME=[null:VARCHAR(10)], DEPTNO=[$0],"
184188
+ " DNAME=[$1])\n"
185189
+ " LogicalFilter(condition=[=($0, 20)])\n"
186190
+ " LogicalProject(DEPTNO=[$0], DNAME=[$1])\n"
187191
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
188192
verifyLogical(root, expectedLogical);
189193

190194
String expectedSparkSql =
191-
"SELECT `EMPNO`, `DEPTNO`, CAST(NULL AS TINYINT) `DEPTNO0`, CAST(NULL AS STRING) `DNAME`\n"
195+
"SELECT `EMPNO`, `ENAME`, CAST(NULL AS TINYINT) `DEPTNO`, CAST(NULL AS STRING) `DNAME`\n"
192196
+ "FROM `scott`.`EMP`\n"
193197
+ "UNION ALL\n"
194-
+ "SELECT CAST(NULL AS SMALLINT) `EMPNO`, CAST(NULL AS TINYINT) `DEPTNO`, `DEPTNO`"
195-
+ " `DEPTNO0`, `DNAME`\n"
198+
+ "SELECT CAST(NULL AS SMALLINT) `EMPNO`, CAST(NULL AS STRING) `ENAME`, `DEPTNO`,"
199+
+ " `DNAME`\n"
196200
+ "FROM (SELECT `DEPTNO`, `DNAME`\n"
197201
+ "FROM `scott`.`DEPT`) `t0`\n"
198202
+ "WHERE `DEPTNO` = 20";
@@ -227,22 +231,9 @@ public void testAppendWithMergedColumns() {
227231
public void testAppendWithConflictTypeColumn() {
228232
String ppl =
229233
"source=EMP | fields DEPTNO | append [ source=EMP | fields DEPTNO | eval DEPTNO = 20 ]";
230-
RelNode root = getRelNode(ppl);
231-
String expectedLogical =
232-
"LogicalUnion(all=[true])\n"
233-
+ " LogicalProject(DEPTNO=[$7], DEPTNO0=[null:INTEGER])\n"
234-
+ " LogicalTableScan(table=[[scott, EMP]])\n"
235-
+ " LogicalProject(DEPTNO=[null:TINYINT], DEPTNO0=[20])\n"
236-
+ " LogicalTableScan(table=[[scott, EMP]])\n";
237-
verifyLogical(root, expectedLogical);
238-
verifyResultCount(root, 28);
239-
240-
String expectedSparkSql =
241-
"SELECT `DEPTNO`, CAST(NULL AS INTEGER) `DEPTNO0`\n"
242-
+ "FROM `scott`.`EMP`\n"
243-
+ "UNION ALL\n"
244-
+ "SELECT CAST(NULL AS TINYINT) `DEPTNO`, 20 `DEPTNO0`\n"
245-
+ "FROM `scott`.`EMP`";
246-
verifyPPLToSparkSQL(root, expectedSparkSql);
234+
Exception exception =
235+
Assert.assertThrows(IllegalArgumentException.class, () -> getRelNode(ppl));
236+
verifyErrorMessageContains(
237+
exception, "Schema unification failed: field 'DEPTNO' has conflicting types");
247238
}
248239
}

ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLMultisearchTest.java

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -142,30 +142,28 @@ public void testMultisearchCrossIndices() {
142142
// Test multisearch with different tables (indices)
143143
String ppl =
144144
"| multisearch [search source=EMP | where DEPTNO = 10 | fields EMPNO, ENAME,"
145-
+ " DEPTNO] [search source=DEPT | where DEPTNO = 10 | fields DEPTNO, DNAME | eval EMPNO"
146-
+ " = DEPTNO, ENAME = DNAME]";
145+
+ " JOB] [search source=DEPT | where DEPTNO = 10 | fields DEPTNO, DNAME, LOC]";
147146
RelNode root = getRelNode(ppl);
148147
String expectedLogical =
149148
"LogicalUnion(all=[true])\n"
150-
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], DEPTNO=[$7], DEPTNO0=[null:TINYINT],"
151-
+ " DNAME=[null:VARCHAR(14)], EMPNO0=[null:TINYINT], ENAME0=[null:VARCHAR(14)])\n"
149+
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], DEPTNO=[null:TINYINT],"
150+
+ " DNAME=[null:VARCHAR(14)], LOC=[null:VARCHAR(13)])\n"
152151
+ " LogicalFilter(condition=[=($7, 10)])\n"
153152
+ " LogicalTableScan(table=[[scott, EMP]])\n"
154153
+ " LogicalProject(EMPNO=[null:SMALLINT], ENAME=[null:VARCHAR(10)],"
155-
+ " DEPTNO=[null:TINYINT], DEPTNO0=[$0], DNAME=[$1], EMPNO0=[$0], ENAME0=[$1])\n"
154+
+ " JOB=[null:VARCHAR(9)], DEPTNO=[$0], DNAME=[$1], LOC=[$2])\n"
156155
+ " LogicalFilter(condition=[=($0, 10)])\n"
157156
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
158157
verifyLogical(root, expectedLogical);
159158

160159
String expectedSparkSql =
161-
"SELECT `EMPNO`, `ENAME`, `DEPTNO`, CAST(NULL AS TINYINT) `DEPTNO0`, CAST(NULL AS STRING)"
162-
+ " `DNAME`, CAST(NULL AS TINYINT) `EMPNO0`, CAST(NULL AS STRING) `ENAME0`\n"
160+
"SELECT `EMPNO`, `ENAME`, `JOB`, CAST(NULL AS TINYINT) `DEPTNO`, CAST(NULL AS STRING)"
161+
+ " `DNAME`, CAST(NULL AS STRING) `LOC`\n"
163162
+ "FROM `scott`.`EMP`\n"
164163
+ "WHERE `DEPTNO` = 10\n"
165164
+ "UNION ALL\n"
166165
+ "SELECT CAST(NULL AS SMALLINT) `EMPNO`, CAST(NULL AS STRING) `ENAME`, CAST(NULL AS"
167-
+ " TINYINT) `DEPTNO`, `DEPTNO` `DEPTNO0`, `DNAME`, `DEPTNO` `EMPNO0`, `DNAME`"
168-
+ " `ENAME0`\n"
166+
+ " STRING) `JOB`, `DEPTNO`, `DNAME`, `LOC`\n"
169167
+ "FROM `scott`.`DEPT`\n"
170168
+ "WHERE `DEPTNO` = 10";
171169
verifyPPLToSparkSQL(root, expectedSparkSql);

0 commit comments

Comments
 (0)