Skip to content

Commit 757904d

Browse files
opensearch-trigger-bot[bot]github-actions[bot]aaarone90LantaoJin
authored
[Backport 2.19-dev] [SQL/PPL] Fix the count(*) and dc(field) to be capped at MAX_INTEGER #4416 (#4656)
* [SQL/PPL] Fix the `count(*)` and `dc(field)` to be capped at MAX_INTEGER #4416 (#4418) Co-authored-by: Aaron Alvarez <aaarone@amazon.com> (cherry picked from commit d7b2c35) Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Fix IT Signed-off-by: Lantao Jin <ltjin@amazon.com> * Fix IT Signed-off-by: Lantao Jin <ltjin@amazon.com> --------- Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Signed-off-by: Lantao Jin <ltjin@amazon.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Aaron Alvarez <aaarone@amazon.com> Co-authored-by: Lantao Jin <ltjin@amazon.com>
1 parent f6eaef3 commit 757904d

File tree

16 files changed

+138
-125
lines changed

16 files changed

+138
-125
lines changed

core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunctions.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ private static DefaultFunctionResolver count() {
9696
new FunctionSignature(functionName, Collections.singletonList(type)),
9797
type ->
9898
(functionProperties, arguments) ->
99-
new CountAggregator(arguments, INTEGER))));
99+
new CountAggregator(arguments, LONG))));
100100
return functionResolver;
101101
}
102102

core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public String toString() {
4444

4545
/** Count State. */
4646
protected static class CountState implements AggregationState {
47-
protected int count;
47+
protected long count;
4848

4949
CountState() {
5050
this.count = 0;
@@ -56,7 +56,7 @@ public void count(ExprValue value) {
5656

5757
@Override
5858
public ExprValue result() {
59-
return ExprValueUtils.integerValue(count);
59+
return ExprValueUtils.longValue(count);
6060
}
6161
}
6262

core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1283,7 +1283,7 @@ public void named_aggregator_with_condition() {
12831283
emptyList()),
12841284
DSL.named(
12851285
"count(string_value) filter(where integer_value > 1)",
1286-
DSL.ref("count(string_value) filter(where integer_value > 1)", INTEGER))),
1286+
DSL.ref("count(string_value) filter(where integer_value > 1)", LONG))),
12871287
AstDSL.project(
12881288
AstDSL.agg(
12891289
AstDSL.relation("schema"),

core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,43 +30,43 @@ class CountAggregatorTest extends AggregationTest {
3030
@Test
3131
public void count_integer_field_expression() {
3232
ExprValue result = aggregation(DSL.count(DSL.ref("integer_value", INTEGER)), tuples);
33-
assertEquals(4, result.value());
33+
assertEquals(4L, result.value());
3434
}
3535

3636
@Test
3737
public void count_long_field_expression() {
3838
ExprValue result = aggregation(DSL.count(DSL.ref("long_value", LONG)), tuples);
39-
assertEquals(4, result.value());
39+
assertEquals(4L, result.value());
4040
}
4141

4242
@Test
4343
public void count_float_field_expression() {
4444
ExprValue result = aggregation(DSL.count(DSL.ref("float_value", FLOAT)), tuples);
45-
assertEquals(4, result.value());
45+
assertEquals(4L, result.value());
4646
}
4747

4848
@Test
4949
public void count_double_field_expression() {
5050
ExprValue result = aggregation(DSL.count(DSL.ref("double_value", DOUBLE)), tuples);
51-
assertEquals(4, result.value());
51+
assertEquals(4L, result.value());
5252
}
5353

5454
@Test
5555
public void count_date_field_expression() {
5656
ExprValue result = aggregation(DSL.count(DSL.ref("date_value", DATE)), tuples);
57-
assertEquals(4, result.value());
57+
assertEquals(4L, result.value());
5858
}
5959

6060
@Test
6161
public void count_timestamp_field_expression() {
6262
ExprValue result = aggregation(DSL.count(DSL.ref("timestamp_value", TIMESTAMP)), tuples);
63-
assertEquals(4, result.value());
63+
assertEquals(4L, result.value());
6464
}
6565

6666
@Test
6767
public void count_datetime_field_expression() {
6868
ExprValue result = aggregation(DSL.count(DSL.ref("datetime_value", DATETIME)), tuples);
69-
assertEquals(4, result.value());
69+
assertEquals(4L, result.value());
7070
}
7171

7272
@Test
@@ -75,34 +75,33 @@ public void count_arithmetic_expression() {
7575
aggregation(
7676
DSL.count(
7777
DSL.multiply(
78-
DSL.ref("integer_value", INTEGER),
79-
DSL.literal(ExprValueUtils.integerValue(10)))),
78+
DSL.ref("long_value", LONG), DSL.literal(ExprValueUtils.longValue(10L)))),
8079
tuples);
81-
assertEquals(4, result.value());
80+
assertEquals(4L, result.value());
8281
}
8382

8483
@Test
8584
public void count_string_field_expression() {
8685
ExprValue result = aggregation(DSL.count(DSL.ref("string_value", STRING)), tuples);
87-
assertEquals(4, result.value());
86+
assertEquals(4L, result.value());
8887
}
8988

9089
@Test
9190
public void count_boolean_field_expression() {
9291
ExprValue result = aggregation(DSL.count(DSL.ref("boolean_value", BOOLEAN)), tuples);
93-
assertEquals(1, result.value());
92+
assertEquals(1L, result.value());
9493
}
9594

9695
@Test
9796
public void count_struct_field_expression() {
9897
ExprValue result = aggregation(DSL.count(DSL.ref("struct_value", STRUCT)), tuples);
99-
assertEquals(1, result.value());
98+
assertEquals(1L, result.value());
10099
}
101100

102101
@Test
103102
public void count_array_field_expression() {
104103
ExprValue result = aggregation(DSL.count(DSL.ref("array_value", ARRAY)), tuples);
105-
assertEquals(1, result.value());
104+
assertEquals(1L, result.value());
106105
}
107106

108107
@Test
@@ -112,14 +111,14 @@ public void filtered_count() {
112111
DSL.count(DSL.ref("integer_value", INTEGER))
113112
.condition(DSL.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))),
114113
tuples);
115-
assertEquals(3, result.value());
114+
assertEquals(3L, result.value());
116115
}
117116

