16
16
package org .springframework .data .jdbc .core .convert ;
17
17
18
18
import java .util .*;
19
+ import java .util .function .BiFunction ;
19
20
import java .util .function .Function ;
20
21
import java .util .function .Predicate ;
21
22
import java .util .stream .Collectors ;
@@ -115,7 +116,7 @@ public class SqlGenerator {
115
116
116
117
/**
117
118
* Create a basic select structure with all the necessary joins
118
- *
119
+ *
119
120
* @param table the table to base the select on
120
121
* @param pathFilter a filter for excluding paths from the select. All paths for which the filter returns
121
122
* {@literal true} will be skipped when determining columns to select.
@@ -185,6 +186,8 @@ private Condition getSubselectCondition(AggregatePath path,
185
186
Table subSelectTable = Table .create (parentPathTableInfo .qualifiedTableName ());
186
187
187
188
Map <AggregatePath , Column > selectFilterColumns = new TreeMap <>();
189
+
190
+ // TODO: cannot we simply pass on the columnInfos?
188
191
parentPathTableInfo .effectiveIdColumnInfos ().forEach ( //
189
192
(ap , ci ) -> //
190
193
selectFilterColumns .put (ap , subSelectTable .column (ci .name ())) //
@@ -468,6 +471,8 @@ String createDeleteAllSql(@Nullable PersistentPropertyPath<RelationalPersistentP
468
471
* @return the statement as a {@link String}. Guaranteed to be not {@literal null}.
469
472
*/
470
473
String createDeleteByPath (PersistentPropertyPath <RelationalPersistentProperty > path ) {
474
+ // TODO: When deleting by path, why do we expect the where-value to be id and not named after the path?
475
+ // See SqlGeneratorEmbeddedUnitTests.deleteByPath
471
476
return createDeleteByPathAndCriteria (mappingContext .getAggregatePath (path ), this ::equalityCondition );
472
477
}
473
478
@@ -487,12 +492,10 @@ String createDeleteInByPath(PersistentPropertyPath<RelationalPersistentProperty>
487
492
*/
488
493
private Condition inCondition (Map <AggregatePath , Column > columnMap ) {
489
494
490
- List <Column > columns = List . copyOf ( columnMap .values () );
495
+ Collection <Column > columns = columnMap .values ();
491
496
492
- if (columns .size () == 1 ) {
493
- return Conditions .in (columns .get (0 ), getBindMarker (IDS_SQL_PARAMETER ));
494
- }
495
- return Conditions .in (TupleExpression .create (columns ), getBindMarker (IDS_SQL_PARAMETER ));
497
+ return Conditions .in (columns .size () == 1 ? columns .iterator ().next () : TupleExpression .create (columns ),
498
+ getBindMarker (IDS_SQL_PARAMETER ));
496
499
}
497
500
498
501
/**
@@ -501,44 +504,54 @@ private Condition inCondition(Map<AggregatePath, Column> columnMap) {
501
504
*/
502
505
private Condition equalityCondition (Map <AggregatePath , Column > columnMap ) {
503
506
504
- AggregatePath . ColumnInfos idColumnInfos = mappingContext . getAggregatePath ( entity ). getTableInfo (). idColumnInfos ( );
507
+ Assert . isTrue (! columnMap . isEmpty (), "Column map must not be empty" );
505
508
506
- Condition result = null ;
507
- for (Map .Entry <AggregatePath , Column > entry : columnMap .entrySet ()) {
508
- BindMarker bindMarker = getBindMarker (idColumnInfos .get (entry .getKey ()).name ());
509
- Comparison singleCondition = entry .getValue ().isEqualTo (bindMarker );
509
+ AggregatePath .ColumnInfos idColumnInfos = mappingContext .getAggregatePath (entity ).getTableInfo ().idColumnInfos ();
510
510
511
- result = result == null ? singleCondition : result .and (singleCondition );
512
- }
513
- Assert .state (result != null , "We need at least one condition" );
514
- return result ;
511
+ return createPredicate (columnMap , (aggregatePath , column ) -> {
512
+ return column .isEqualTo (getBindMarker (idColumnInfos .get (aggregatePath ).name ()));
513
+ });
515
514
}
516
515
517
516
/**
518
517
* Constructs a function for constructing where a condition. The where condition will be of the form
519
518
* {@literal <column-a> IS NOT NULL AND <column-b> IS NOT NULL ... }
520
519
*/
521
520
private Condition isNotNullCondition (Map <AggregatePath , Column > columnMap ) {
521
+ return createPredicate (columnMap , (aggregatePath , column ) -> column .isNotNull ());
522
+ }
523
+
524
+ /**
525
+ * Constructs a function for constructing where a condition. The where condition will be of the form
526
+ * {@literal <column-a> IS NOT NULL AND <column-b> IS NOT NULL ... }
527
+ */
528
+ private static Condition createPredicate (Map <AggregatePath , Column > columnMap ,
529
+ BiFunction <AggregatePath , Column , Condition > conditionFunction ) {
522
530
523
531
Condition result = null ;
524
- for (Column column : columnMap .values ()) {
525
- Condition singleCondition = column .isNotNull ();
532
+ for (Map .Entry <AggregatePath , Column > entry : columnMap .entrySet ()) {
526
533
534
+ Condition singleCondition = conditionFunction .apply (entry .getKey (), entry .getValue ());
527
535
result = result == null ? singleCondition : result .and (singleCondition );
528
536
}
529
537
Assert .state (result != null , "We need at least one condition" );
530
538
return result ;
531
539
}
532
540
533
541
private String createFindOneSql () {
534
-
535
542
return render (selectBuilder ().where (equalityIdWhereCondition ()).build ());
536
543
}
537
544
538
545
private Condition equalityIdWhereCondition () {
546
+ return equalityIdWhereCondition (getIdColumns ());
547
+ }
548
+
549
+ private Condition equalityIdWhereCondition (Iterable <Column > columns ) {
550
+
551
+ Assert .isTrue (columns .iterator ().hasNext (), "Identifier columns must not be empty" );
539
552
540
553
Condition aggregate = null ;
541
- for (Column column : getIdColumns () ) {
554
+ for (Column column : columns ) {
542
555
543
556
Comparison condition = column .isEqualTo (getBindMarker (column .getName ()));
544
557
aggregate = aggregate == null ? condition : aggregate .and (condition );
@@ -711,19 +724,13 @@ Join getJoin(AggregatePath path) {
711
724
Table parentTable = sqlContext .getTable (idDefiningParentPath );
712
725
AggregatePath .ColumnInfos idColumnInfos = idDefiningParentPath .getTableInfo ().idColumnInfos ();
713
726
714
- final Condition [] joinCondition = { null };
715
- backRefColumnInfos .forEach ((ap , ci ) -> {
716
-
717
- Condition elementalCondition = currentTable .column (ci .name ())
718
- .isEqualTo (parentTable .column (idColumnInfos .get (ap ).name ()));
719
- joinCondition [0 ] = joinCondition [0 ] == null ? elementalCondition : joinCondition [0 ].and (elementalCondition );
720
- });
727
+ Condition joinCondition = backRefColumnInfos .reduce (Conditions .unrestricted (), (aggregatePath , columnInfo ) -> {
721
728
722
- return new Join ( //
723
- currentTable , //
724
- joinCondition [0 ] //
725
- );
729
+ return currentTable .column (columnInfo .name ())
730
+ .isEqualTo (parentTable .column (idColumnInfos .get (aggregatePath ).name ()));
731
+ }, Condition ::and );
726
732
733
+ return new Join (currentTable , joinCondition );
727
734
}
728
735
729
736
private String createFindAllInListSql () {
@@ -862,6 +869,8 @@ private String createDeleteByPathAndCriteria(AggregatePath path,
862
869
863
870
Map <AggregatePath , Column > columns = new TreeMap <>();
864
871
AggregatePath .ColumnInfos columnInfos = path .getTableInfo ().backReferenceColumnInfos ();
872
+
873
+ // TODO: cannot we simply pass on the columnInfos?
865
874
columnInfos .forEach ((ag , ci ) -> columns .put (ag , table .column (ci .name ())));
866
875
867
876
if (isFirstNonRoot (path )) {
@@ -915,17 +924,20 @@ private Table getTable() {
915
924
*/
916
925
private Column getSingleNonNullColumn () {
917
926
927
+ // getColumn() is slightly different from the code in any(…). Why?
928
+ // AggregatePath.ColumnInfo columnInfo = path.getColumnInfo();
929
+ // return getTable(path).column(columnInfo.name()).as(columnInfo.alias());
930
+
918
931
AggregatePath .ColumnInfos columnInfos = mappingContext .getAggregatePath (entity ).getTableInfo ().idColumnInfos ();
919
932
return columnInfos .any ((ap , ci ) -> sqlContext .getTable (columnInfos .fullPath (ap )).column (ci .name ()).as (ci .alias ()));
920
933
}
921
934
922
935
private List <Column > getIdColumns () {
923
936
924
937
AggregatePath .ColumnInfos columnInfos = mappingContext .getAggregatePath (entity ).getTableInfo ().idColumnInfos ();
925
- List <Column > result = new ArrayList <>(columnInfos .size ());
926
- columnInfos .forEach ((ap , ci ) -> result .add (sqlContext .getColumn (columnInfos .fullPath (ap ))));
927
938
928
- return result ;
939
+ return columnInfos
940
+ .toColumnList ((aggregatePath , columnInfo ) -> sqlContext .getColumn (columnInfos .fullPath (aggregatePath )));
929
941
}
930
942
931
943
private Column getVersionColumn () {
0 commit comments