22
22
import java .util .ArrayList ;
23
23
import java .util .Collections ;
24
24
import java .util .List ;
25
- import java .util .Map ;
26
- import java .util .function .Predicate ;
25
+ import java .util .Optional ;
27
26
28
- import org .springframework .dao .DataRetrievalFailureException ;
29
27
import org .springframework .dao .EmptyResultDataAccessException ;
30
- import org .springframework .dao .InvalidDataAccessApiUsageException ;
31
28
import org .springframework .dao .OptimisticLockingFailureException ;
32
29
import org .springframework .data .domain .Pageable ;
33
30
import org .springframework .data .domain .Sort ;
34
31
import org .springframework .data .jdbc .core .mapping .JdbcValue ;
35
32
import org .springframework .data .jdbc .support .JdbcUtil ;
36
- import org .springframework .data .mapping .PersistentProperty ;
37
- import org .springframework .data .mapping .PersistentPropertyAccessor ;
38
33
import org .springframework .data .mapping .PersistentPropertyPath ;
39
- import org .springframework .data .relational .core .dialect .IdGeneration ;
40
34
import org .springframework .data .relational .core .mapping .PersistentPropertyPathExtension ;
41
35
import org .springframework .data .relational .core .mapping .RelationalMappingContext ;
42
36
import org .springframework .data .relational .core .mapping .RelationalPersistentEntity ;
47
41
import org .springframework .jdbc .core .RowMapper ;
48
42
import org .springframework .jdbc .core .namedparam .NamedParameterJdbcOperations ;
49
43
import org .springframework .jdbc .core .namedparam .SqlParameterSource ;
50
- import org .springframework .jdbc .support .GeneratedKeyHolder ;
51
44
import org .springframework .jdbc .support .JdbcUtils ;
52
- import org .springframework .jdbc .support .KeyHolder ;
53
45
import org .springframework .lang .Nullable ;
54
46
import org .springframework .util .Assert ;
55
47
@@ -75,8 +67,8 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
75
67
private final RelationalMappingContext context ;
76
68
private final JdbcConverter converter ;
77
69
private final NamedParameterJdbcOperations operations ;
78
- private final BatchJdbcOperations batchOperations ;
79
70
private final SqlParametersFactory sqlParametersFactory ;
71
+ private final InsertStrategyFactory insertStrategyFactory ;
80
72
81
73
/**
82
74
* Creates a {@link DefaultDataAccessStrategy}
@@ -88,26 +80,21 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
88
80
* @since 1.1
89
81
*/
90
82
public DefaultDataAccessStrategy (SqlGeneratorSource sqlGeneratorSource , RelationalMappingContext context ,
91
- JdbcConverter converter , NamedParameterJdbcOperations operations , SqlParametersFactory sqlParametersFactory ) {
92
- this (sqlGeneratorSource , context , converter , operations , new BatchJdbcOperations (operations .getJdbcOperations ()), sqlParametersFactory );
93
- }
94
-
95
- DefaultDataAccessStrategy (SqlGeneratorSource sqlGeneratorSource , RelationalMappingContext context ,
96
- JdbcConverter converter , NamedParameterJdbcOperations operations ,
97
- BatchJdbcOperations batchOperations , SqlParametersFactory sqlParametersFactory ) {
98
-
83
+ JdbcConverter converter , NamedParameterJdbcOperations operations , SqlParametersFactory sqlParametersFactory ,
84
+ InsertStrategyFactory insertStrategyFactory ) {
99
85
Assert .notNull (sqlGeneratorSource , "SqlGeneratorSource must not be null" );
100
86
Assert .notNull (context , "RelationalMappingContext must not be null" );
101
87
Assert .notNull (converter , "JdbcConverter must not be null" );
102
88
Assert .notNull (operations , "NamedParameterJdbcOperations must not be null" );
103
89
Assert .notNull (sqlParametersFactory , "SqlParametersFactory must not be null" );
90
+ Assert .notNull (insertStrategyFactory , "InsertStrategyFactory must not be null" );
104
91
105
92
this .sqlGeneratorSource = sqlGeneratorSource ;
106
93
this .context = context ;
107
94
this .converter = converter ;
108
95
this .operations = operations ;
109
- this .batchOperations = batchOperations ;
110
96
this .sqlParametersFactory = sqlParametersFactory ;
97
+ this .insertStrategyFactory = insertStrategyFactory ;
111
98
}
112
99
113
100
/*
@@ -121,14 +108,7 @@ public <T> Object insert(T instance, Class<T> domainType, Identifier identifier,
121
108
122
109
String insertSql = sql (domainType ).getInsert (parameterSource .getIdentifiers ());
123
110
124
- if (!includeId ) {
125
- RelationalPersistentEntity <T > persistentEntity = getRequiredPersistentEntity (domainType );
126
- return executeInsertAndReturnGeneratedId (persistentEntity , parameterSource , insertSql );
127
- } else {
128
-
129
- operations .update (insertSql , parameterSource );
130
- return null ;
131
- }
111
+ return insertStrategyFactory .insertStrategy (!includeId , getIdColumn (domainType )).execute (insertSql , parameterSource );
132
112
}
133
113
134
114
@ Override
@@ -142,68 +122,7 @@ public <T> Object[] insert(List<RecordDescriptor<T>> recordDescriptors, Class<T>
142
122
143
123
String insertSql = sql (domainType ).getInsert (sqlParameterSources [0 ].getIdentifiers ());
144
124
145
- if (includeId ) {
146
- operations .batchUpdate (insertSql , sqlParameterSources );
147
- return new Object [sqlParameterSources .length ];
148
- }
149
- GeneratedKeyHolder holder = new GeneratedKeyHolder ();
150
-
151
- IdGeneration idGeneration = sqlGeneratorSource .getDialect ().getIdGeneration ();
152
-
153
- RelationalPersistentEntity <T > persistentEntity = getRequiredPersistentEntity (domainType );
154
- if (idGeneration .driverRequiresKeyColumnNames ()) {
155
-
156
- String [] keyColumnNames = getKeyColumnNames (persistentEntity .getType ());
157
- if (keyColumnNames .length == 0 ) {
158
- batchOperations .insert (insertSql , sqlParameterSources , holder );
159
- } else {
160
- batchOperations .insert (insertSql , sqlParameterSources , holder , keyColumnNames );
161
- }
162
- } else {
163
- batchOperations .insert (insertSql , sqlParameterSources , holder );
164
- }
165
- // TODO: Is this needed?
166
- if (!persistentEntity .hasIdProperty ()) {
167
- return new Object [sqlParameterSources .length ];
168
- }
169
- // TODO: Duplicated in #getIdFromHolder - consider refactoring
170
- Object [] ids = new Object [sqlParameterSources .length ];
171
- List <Map <String , Object >> keyList = holder .getKeyList ();
172
- for (int i = 0 ; i < keyList .size (); i ++) {
173
- Map <String , Object > keys = keyList .get (i );
174
- Object id ;
175
- if (keys .size () > 1 ) {
176
- id = keys .get (persistentEntity .getIdColumn ().getReference (getIdentifierProcessing ()));
177
- } else {
178
- id = keys .entrySet ().stream ().findFirst () //
179
- .map (Map .Entry ::getValue ) //
180
- .orElseThrow (() -> new IllegalStateException ("KeyHolder contains an empty key list." ));
181
- }
182
- ids [i ] = id ;
183
- }
184
- return ids ;
185
- }
186
-
187
- @ Nullable
188
- private <T > Object executeInsertAndReturnGeneratedId (RelationalPersistentEntity <T > persistentEntity , SqlIdentifierParameterSource parameterSource , String insertSql ) {
189
-
190
- KeyHolder holder = new GeneratedKeyHolder ();
191
-
192
- IdGeneration idGeneration = sqlGeneratorSource .getDialect ().getIdGeneration ();
193
-
194
- if (idGeneration .driverRequiresKeyColumnNames ()) {
195
-
196
- String [] keyColumnNames = getKeyColumnNames (persistentEntity .getType ());
197
- if (keyColumnNames .length == 0 ) {
198
- operations .update (insertSql , parameterSource , holder );
199
- } else {
200
- operations .update (insertSql , parameterSource , holder , keyColumnNames );
201
- }
202
- } else {
203
- operations .update (insertSql , parameterSource , holder );
204
- }
205
-
206
- return getIdFromHolder (holder , persistentEntity );
125
+ return insertStrategyFactory .insertStrategy (!includeId , getIdColumn (domainType )).execute (insertSql , sqlParameterSources );
207
126
}
208
127
209
128
/*
@@ -476,26 +395,6 @@ public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {
476
395
return operations .query (sql (domainType ).getFindAll (pageable ), (RowMapper <T >) getEntityRowMapper (domainType ));
477
396
}
478
397
479
- @ Nullable
480
- private <S > Object getIdFromHolder (KeyHolder holder , RelationalPersistentEntity <S > persistentEntity ) {
481
-
482
- try {
483
- // MySQL just returns one value with a special name
484
- return holder .getKey ();
485
- } catch (DataRetrievalFailureException | InvalidDataAccessApiUsageException e ) {
486
- // Postgres returns a value for each column
487
- // MS SQL Server returns a value that might be null.
488
-
489
- Map <String , Object > keys = holder .getKeys ();
490
-
491
- if (keys == null || persistentEntity .getIdProperty () == null ) {
492
- return null ;
493
- }
494
-
495
- return keys .get (persistentEntity .getIdColumn ().getReference (getIdentifierProcessing ()));
496
- }
497
- }
498
-
499
398
private EntityRowMapper <?> getEntityRowMapper (Class <?> domainType ) {
500
399
return new EntityRowMapper <>(getRequiredPersistentEntity (domainType ), converter );
501
400
}
@@ -588,17 +487,10 @@ private SqlGenerator sql(Class<?> domainType) {
588
487
return sqlGeneratorSource .getSqlGenerator (domainType );
589
488
}
590
489
591
- private <T > String [] getKeyColumnNames (Class <T > domainType ) {
592
-
593
- RelationalPersistentEntity <?> requiredPersistentEntity = context .getRequiredPersistentEntity (domainType );
594
-
595
- if (!requiredPersistentEntity .hasIdProperty ()) {
596
- return new String [0 ];
597
- }
598
-
599
- SqlIdentifier idColumn = requiredPersistentEntity .getIdColumn ();
600
-
601
- return new String [] { idColumn .getReference (getIdentifierProcessing ()) };
490
+ @ Nullable
491
+ private <T > SqlIdentifier getIdColumn (Class <T > domainType ) {
492
+ return Optional .ofNullable (context .getRequiredPersistentEntity (domainType ).getIdProperty ())
493
+ .map (RelationalPersistentProperty ::getColumnName )
494
+ .orElse (null );
602
495
}
603
-
604
496
}
0 commit comments