118117
@Test
119118
public void distinct_count() {
120119
ExprValue result =
121120
aggregation(DSL.distinctCount(DSL.ref("integer_value", INTEGER)), tuples_with_duplicates);
122-
assertEquals(3, result.value());
121+
assertEquals(3L, result.value());
123122
}
124123

125124
@Test
@@ -129,47 +128,47 @@ public void filtered_distinct_count() {
129128
DSL.distinctCount(DSL.ref("integer_value", INTEGER))
130129
.condition(DSL.greater(DSL.ref("double_value", DOUBLE), DSL.literal(1d))),
131130
tuples_with_duplicates);
132-
assertEquals(2, result.value());
131+
assertEquals(2L, result.value());
133132
}
134133

135134
@Test
136135
public void distinct_count_map() {
137136
ExprValue result =
138137
aggregation(DSL.distinctCount(DSL.ref("struct_value", STRUCT)), tuples_with_duplicates);
139-
assertEquals(3, result.value());
138+
assertEquals(3L, result.value());
140139
}
141140

142141
@Test
143142
public void distinct_count_array() {
144143
ExprValue result =
145144
aggregation(DSL.distinctCount(DSL.ref("array_value", ARRAY)), tuples_with_duplicates);
146-
assertEquals(3, result.value());
145+
assertEquals(3L, result.value());
147146
}
148147

149148
@Test
150149
public void count_with_missing() {
151150
ExprValue result =
152151
aggregation(DSL.count(DSL.ref("integer_value", INTEGER)), tuples_with_null_and_missing);
153-
assertEquals(2, result.value());
152+
assertEquals(2L, result.value());
154153
}
155154

