@@ -175,7 +175,8 @@ private Object parseSql(SqlNode sqlNode, Set<String> sideTableSet, Queue<Object>
175175 }
176176 break ;
177177 case JOIN :
178- return dealJoinNode ((SqlJoin ) sqlNode , sideTableSet , queueInfo , parentWhere , parentSelectList );
178+ Set <Tuple2 <String , String >> joinFieldSet = Sets .newHashSet ();
179+ return dealJoinNode ((SqlJoin ) sqlNode , sideTableSet , queueInfo , parentWhere , parentSelectList , joinFieldSet );
179180 case AS :
180181 SqlNode info = ((SqlBasicCall )sqlNode ).getOperands ()[0 ];
181182 SqlNode alias = ((SqlBasicCall ) sqlNode ).getOperands ()[1 ];
@@ -248,7 +249,7 @@ private SqlBasicCall buildAsSqlNode(String internalTableName, SqlNode newSource)
248249 * @return
249250 */
250251 private JoinInfo dealJoinNode (SqlJoin joinNode , Set <String > sideTableSet , Queue <Object > queueInfo ,
251- SqlNode parentWhere , SqlNodeList parentSelectList ) {
252+ SqlNode parentWhere , SqlNodeList parentSelectList , Set < Tuple2 < String , String >> joinFieldSet ) {
252253 SqlNode leftNode = joinNode .getLeft ();
253254 SqlNode rightNode = joinNode .getRight ();
254255 JoinType joinType = joinNode .getJoinType ();
@@ -261,12 +262,14 @@ private JoinInfo dealJoinNode(SqlJoin joinNode, Set<String> sideTableSet, Queue<
261262
262263 //如果是连续join 判断是否已经处理过添加到执行队列
263264 Boolean alreadyOffer = false ;
265+ extractJoinField (joinNode .getCondition (), joinFieldSet );
264266
265267 if (leftNode .getKind () == IDENTIFIER ){
266268 leftTbName = leftNode .toString ();
267269 } else if (leftNode .getKind () == JOIN ) {
268270 //处理连续join
269- Tuple2 <Boolean , SqlBasicCall > nestJoinResult = dealNestJoin ((SqlJoin ) leftNode , sideTableSet , queueInfo , parentWhere , parentSelectList );
271+ Tuple2 <Boolean , SqlBasicCall > nestJoinResult = dealNestJoin ((SqlJoin ) leftNode , sideTableSet ,
272+ queueInfo , parentWhere , parentSelectList , joinFieldSet );
270273 alreadyOffer = nestJoinResult .f0 ;
271274 leftTbName = nestJoinResult .f1 .getOperands ()[0 ].toString ();
272275 leftTbAlias = nestJoinResult .f1 .getOperands ()[1 ].toString ();
@@ -320,7 +323,8 @@ private JoinInfo dealJoinNode(SqlJoin joinNode, Set<String> sideTableSet, Queue<
320323 }
321324
322325 if (tableInfo .getLeftNode ().getKind () != AS ){
323- extractTemporaryQuery (tableInfo .getLeftNode (), tableInfo .getLeftTableAlias (), (SqlBasicCall ) parentWhere , parentSelectList , queueInfo );
326+ extractTemporaryQuery (tableInfo .getLeftNode (), tableInfo .getLeftTableAlias (), (SqlBasicCall ) parentWhere ,
327+ parentSelectList , queueInfo , joinFieldSet );
324328 }else {
325329 SqlKind asNodeFirstKind = ((SqlBasicCall )tableInfo .getLeftNode ()).operands [0 ].getKind ();
326330 if (asNodeFirstKind == SELECT ){
@@ -331,11 +335,14 @@ private JoinInfo dealJoinNode(SqlJoin joinNode, Set<String> sideTableSet, Queue<
331335 return tableInfo ;
332336 }
333337
338+
334339 //构建新的查询
335- private Tuple2 <Boolean , SqlBasicCall > dealNestJoin (SqlJoin joinNode , Set <String > sideTableSet , Queue <Object > queueInfo , SqlNode parentWhere , SqlNodeList selectList ){
340+ private Tuple2 <Boolean , SqlBasicCall > dealNestJoin (SqlJoin joinNode , Set <String > sideTableSet ,
341+ Queue <Object > queueInfo , SqlNode parentWhere ,
342+ SqlNodeList selectList , Set <Tuple2 <String , String >> joinFieldSet ){
336343 SqlNode rightNode = joinNode .getRight ();
337344 Tuple2 <String , String > rightTableNameAndAlias = parseRightNode (rightNode , sideTableSet , queueInfo , parentWhere , selectList );
338- JoinInfo joinInfo = dealJoinNode (joinNode , sideTableSet , queueInfo , parentWhere , selectList );
345+ JoinInfo joinInfo = dealJoinNode (joinNode , sideTableSet , queueInfo , parentWhere , selectList , joinFieldSet );
339346
340347 String rightTableName = rightTableNameAndAlias .f0 ;
341348 boolean rightIsSide = checkIsSideTable (rightTableName , sideTableSet );
@@ -352,23 +359,23 @@ private Tuple2<Boolean, SqlBasicCall> dealNestJoin(SqlJoin joinNode, Set<String>
352359 return Tuple2 .of (alreadyOffer , TableUtils .buildAsNodeByJoinInfo (joinInfo , null , null ));
353360 }
354361
355- public boolean checkAndRemoveCondition (Set <String > fromTableNameSet , SqlBasicCall parentWhere , List <SqlBasicCall > extractContition ){
362+ public boolean checkAndRemoveCondition (Set <String > fromTableNameSet , SqlBasicCall parentWhere , List <SqlBasicCall > extractCondition ){
356363
357364 if (parentWhere == null ){
358365 return false ;
359366 }
360367
361368 SqlKind kind = parentWhere .getKind ();
362369 if (kind == AND ){
363- boolean removeLeft = checkAndRemoveCondition (fromTableNameSet , (SqlBasicCall ) parentWhere .getOperands ()[0 ], extractContition );
364- boolean removeRight = checkAndRemoveCondition (fromTableNameSet , (SqlBasicCall ) parentWhere .getOperands ()[1 ], extractContition );
370+ boolean removeLeft = checkAndRemoveCondition (fromTableNameSet , (SqlBasicCall ) parentWhere .getOperands ()[0 ], extractCondition );
371+ boolean removeRight = checkAndRemoveCondition (fromTableNameSet , (SqlBasicCall ) parentWhere .getOperands ()[1 ], extractCondition );
365372 //DO remove
366373 if (removeLeft ){
367- extractContition .add (removeWhereConditionNode (parentWhere , 0 ));
374+ extractCondition .add (removeWhereConditionNode (parentWhere , 0 ));
368375 }
369376
370377 if (removeRight ){
371- extractContition .add (removeWhereConditionNode (parentWhere , 1 ));
378+ extractCondition .add (removeWhereConditionNode (parentWhere , 1 ));
372379 }
373380
374381 return false ;
@@ -385,7 +392,8 @@ public boolean checkAndRemoveCondition(Set<String> fromTableNameSet, SqlBasicCal
385392 }
386393
387394 private void extractTemporaryQuery (SqlNode node , String tableAlias , SqlBasicCall parentWhere ,
388- SqlNodeList parentSelectList , Queue <Object > queueInfo ){
395+ SqlNodeList parentSelectList , Queue <Object > queueInfo ,
396+ Set <Tuple2 <String , String >> joinFieldSet ){
389397 try {
390398 //父一级的where 条件中如果只和临时查询相关的条件都截取进来
391399 Set <String > fromTableNameSet = Sets .newHashSet ();
@@ -394,8 +402,9 @@ private void extractTemporaryQuery(SqlNode node, String tableAlias, SqlBasicCall
394402 getFromTableInfo (node , fromTableNameSet );
395403 checkAndRemoveCondition (fromTableNameSet , parentWhere , extractCondition );
396404
397- List <String > extractSelectField = extractSelectList (parentSelectList , fromTableNameSet );
398- String extractSelectFieldStr = buildSelectNode (extractSelectField );
405+ Set <String > extractSelectField = extractSelectFields (parentSelectList , fromTableNameSet );
406+ Set <String > fieldFromJoinCondition = extractSelectFieldFromJoinCondition (joinFieldSet , fromTableNameSet );
407+ String extractSelectFieldStr = buildSelectNode (extractSelectField , fieldFromJoinCondition );
399408 String extractConditionStr = buildCondition (extractCondition );
400409
401410 String tmpSelectSql = String .format (SELECT_TEMP_SQL ,
@@ -425,19 +434,50 @@ private void extractTemporaryQuery(SqlNode node, String tableAlias, SqlBasicCall
425434 * @param fromTableNameSet
426435 * @return
427436 */
428- private List <String > extractSelectList (SqlNodeList parentSelectList , Set <String > fromTableNameSet ){
429- List <String > extractFieldList = Lists . newArrayList ();
437+ private Set <String > extractSelectFields (SqlNodeList parentSelectList , Set <String > fromTableNameSet ){
438+ Set <String > extractFieldList = Sets . newHashSet ();
430439 for (SqlNode selectNode : parentSelectList .getList ()){
431440 extractSelectField (selectNode , extractFieldList , fromTableNameSet );
432441 }
433442
434443 return extractFieldList ;
435444 }
436445
437- private void extractSelectField (SqlNode selectNode , List <String > extractFieldList , Set <String > fromTableNameSet ){
446+ private Set <String > extractSelectFieldFromJoinCondition (Set <Tuple2 <String , String >> joinFieldSet , Set <String > fromTableNameSet ){
447+ Set <String > extractFieldList = Sets .newHashSet ();
448+ for (Tuple2 <String , String > field : joinFieldSet ){
449+ if (fromTableNameSet .contains (field .f0 )){
450+ extractFieldList .add (field .f0 + "." + field .f1 );
451+ }
452+ }
453+
454+ return extractFieldList ;
455+ }
456+
457+ /**
458+ * 从join的条件中获取字段信息
459+ * @param condition
460+ * @param joinFieldSet
461+ */
462+ private void extractJoinField (SqlNode condition , Set <Tuple2 <String , String >> joinFieldSet ){
463+ SqlKind joinKind = condition .getKind ();
464+ if ( joinKind == AND ){
465+ extractJoinField (((SqlBasicCall )condition ).operands [0 ], joinFieldSet );
466+ extractJoinField (((SqlBasicCall )condition ).operands [1 ], joinFieldSet );
467+ }else if ( joinKind == EQUALS ){
468+ extractJoinField (((SqlBasicCall )condition ).operands [0 ], joinFieldSet );
469+ extractJoinField (((SqlBasicCall )condition ).operands [1 ], joinFieldSet );
470+ }else {
471+ Preconditions .checkState (((SqlIdentifier )condition ).names .size () == 2 , "join condition must be format table.field" );
472+ Tuple2 <String , String > tuple2 = Tuple2 .of (((SqlIdentifier )condition ).names .get (0 ), ((SqlIdentifier )condition ).names .get (1 ));
473+ joinFieldSet .add (tuple2 );
474+ }
475+ }
476+
477+ private void extractSelectField (SqlNode selectNode , Set <String > extractFieldSet , Set <String > fromTableNameSet ){
438478 if (selectNode .getKind () == AS ) {
439479 SqlNode leftNode = ((SqlBasicCall ) selectNode ).getOperands ()[0 ];
440- extractSelectField (leftNode , extractFieldList , fromTableNameSet );
480+ extractSelectField (leftNode , extractFieldSet , fromTableNameSet );
441481
442482 }else if (selectNode .getKind () == IDENTIFIER ) {
443483 SqlIdentifier sqlIdentifier = (SqlIdentifier ) selectNode ;
@@ -448,7 +488,7 @@ private void extractSelectField(SqlNode selectNode, List<String> extractFieldLis
448488
449489 String tableName = sqlIdentifier .names .get (0 );
450490 if (fromTableNameSet .contains (tableName )){
451- extractFieldList .add (sqlIdentifier .toString ());
491+ extractFieldSet .add (sqlIdentifier .toString ());
452492 }
453493
454494 }else if ( AGGREGATE .contains (selectNode .getKind ())
@@ -493,7 +533,7 @@ private void extractSelectField(SqlNode selectNode, List<String> extractFieldLis
493533 continue ;
494534 }
495535
496- extractSelectField (sqlNode , extractFieldList , fromTableNameSet );
536+ extractSelectField (sqlNode , extractFieldSet , fromTableNameSet );
497537 }
498538
499539 }else if (selectNode .getKind () == CASE ){
@@ -505,15 +545,15 @@ private void extractSelectField(SqlNode selectNode, List<String> extractFieldLis
505545
506546 for (int i =0 ; i <whenOperands .size (); i ++){
507547 SqlNode oneOperand = whenOperands .get (i );
508- extractSelectField (oneOperand , extractFieldList , fromTableNameSet );
548+ extractSelectField (oneOperand , extractFieldSet , fromTableNameSet );
509549 }
510550
511551 for (int i =0 ; i <thenOperands .size (); i ++){
512552 SqlNode oneOperand = thenOperands .get (i );
513- extractSelectField (oneOperand , extractFieldList , fromTableNameSet );
553+ extractSelectField (oneOperand , extractFieldSet , fromTableNameSet );
514554 }
515555
516- extractSelectField (elseNode , extractFieldList , fromTableNameSet );
556+ extractSelectField (elseNode , extractFieldSet , fromTableNameSet );
517557 }else {
518558 //do nothing
519559 }
@@ -566,12 +606,14 @@ public String buildCondition(List<SqlBasicCall> conditionList){
566606 return " where " + StringUtils .join (conditionList , " AND " );
567607 }
568608
569- public String buildSelectNode (List <String > extractSelectField ){
609+ public String buildSelectNode (Set <String > extractSelectField , Set < String > joinFieldSet ){
570610 if (CollectionUtils .isEmpty (extractSelectField )){
571611 throw new RuntimeException ("no field is used" );
572612 }
573613
574- return StringUtils .join (extractSelectField , "," );
614+ Sets .SetView view = Sets .union (extractSelectField , joinFieldSet );
615+
616+ return StringUtils .join (view , "," );
575617 }
576618
577619 public SqlBasicCall buildDefaultCondition (){
0 commit comments