22
22
import java .util .ArrayList ;
23
23
import java .util .Collections ;
24
24
import java .util .List ;
25
- import java .util .Map ;
26
- import java .util .function .Predicate ;
25
+ import java .util .Optional ;
27
26
28
- import org .springframework .dao .DataRetrievalFailureException ;
29
27
import org .springframework .dao .EmptyResultDataAccessException ;
30
- import org .springframework .dao .InvalidDataAccessApiUsageException ;
31
28
import org .springframework .dao .OptimisticLockingFailureException ;
32
29
import org .springframework .data .domain .Pageable ;
33
30
import org .springframework .data .domain .Sort ;
34
31
import org .springframework .data .jdbc .core .mapping .JdbcValue ;
35
32
import org .springframework .data .jdbc .support .JdbcUtil ;
36
- import org .springframework .data .mapping .PersistentProperty ;
37
- import org .springframework .data .mapping .PersistentPropertyAccessor ;
38
33
import org .springframework .data .mapping .PersistentPropertyPath ;
39
- import org .springframework .data .relational .core .dialect .IdGeneration ;
40
34
import org .springframework .data .relational .core .mapping .PersistentPropertyPathExtension ;
41
35
import org .springframework .data .relational .core .mapping .RelationalMappingContext ;
42
36
import org .springframework .data .relational .core .mapping .RelationalPersistentEntity ;
47
41
import org .springframework .jdbc .core .RowMapper ;
48
42
import org .springframework .jdbc .core .namedparam .NamedParameterJdbcOperations ;
49
43
import org .springframework .jdbc .core .namedparam .SqlParameterSource ;
50
- import org .springframework .jdbc .support .GeneratedKeyHolder ;
51
44
import org .springframework .jdbc .support .JdbcUtils ;
52
- import org .springframework .jdbc .support .KeyHolder ;
53
45
import org .springframework .lang .Nullable ;
54
46
import org .springframework .util .Assert ;
55
47
@@ -75,8 +67,8 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
75
67
private final RelationalMappingContext context ;
76
68
private final JdbcConverter converter ;
77
69
private final NamedParameterJdbcOperations operations ;
78
- private final BatchJdbcOperations batchOperations ;
79
70
private final SqlParametersFactory sqlParametersFactory ;
71
+ private final InsertStrategyFactory insertStrategyFactory ;
80
72
81
73
/**
82
74
* Creates a {@link DefaultDataAccessStrategy}
@@ -88,26 +80,21 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
88
80
* @since 1.1
89
81
*/
90
82
public DefaultDataAccessStrategy (SqlGeneratorSource sqlGeneratorSource , RelationalMappingContext context ,
91
- JdbcConverter converter , NamedParameterJdbcOperations operations , SqlParametersFactory sqlParametersFactory ) {
92
- this (sqlGeneratorSource , context , converter , operations , new BatchJdbcOperations (operations .getJdbcOperations ()), sqlParametersFactory );
93
- }
94
-
95
- DefaultDataAccessStrategy (SqlGeneratorSource sqlGeneratorSource , RelationalMappingContext context ,
96
- JdbcConverter converter , NamedParameterJdbcOperations operations ,
97
- BatchJdbcOperations batchOperations , SqlParametersFactory sqlParametersFactory ) {
98
-
83
+ JdbcConverter converter , NamedParameterJdbcOperations operations , SqlParametersFactory sqlParametersFactory ,
84
+ InsertStrategyFactory insertStrategyFactory ) {
99
85
Assert .notNull (sqlGeneratorSource , "SqlGeneratorSource must not be null" );
100
86
Assert .notNull (context , "RelationalMappingContext must not be null" );
101
87
Assert .notNull (converter , "JdbcConverter must not be null" );
102
88
Assert .notNull (operations , "NamedParameterJdbcOperations must not be null" );
103
89
Assert .notNull (sqlParametersFactory , "SqlParametersFactory must not be null" );
90
+ Assert .notNull (insertStrategyFactory , "InsertStrategyFactory must not be null" );
104
91
105
92
this .sqlGeneratorSource = sqlGeneratorSource ;
106
93
this .context = context ;
107
94
this .converter = converter ;
108
95
this .operations = operations ;
109
- this .batchOperations = batchOperations ;
110
96
this .sqlParametersFactory = sqlParametersFactory ;
97
+ this .insertStrategyFactory = insertStrategyFactory ;
111
98
}
112
99
113
100
/*
@@ -121,14 +108,7 @@ public <T> Object insert(T instance, Class<T> domainType, Identifier identifier,
121
108
122
109
String insertSql = sql (domainType ).getInsert (parameterSource .getIdentifiers ());
123
110
124
- if (!includeId ) {
125
- RelationalPersistentEntity <T > persistentEntity = getRequiredPersistentEntity (domainType );
126
- return executeInsertAndReturnGeneratedId (persistentEntity , parameterSource , insertSql );
127
- } else {
128
-
129
- operations .update (insertSql , parameterSource );
130
- return null ;
131
- }
111
+ return insertStrategyFactory .insertStrategy (!includeId , getIdColumn (domainType )).execute (insertSql , parameterSource );
132
112
}
133
113
134
114
@ Override
@@ -141,68 +121,7 @@ public <T> Object[] insert(List<RecordDescriptor<T>> recordDescriptors, Class<T>
141
121
142
122
String insertSql = sql (domainType ).getInsert (sqlParameterSources [0 ].getIdentifiers ());
143
123
144
- if (includeId ) {
145
- operations .batchUpdate (insertSql , sqlParameterSources );
146
- return new Object [sqlParameterSources .length ];
147
- }
148
- GeneratedKeyHolder holder = new GeneratedKeyHolder ();
149
-
150
- IdGeneration idGeneration = sqlGeneratorSource .getDialect ().getIdGeneration ();
151
-
152
- RelationalPersistentEntity <T > persistentEntity = getRequiredPersistentEntity (domainType );
153
- if (idGeneration .driverRequiresKeyColumnNames ()) {
154
-
155
- String [] keyColumnNames = getKeyColumnNames (persistentEntity .getType ());
156
- if (keyColumnNames .length == 0 ) {
157
- batchOperations .insert (insertSql , sqlParameterSources , holder );
158
- } else {
159
- batchOperations .insert (insertSql , sqlParameterSources , holder , keyColumnNames );
160
- }
161
- } else {
162
- batchOperations .insert (insertSql , sqlParameterSources , holder );
163
- }
164
- // TODO: Is this needed?
165
- if (!persistentEntity .hasIdProperty ()) {
166
- return new Object [sqlParameterSources .length ];
167
- }
168
- // TODO: Duplicated in #getIdFromHolder - consider refactoring
169
- Object [] ids = new Object [sqlParameterSources .length ];
170
- List <Map <String , Object >> keyList = holder .getKeyList ();
171
- for (int i = 0 ; i < keyList .size (); i ++) {
172
- Map <String , Object > keys = keyList .get (i );
173
- Object id ;
174
- if (keys .size () > 1 ) {
175
- id = keys .get (persistentEntity .getIdColumn ().getReference (getIdentifierProcessing ()));
176
- } else {
177
- id = keys .entrySet ().stream ().findFirst () //
178
- .map (Map .Entry ::getValue ) //
179
- .orElseThrow (() -> new IllegalStateException ("KeyHolder contains an empty key list." ));
180
- }
181
- ids [i ] = id ;
182
- }
183
- return ids ;
184
- }
185
-
186
- @ Nullable
187
- private <T > Object executeInsertAndReturnGeneratedId (RelationalPersistentEntity <T > persistentEntity , SqlIdentifierParameterSource parameterSource , String insertSql ) {
188
-
189
- KeyHolder holder = new GeneratedKeyHolder ();
190
-
191
- IdGeneration idGeneration = sqlGeneratorSource .getDialect ().getIdGeneration ();
192
-
193
- if (idGeneration .driverRequiresKeyColumnNames ()) {
194
-
195
- String [] keyColumnNames = getKeyColumnNames (persistentEntity .getType ());
196
- if (keyColumnNames .length == 0 ) {
197
- operations .update (insertSql , parameterSource , holder );
198
- } else {
199
- operations .update (insertSql , parameterSource , holder , keyColumnNames );
200
- }
201
- } else {
202
- operations .update (insertSql , parameterSource , holder );
203
- }
204
-
205
- return getIdFromHolder (holder , persistentEntity );
124
+ return insertStrategyFactory .insertStrategy (!includeId , getIdColumn (domainType )).execute (insertSql , sqlParameterSources );
206
125
}
207
126
208
127
/*
@@ -475,26 +394,6 @@ public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {
475
394
return operations .query (sql (domainType ).getFindAll (pageable ), (RowMapper <T >) getEntityRowMapper (domainType ));
476
395
}
477
396
478
- @ Nullable
479
- private <S > Object getIdFromHolder (KeyHolder holder , RelationalPersistentEntity <S > persistentEntity ) {
480
-
481
- try {
482
- // MySQL just returns one value with a special name
483
- return holder .getKey ();
484
- } catch (DataRetrievalFailureException | InvalidDataAccessApiUsageException e ) {
485
- // Postgres returns a value for each column
486
- // MS SQL Server returns a value that might be null.
487
-
488
- Map <String , Object > keys = holder .getKeys ();
489
-
490
- if (keys == null || persistentEntity .getIdProperty () == null ) {
491
- return null ;
492
- }
493
-
494
- return keys .get (persistentEntity .getIdColumn ().getReference (getIdentifierProcessing ()));
495
- }
496
- }
497
-
498
397
private EntityRowMapper <?> getEntityRowMapper (Class <?> domainType ) {
499
398
return new EntityRowMapper <>(getRequiredPersistentEntity (domainType ), converter );
500
399
}
@@ -587,17 +486,10 @@ private SqlGenerator sql(Class<?> domainType) {
587
486
return sqlGeneratorSource .getSqlGenerator (domainType );
588
487
}
589
488
590
- private <T > String [] getKeyColumnNames (Class <T > domainType ) {
591
-
592
- RelationalPersistentEntity <?> requiredPersistentEntity = context .getRequiredPersistentEntity (domainType );
593
-
594
- if (!requiredPersistentEntity .hasIdProperty ()) {
595
- return new String [0 ];
596
- }
597
-
598
- SqlIdentifier idColumn = requiredPersistentEntity .getIdColumn ();
599
-
600
- return new String [] { idColumn .getReference (getIdentifierProcessing ()) };
489
+ @ Nullable
490
+ private <T > SqlIdentifier getIdColumn (Class <T > domainType ) {
491
+ return Optional .ofNullable (context .getRequiredPersistentEntity (domainType ).getIdProperty ())
492
+ .map (RelationalPersistentProperty ::getColumnName )
493
+ .orElse (null );
601
494
}
602
-
603
495
}
0 commit comments