156155
@Test
157156
public void count_with_null() {
158157
ExprValue result =
159158
aggregation(DSL.count(DSL.ref("double_value", DOUBLE)), tuples_with_null_and_missing);
160-
assertEquals(2, result.value());
159+
assertEquals(2L, result.value());
161160
}
162161

163162
@Test
164163
public void count_star_with_null_and_missing() {
165164
ExprValue result = aggregation(DSL.count(DSL.literal("*")), tuples_with_null_and_missing);
166-
assertEquals(3, result.value());
165+
assertEquals(3L, result.value());
167166
}
168167

169168
@Test
170169
public void count_literal_with_null_and_missing() {
171170
ExprValue result = aggregation(DSL.count(DSL.literal(1)), tuples_with_null_and_missing);
172-
assertEquals(3, result.value());
171+
assertEquals(3L, result.value());
173172
}
174173

175174
@Test

integ-test/src/test/java/org/opensearch/sql/bwc/SQLBackwardsCompatibilityIT.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,21 @@ private void verifySQLQueries(String endpoint) throws IOException {
175175
executeSQLQuery(
176176
endpoint,
177177
"SELECT COUNT(*) FILTER(WHERE age > 35) FROM " + TestsConstants.TEST_INDEX_ACCOUNT);
178-
verifySchema(filterResponse, schema("COUNT(*) FILTER(WHERE age > 35)", null, "integer"));
178+
// Accept both integer and long types for backwards compatibility
179+
String actualType =
180+
(String) filterResponse.getJSONArray("schema").getJSONObject(0).query("/type");
181+
String expectedType = actualType.equals("integer") ? "integer" : "long";
182+
verifySchema(filterResponse, schema("COUNT(*) FILTER(WHERE age > 35)", null, expectedType));
179183
verifyDataRows(filterResponse, rows(238));
180184

181185
JSONObject aggResponse =
182186
executeSQLQuery(
183187
endpoint, "SELECT COUNT(DISTINCT age) FROM " + TestsConstants.TEST_INDEX_ACCOUNT);
184-
verifySchema(aggResponse, schema("COUNT(DISTINCT age)", null, "integer"));
188+
// Accept both integer and long types for backwards compatibility
189+
String actualType2 =
190+
(String) aggResponse.getJSONArray("schema").getJSONObject(0).query("/type");
191+
String expectedType2 = actualType2.equals("integer") ? "integer" : "long";
192+
verifySchema(aggResponse, schema("COUNT(DISTINCT age)", null, expectedType2));
185193
verifyDataRows(aggResponse, rows(21));
186194

187195
JSONObject groupByResponse =

integ-test/src/test/java/org/opensearch/sql/correctness/runner/resultset/DBResult.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ public class DBResult {
3434
/** Possible types for varchar. H2 2.x use CHARACTER VARYING instead of VARCHAR. */
3535
private static final Set<String> VARCHAR = ImmutableSet.of("CHARACTER VARYING", "VARCHAR");
3636

37+
/**
38+
* Possible types for integer numbers.<br>
39+
* Different databases may return INTEGER or BIGINT for count operations.
40+
*/
41+
private static final Set<String> INTEGER_TYPES = ImmutableSet.of("INTEGER", "BIGINT");
42+
3743
/** Database name for display */
3844
private final String databaseName;
3945

@@ -74,6 +80,8 @@ public void addColumn(String name, String type) {
7480
type = FLOAT_TYPES.toString();
7581
} else if (VARCHAR.contains(type)) {
7682
type = "VARCHAR";
83+
} else if (INTEGER_TYPES.contains(type)) {
84+
type = INTEGER_TYPES.toString();
7785
}
7886
schema.add(new Type(StringUtils.toUpper(name), type));
7987
}

integ-test/src/test/java/org/opensearch/sql/correctness/runner/resultset/Row.java

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,19 @@
1010
import java.util.ArrayList;
1111
import java.util.Collection;
1212
import java.util.List;
13-
import lombok.EqualsAndHashCode;
13+
import java.util.Objects;
1414
import lombok.Getter;
1515
import lombok.ToString;
1616

1717
/** Row in result set. */
18-
@EqualsAndHashCode
1918
@ToString
2019
@Getter
2120
public class Row implements Comparable<Row> {
2221

2322
private final Collection<Object> values;
2423

2524
public Row() {
26-
this(new ArrayList<>()); // values in order by default
25+
this(new ArrayList<>());
2726
}
2827

2928
public Row(Collection<Object> values) {
@@ -37,7 +36,7 @@ public void add(Object value) {
3736
private Object roundFloatNum(Object value) {
3837
if (value instanceof Float) {
3938
BigDecimal decimal = BigDecimal.valueOf((Float) value).setScale(2, RoundingMode.CEILING);
40-
value = decimal.doubleValue(); // Convert to double too
39+
value = decimal.doubleValue();
4140
} else if (value instanceof Double) {
4241
BigDecimal decimal = BigDecimal.valueOf((Double) value).setScale(2, RoundingMode.CEILING);
4342
value = decimal.doubleValue();
@@ -70,8 +69,54 @@ public int compareTo(Row other) {
7069
if (result != 0) {
7170
return result;
7271
}
73-
} // Ignore incomparable field silently?
72+
}
7473
}
7574
return 0;
7675
}
76+
77+
@Override
78+
public boolean equals(Object o) {
79+
if (this == o) return true;
80+
if (!(o instanceof Row)) return false;
81+
Row other = (Row) o;
82+
return valuesEqual(this.values, other.values);
83+
}
84+
85+
private boolean valuesEqual(Collection<Object> values1, Collection<Object> values2) {
86+
if (values1.size() != values2.size()) return false;
87+
88+
List<Object> list1 = new ArrayList<>(values1);
89+
List<Object> list2 = new ArrayList<>(values2);
90+
91+
for (int i = 0; i < list1.size(); i++) {
92+
if (!isValueEqual(list1.get(i), list2.get(i))) {
93+
return false;
94+
}
95+
}
96+
return true;
97+
}
98+
99+
private boolean isValueEqual(Object val1, Object val2) {
100+
if (Objects.equals(val1, val2)) return true;
101+
102+
if (isIntegerOrLong(val1) && isIntegerOrLong(val2)) {
103+
return ((Number) val1).longValue() == ((Number) val2).longValue();
104+
}
105+
106+
return false;
107+
}
108+
109+
private boolean isIntegerOrLong(Object value) {
110+
return value instanceof Integer || value instanceof Long;
111+
}
112+
113+
@Override
114+
public int hashCode() {
115+
116+
List<Object> normalizedValues = new ArrayList<>();
117+
for (Object value : values) {
118+
normalizedValues.add(value instanceof Integer ? ((Integer) value).longValue() : value);
119+
}
120+
return normalizedValues.hashCode();
121+
}
77122
}

integ-test/src/test/java/org/opensearch/sql/legacy/AggregationExpressionIT.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ public void groupByDateShouldPass() {
204204
Index.BANK.getName()));
205205

206206
verifySchema(
207-
response, schema("birthdate", null, "timestamp"), schema("count(*)", "count", "integer"));
207+
response, schema("birthdate", null, "timestamp"), schema("count(*)", "count", "long"));
208208
verifyDataRows(response, rows("2018-06-23 00:00:00", 1));
209209
}
210210

@@ -220,9 +220,7 @@ public void groupByDateWithAliasShouldPass() {
220220
Index.BANK.getName()));
221221

222222
verifySchema(
223-
response,
224-
schema("birthdate", "birth", "timestamp"),
225-
schema("count(*)", "count", "integer"));
223+
response, schema("birthdate", "birth", "timestamp"), schema("count(*)", "count", "long"));
226224
verifyDataRows(response, rows("2018-06-23 00:00:00", 1));
227225
}
228226

0 commit comments

Comments
 (0)