From e363be76d9260c6c2197f7b6e6c458f59850139b Mon Sep 17 00:00:00 2001 From: Chirag Tailor Date: Fri, 11 Feb 2022 09:44:40 -0600 Subject: [PATCH 1/8] Add JdbcOperations#batchUpdate method that takes PreparedStatementCreator. --- .../jdbc/core/JdbcOperations.java | 15 ++++ .../jdbc/core/JdbcTemplate.java | 86 +++++++++++-------- 2 files changed, 63 insertions(+), 38 deletions(-) diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcOperations.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcOperations.java index 7910753806aa..e8bb59cec561 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcOperations.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcOperations.java @@ -990,6 +990,21 @@ List queryForList(String sql, Object[] args, int[] argTypes, Class ele */ int[] batchUpdate(String sql, BatchPreparedStatementSetter pss) throws DataAccessException; + /** + * Issue multiple update statements on a single PreparedStatement, + * using batch updates and a BatchPreparedStatementSetter to set values. + *

Will fall back to separate updates on a single PreparedStatement + * if the JDBC driver does not support batch updates. + * @param psc a callback that creates a PreparedStatement given a Connection + * @param pss object to set parameters on the PreparedStatement + * created by this method + * @return an array of the number of rows affected by each statement + * (may also contain special JDBC-defined negative values for affected rows such as + * {@link java.sql.Statement#SUCCESS_NO_INFO}/{@link java.sql.Statement#EXECUTE_FAILED}) + * @throws DataAccessException if there is any problem issuing the update + */ + int[] batchUpdate(PreparedStatementCreator psc, BatchPreparedStatementSetter pss) throws DataAccessException; + /** * Execute a batch using the supplied SQL statement with the batch of supplied arguments. * @param sql the SQL statement to execute diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java index 179a6b79e166..d6116dbb346d 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java @@ -1025,50 +1025,21 @@ public int update(String sql, @Nullable Object... args) throws DataAccessExcepti return update(sql, newArgPreparedStatementSetter(args)); } + @Override + public int[] batchUpdate(PreparedStatementCreator psc, final BatchPreparedStatementSetter pss) throws DataAccessException { + int[] result = execute(psc, getPreparedStatementCallback(pss)); + + Assert.state(result != null, "No result array"); + return result; + } + @Override public int[] batchUpdate(String sql, final BatchPreparedStatementSetter pss) throws DataAccessException { if (logger.isDebugEnabled()) { logger.debug("Executing SQL batch update [" + sql + "]"); } - int[] result = execute(sql, (PreparedStatementCallback) ps -> { - try { - int batchSize = pss.getBatchSize(); - InterruptibleBatchPreparedStatementSetter ipss = - (pss instanceof InterruptibleBatchPreparedStatementSetter ? - (InterruptibleBatchPreparedStatementSetter) pss : null); - if (JdbcUtils.supportsBatchUpdates(ps.getConnection())) { - for (int i = 0; i < batchSize; i++) { - pss.setValues(ps, i); - if (ipss != null && ipss.isBatchExhausted(i)) { - break; - } - ps.addBatch(); - } - return ps.executeBatch(); - } - else { - List rowsAffected = new ArrayList<>(); - for (int i = 0; i < batchSize; i++) { - pss.setValues(ps, i); - if (ipss != null && ipss.isBatchExhausted(i)) { - break; - } - rowsAffected.add(ps.executeUpdate()); - } - int[] rowsAffectedArray = new int[rowsAffected.size()]; - for (int i = 0; i < rowsAffectedArray.length; i++) { - rowsAffectedArray[i] = rowsAffected.get(i); - } - return rowsAffectedArray; - } - } - finally { - if (pss instanceof ParameterDisposer) { - ((ParameterDisposer) pss).cleanupParameters(); - } - } - }); + int[] result = execute(sql, getPreparedStatementCallback(pss)); Assert.state(result != null, "No result array"); return result; @@ -1567,6 +1538,45 @@ private static int updateCount(@Nullable Integer result) { return result; } + private PreparedStatementCallback getPreparedStatementCallback(BatchPreparedStatementSetter pss) { + return ps -> { + try { + int batchSize = pss.getBatchSize(); + InterruptibleBatchPreparedStatementSetter ipss = + (pss instanceof InterruptibleBatchPreparedStatementSetter ? + (InterruptibleBatchPreparedStatementSetter) pss : null); + if (JdbcUtils.supportsBatchUpdates(ps.getConnection())) { + for (int i = 0; i < batchSize; i++) { + pss.setValues(ps, i); + if (ipss != null && ipss.isBatchExhausted(i)) { + break; + } + ps.addBatch(); + } + return ps.executeBatch(); + } else { + List rowsAffected = new ArrayList<>(); + for (int i = 0; i < batchSize; i++) { + pss.setValues(ps, i); + if (ipss != null && ipss.isBatchExhausted(i)) { + break; + } + rowsAffected.add(ps.executeUpdate()); + } + int[] rowsAffectedArray = new int[rowsAffected.size()]; + for (int i = 0; i < rowsAffectedArray.length; i++) { + rowsAffectedArray[i] = rowsAffected.get(i); + } + return rowsAffectedArray; + } + } finally { + if (pss instanceof ParameterDisposer) { + ((ParameterDisposer) pss).cleanupParameters(); + } + } + }; + } + /** * Invocation handler that suppresses close calls on JDBC Connections. From 21401951418a57206456a7ae612ad90d594f49fe Mon Sep 17 00:00:00 2001 From: Chirag Tailor Date: Fri, 11 Feb 2022 10:28:07 -0600 Subject: [PATCH 2/8] Add NamedParameterJdbcOperations#batchUpdate method that takes KeyHolder. This batch update uses the form of JdbcTemplate#batchUpdate that takes PreparedStatementCreator with the flag set to return generated keys. --- .../NamedParameterJdbcOperations.java | 14 ++++++++++ .../NamedParameterJdbcTemplate.java | 27 +++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcOperations.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcOperations.java index b308e06f735b..6c2abffc3d74 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcOperations.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcOperations.java @@ -549,4 +549,18 @@ int update(String sql, SqlParameterSource paramSource, KeyHolder generatedKeyHol */ int[] batchUpdate(String sql, SqlParameterSource[] batchArgs); + /** + * Execute a batch using the supplied SQL statement with the batch of supplied arguments, + * returning generated keys. + * @param sql the SQL statement to execute + * @param batchArgs the array of {@link SqlParameterSource} containing the batch of + * arguments for the query + * @param generatedKeyHolder a {@link KeyHolder} that will hold the generated keys + * @return an array containing the numbers of rows affected by each update in the batch + * (may also contain special JDBC-defined negative values for affected rows such as + * {@link java.sql.Statement#SUCCESS_NO_INFO}/{@link java.sql.Statement#EXECUTE_FAILED}) + * @throws DataAccessException if there is any problem issuing the update + * @see org.springframework.jdbc.support.GeneratedKeyHolder + */ + int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder); } diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java index ef7b6567dfcc..9b4f109f1682 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java @@ -385,6 +385,33 @@ public int getBatchSize() { }); } + @Override + public int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder keyHolder) { + if (batchArgs.length == 0) { + return new int[0]; + } + + ParsedSql parsedSql = getParsedSql(sql); + SqlParameterSource paramSource = batchArgs[0]; + PreparedStatementCreatorFactory pscf = getPreparedStatementCreatorFactory(parsedSql, paramSource); + pscf.setReturnGeneratedKeys(true); + Object[] params = NamedParameterUtils.buildValueArray(parsedSql, paramSource, null); + PreparedStatementCreator psc = pscf.newPreparedStatementCreator(params); + return getJdbcOperations().batchUpdate( + psc, + new BatchPreparedStatementSetter() { + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + Object[] values = NamedParameterUtils.buildValueArray(parsedSql, batchArgs[i], null); + pscf.newPreparedStatementSetter(values).setValues(ps); + } + @Override + public int getBatchSize() { + return batchArgs.length; + } + }); + } + /** * Build a {@link PreparedStatementCreator} based on the given SQL and named parameters. From a472fe98653eb61f05524c8c64cad65f82c100fd Mon Sep 17 00:00:00 2001 From: Chirag Tailor Date: Fri, 11 Feb 2022 10:48:22 -0600 Subject: [PATCH 3/8] Pass KeyHolder param from NamedParameterJdbcOperations to JdbcOperations. --- .../org/springframework/jdbc/core/JdbcOperations.java | 8 +++++++- .../java/org/springframework/jdbc/core/JdbcTemplate.java | 2 +- .../jdbc/core/namedparam/NamedParameterJdbcTemplate.java | 3 ++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcOperations.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcOperations.java index e8bb59cec561..4bd49112fde8 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcOperations.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcOperations.java @@ -993,17 +993,23 @@ List queryForList(String sql, Object[] args, int[] argTypes, Class ele /** * Issue multiple update statements on a single PreparedStatement, * using batch updates and a BatchPreparedStatementSetter to set values. + * Generated keys will be put into the given KeyHolder. + *

Note that the given PreparedStatementCreator has to create a statement + * with activated extraction of generated keys (a JDBC 3.0 feature). This can + * either be done directly or through using a PreparedStatementCreatorFactory. *

Will fall back to separate updates on a single PreparedStatement * if the JDBC driver does not support batch updates. * @param psc a callback that creates a PreparedStatement given a Connection * @param pss object to set parameters on the PreparedStatement * created by this method + * @param generatedKeyHolder a KeyHolder that will hold the generated keys * @return an array of the number of rows affected by each statement * (may also contain special JDBC-defined negative values for affected rows such as * {@link java.sql.Statement#SUCCESS_NO_INFO}/{@link java.sql.Statement#EXECUTE_FAILED}) * @throws DataAccessException if there is any problem issuing the update + * @see org.springframework.jdbc.support.GeneratedKeyHolder */ - int[] batchUpdate(PreparedStatementCreator psc, BatchPreparedStatementSetter pss) throws DataAccessException; + int[] batchUpdate(PreparedStatementCreator psc, BatchPreparedStatementSetter pss, KeyHolder generatedKeyHolder) throws DataAccessException; /** * Execute a batch using the supplied SQL statement with the batch of supplied arguments. diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java index d6116dbb346d..900aca92afd7 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java @@ -1026,7 +1026,7 @@ public int update(String sql, @Nullable Object... args) throws DataAccessExcepti } @Override - public int[] batchUpdate(PreparedStatementCreator psc, final BatchPreparedStatementSetter pss) throws DataAccessException { + public int[] batchUpdate(final PreparedStatementCreator psc, final BatchPreparedStatementSetter pss, final KeyHolder generatedKeyHolder) throws DataAccessException { int[] result = execute(psc, getPreparedStatementCallback(pss)); Assert.state(result != null, "No result array"); diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java index 9b4f109f1682..e2343ac6297e 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java @@ -409,7 +409,8 @@ public void setValues(PreparedStatement ps, int i) throws SQLException { public int getBatchSize() { return batchArgs.length; } - }); + }, + keyHolder); } From 1a8bebd59c9ad17eaee5026934e21edd26285de2 Mon Sep 17 00:00:00 2001 From: Chirag Tailor Date: Fri, 11 Feb 2022 12:13:24 -0600 Subject: [PATCH 4/8] Store generated keys in KeyHolder after executing batch update. --- .../jdbc/core/JdbcTemplate.java | 45 +++++++++++-------- .../jdbc/core/JdbcTemplateTests.java | 38 ++++++++++++++++ 2 files changed, 65 insertions(+), 18 deletions(-) diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java index 900aca92afd7..6e1d2721f5b7 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java @@ -990,21 +990,9 @@ public int update(final PreparedStatementCreator psc, final KeyHolder generatedK return updateCount(execute(psc, ps -> { int rows = ps.executeUpdate(); - List> generatedKeys = generatedKeyHolder.getKeyList(); - generatedKeys.clear(); - ResultSet keys = ps.getGeneratedKeys(); - if (keys != null) { - try { - RowMapperResultSetExtractor> rse = - new RowMapperResultSetExtractor<>(getColumnMapRowMapper(), 1); - generatedKeys.addAll(result(rse.extractData(keys))); - } - finally { - JdbcUtils.closeResultSet(keys); - } - } + storeGeneratedKeys(generatedKeyHolder).doInPreparedStatement(ps); if (logger.isTraceEnabled()) { - logger.trace("SQL update affected " + rows + " rows and returned " + generatedKeys.size() + " keys"); + logger.trace("SQL update affected " + rows + " rows and returned " + generatedKeyHolder.getKeyList().size() + " keys"); } return rows; }, true)); @@ -1027,7 +1015,7 @@ public int update(String sql, @Nullable Object... args) throws DataAccessExcepti @Override public int[] batchUpdate(final PreparedStatementCreator psc, final BatchPreparedStatementSetter pss, final KeyHolder generatedKeyHolder) throws DataAccessException { - int[] result = execute(psc, getPreparedStatementCallback(pss)); + int[] result = execute(psc, getPreparedStatementCallback(pss, storeGeneratedKeys(generatedKeyHolder))); Assert.state(result != null, "No result array"); return result; @@ -1039,7 +1027,7 @@ public int[] batchUpdate(String sql, final BatchPreparedStatementSetter pss) thr logger.debug("Executing SQL batch update [" + sql + "]"); } - int[] result = execute(sql, getPreparedStatementCallback(pss)); + int[] result = execute(sql, getPreparedStatementCallback(pss, ps -> null)); Assert.state(result != null, "No result array"); return result; @@ -1538,7 +1526,26 @@ private static int updateCount(@Nullable Integer result) { return result; } - private PreparedStatementCallback getPreparedStatementCallback(BatchPreparedStatementSetter pss) { + private PreparedStatementCallback storeGeneratedKeys(KeyHolder generatedKeyHolder) { + return ps -> { + List> generatedKeys = generatedKeyHolder.getKeyList(); + generatedKeys.clear(); + ResultSet keys = ps.getGeneratedKeys(); + if (keys != null) { + try { + RowMapperResultSetExtractor> rse = + new RowMapperResultSetExtractor<>(getColumnMapRowMapper(), 1); + generatedKeys.addAll(result(rse.extractData(keys))); + } + finally { + JdbcUtils.closeResultSet(keys); + } + } + return null; + }; + } + + private PreparedStatementCallback getPreparedStatementCallback(BatchPreparedStatementSetter pss, PreparedStatementCallback afterUpdateCallback) { return ps -> { try { int batchSize = pss.getBatchSize(); @@ -1553,7 +1560,9 @@ private PreparedStatementCallback getPreparedStatementCallback(BatchPrepa } ps.addBatch(); } - return ps.executeBatch(); + int[] results = ps.executeBatch(); + afterUpdateCallback.doInPreparedStatement(ps); + return results; } else { List rowsAffected = new ArrayList<>(); for (int i = 0; i < batchSize; i++) { diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java index 456d59dd5bfd..33408205ecf2 100644 --- a/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java +++ b/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java @@ -47,6 +47,8 @@ import org.springframework.jdbc.core.support.AbstractInterruptibleBatchPreparedStatementSetter; import org.springframework.jdbc.datasource.ConnectionProxy; import org.springframework.jdbc.datasource.SingleConnectionDataSource; +import org.springframework.jdbc.support.GeneratedKeyHolder; +import org.springframework.jdbc.support.KeyHolder; import org.springframework.jdbc.support.SQLErrorCodeSQLExceptionTranslator; import org.springframework.jdbc.support.SQLStateSQLExceptionTranslator; import org.springframework.util.LinkedCaseInsensitiveMap; @@ -1085,6 +1087,42 @@ public void testEquallyNamedColumn() throws SQLException { assertThat(map.get("x")).isEqualTo("first value"); } + @Test + void testBatchUpdateReturnsGeneratedKeys() throws SQLException { + final int[] rowsAffected = new int[] {1, 2}; + given(this.preparedStatement.executeBatch()).willReturn(rowsAffected); + DatabaseMetaData databaseMetaData = mock(DatabaseMetaData.class); + given(databaseMetaData.supportsBatchUpdates()).willReturn(true); + given(this.connection.getMetaData()).willReturn(databaseMetaData); + ResultSet generatedKeysResultSet = mock(ResultSet.class); + ResultSetMetaData rsmd = mock(ResultSetMetaData.class); + given(rsmd.getColumnCount()).willReturn(1); + given(rsmd.getColumnLabel(1)).willReturn("someId"); + given(generatedKeysResultSet.getMetaData()).willReturn(rsmd); + given(generatedKeysResultSet.getObject(1)).willReturn(123, 456); + given(generatedKeysResultSet.next()).willReturn(true, true, false); + given(this.preparedStatement.getGeneratedKeys()).willReturn(generatedKeysResultSet); + + int[] values = new int[]{100, 200}; + BatchPreparedStatementSetter bpss = new BatchPreparedStatementSetter() { + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + ps.setObject(i, values[i]); + } + + @Override + public int getBatchSize() { + return 2; + } + }; + + KeyHolder keyHolder = new GeneratedKeyHolder(); + this.template.batchUpdate(con -> con.prepareStatement(""), bpss, keyHolder); + + assertThat(keyHolder.getKeyList()).containsExactly( + Collections.singletonMap("someId", 123), + Collections.singletonMap("someId", 456)); + } private void mockDatabaseMetaData(boolean supportsBatchUpdates) throws SQLException { DatabaseMetaData databaseMetaData = mock(DatabaseMetaData.class); From ef028b5158ee7b7dde99e9abb2f5efc771b0a32f Mon Sep 17 00:00:00 2001 From: Chirag Tailor Date: Fri, 11 Feb 2022 09:44:40 -0600 Subject: [PATCH 5/8] Add NamedParameterJdbcOperations#batchUpdate method that takes KeyHolder and keyColumnNames. --- .../jdbc/core/JdbcTemplate.java | 1 + .../NamedParameterJdbcOperations.java | 16 ++++++++ .../NamedParameterJdbcTemplate.java | 15 +++++-- .../jdbc/core/JdbcTemplateTests.java | 39 ++++++++++++++++++- 4 files changed, 67 insertions(+), 4 deletions(-) diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java index 6e1d2721f5b7..7576a19df411 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java @@ -1573,6 +1573,7 @@ private PreparedStatementCallback getPreparedStatementCallback(BatchPrepa rowsAffected.add(ps.executeUpdate()); } int[] rowsAffectedArray = new int[rowsAffected.size()]; + afterUpdateCallback.doInPreparedStatement(ps); for (int i = 0; i < rowsAffectedArray.length; i++) { rowsAffectedArray[i] = rowsAffected.get(i); } diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcOperations.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcOperations.java index 6c2abffc3d74..e0a51e9f106b 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcOperations.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcOperations.java @@ -563,4 +563,20 @@ int update(String sql, SqlParameterSource paramSource, KeyHolder generatedKeyHol * @see org.springframework.jdbc.support.GeneratedKeyHolder */ int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder); + + /** + * Execute a batch using the supplied SQL statement with the batch of supplied arguments, + * returning generated keys. + * @param sql the SQL statement to execute + * @param batchArgs the array of {@link SqlParameterSource} containing the batch of + * arguments for the query + * @param generatedKeyHolder a {@link KeyHolder} that will hold the generated keys + * @param keyColumnNames names of the columns that will have keys generated for them + * @return an array containing the numbers of rows affected by each update in the batch + * (may also contain special JDBC-defined negative values for affected rows such as + * {@link java.sql.Statement#SUCCESS_NO_INFO}/{@link java.sql.Statement#EXECUTE_FAILED}) + * @throws DataAccessException if there is any problem issuing the update + * @see org.springframework.jdbc.support.GeneratedKeyHolder + */ + int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder, String[] keyColumnNames); } diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java index e2343ac6297e..b1e4b898dc63 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java @@ -386,7 +386,12 @@ public int getBatchSize() { } @Override - public int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder keyHolder) { + public int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder) { + return batchUpdate(sql, batchArgs, generatedKeyHolder, null); + } + + @Override + public int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder, String[] keyColumnNames) { if (batchArgs.length == 0) { return new int[0]; } @@ -394,7 +399,11 @@ public int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder k ParsedSql parsedSql = getParsedSql(sql); SqlParameterSource paramSource = batchArgs[0]; PreparedStatementCreatorFactory pscf = getPreparedStatementCreatorFactory(parsedSql, paramSource); - pscf.setReturnGeneratedKeys(true); + if (keyColumnNames != null) { + pscf.setGeneratedKeysColumnNames(keyColumnNames); + } else { + pscf.setReturnGeneratedKeys(true); + } Object[] params = NamedParameterUtils.buildValueArray(parsedSql, paramSource, null); PreparedStatementCreator psc = pscf.newPreparedStatementCreator(params); return getJdbcOperations().batchUpdate( @@ -410,7 +419,7 @@ public int getBatchSize() { return batchArgs.length; } }, - keyHolder); + generatedKeyHolder); } diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java index 33408205ecf2..d48351116ff6 100644 --- a/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java +++ b/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java @@ -1088,7 +1088,7 @@ public void testEquallyNamedColumn() throws SQLException { } @Test - void testBatchUpdateReturnsGeneratedKeys() throws SQLException { + void testBatchUpdateReturnsGeneratedKeys_whenDatabaseSupportsBatchUpdates() throws SQLException { final int[] rowsAffected = new int[] {1, 2}; given(this.preparedStatement.executeBatch()).willReturn(rowsAffected); DatabaseMetaData databaseMetaData = mock(DatabaseMetaData.class); @@ -1124,6 +1124,43 @@ public int getBatchSize() { Collections.singletonMap("someId", 456)); } + @Test + void testBatchUpdateReturnsGeneratedKeys_whenDatabaseDoesNotSupportBatchUpdates() throws SQLException { + final int[] rowsAffected = new int[] {1, 2}; + given(this.preparedStatement.executeBatch()).willReturn(rowsAffected); + DatabaseMetaData databaseMetaData = mock(DatabaseMetaData.class); + given(databaseMetaData.supportsBatchUpdates()).willReturn(false); + given(this.connection.getMetaData()).willReturn(databaseMetaData); + ResultSet generatedKeysResultSet = mock(ResultSet.class); + ResultSetMetaData rsmd = mock(ResultSetMetaData.class); + given(rsmd.getColumnCount()).willReturn(1); + given(rsmd.getColumnLabel(1)).willReturn("someId"); + given(generatedKeysResultSet.getMetaData()).willReturn(rsmd); + given(generatedKeysResultSet.getObject(1)).willReturn(123, 456); + given(generatedKeysResultSet.next()).willReturn(true, true, false); + given(this.preparedStatement.getGeneratedKeys()).willReturn(generatedKeysResultSet); + + int[] values = new int[]{100, 200}; + BatchPreparedStatementSetter bpss = new BatchPreparedStatementSetter() { + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + ps.setObject(i, values[i]); + } + + @Override + public int getBatchSize() { + return 2; + } + }; + + KeyHolder keyHolder = new GeneratedKeyHolder(); + this.template.batchUpdate(con -> con.prepareStatement(""), bpss, keyHolder); + + assertThat(keyHolder.getKeyList()).containsExactly( + Collections.singletonMap("someId", 123), + Collections.singletonMap("someId", 456)); + } + private void mockDatabaseMetaData(boolean supportsBatchUpdates) throws SQLException { DatabaseMetaData databaseMetaData = mock(DatabaseMetaData.class); given(databaseMetaData.getDatabaseProductName()).willReturn("MySQL"); From 61bb5454036ed0409e6bf9733a3d91f3fb7fae09 Mon Sep 17 00:00:00 2001 From: Chirag Tailor Date: Tue, 15 Mar 2022 12:51:15 -0500 Subject: [PATCH 6/8] Resolve checkstyle violations. --- .../java/org/springframework/jdbc/core/JdbcTemplate.java | 6 ++++-- .../jdbc/core/namedparam/NamedParameterJdbcTemplate.java | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java index 7576a19df411..f6145ef9ad90 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java @@ -1563,7 +1563,8 @@ private PreparedStatementCallback getPreparedStatementCallback(BatchPrepa int[] results = ps.executeBatch(); afterUpdateCallback.doInPreparedStatement(ps); return results; - } else { + } + else { List rowsAffected = new ArrayList<>(); for (int i = 0; i < batchSize; i++) { pss.setValues(ps, i); @@ -1579,7 +1580,8 @@ private PreparedStatementCallback getPreparedStatementCallback(BatchPrepa } return rowsAffectedArray; } - } finally { + } + finally { if (pss instanceof ParameterDisposer) { ((ParameterDisposer) pss).cleanupParameters(); } diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java index b1e4b898dc63..94779631f15e 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java @@ -401,7 +401,8 @@ public int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder g PreparedStatementCreatorFactory pscf = getPreparedStatementCreatorFactory(parsedSql, paramSource); if (keyColumnNames != null) { pscf.setGeneratedKeysColumnNames(keyColumnNames); - } else { + } + else { pscf.setReturnGeneratedKeys(true); } Object[] params = NamedParameterUtils.buildValueArray(parsedSql, paramSource, null); From d8fb8bef302da9d22fc3136ea19f6a58a3f05fe2 Mon Sep 17 00:00:00 2001 From: Chirag Tailor Date: Tue, 15 Mar 2022 13:40:51 -0500 Subject: [PATCH 7/8] Must extract generated keys with each update when db does not support batch operations. --- .../jdbc/core/JdbcTemplate.java | 48 ++++++++++--------- .../jdbc/core/JdbcTemplateTests.java | 14 ++++-- 2 files changed, 35 insertions(+), 27 deletions(-) diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java index f6145ef9ad90..7a127b26b786 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java @@ -990,7 +990,8 @@ public int update(final PreparedStatementCreator psc, final KeyHolder generatedK return updateCount(execute(psc, ps -> { int rows = ps.executeUpdate(); - storeGeneratedKeys(generatedKeyHolder).doInPreparedStatement(ps); + generatedKeyHolder.getKeyList().clear(); + storeGeneratedKeys(generatedKeyHolder, ps); if (logger.isTraceEnabled()) { logger.trace("SQL update affected " + rows + " rows and returned " + generatedKeyHolder.getKeyList().size() + " keys"); } @@ -1015,7 +1016,7 @@ public int update(String sql, @Nullable Object... args) throws DataAccessExcepti @Override public int[] batchUpdate(final PreparedStatementCreator psc, final BatchPreparedStatementSetter pss, final KeyHolder generatedKeyHolder) throws DataAccessException { - int[] result = execute(psc, getPreparedStatementCallback(pss, storeGeneratedKeys(generatedKeyHolder))); + int[] result = execute(psc, getPreparedStatementCallback(pss, generatedKeyHolder)); Assert.state(result != null, "No result array"); return result; @@ -1027,7 +1028,7 @@ public int[] batchUpdate(String sql, final BatchPreparedStatementSetter pss) thr logger.debug("Executing SQL batch update [" + sql + "]"); } - int[] result = execute(sql, getPreparedStatementCallback(pss, ps -> null)); + int[] result = execute(sql, getPreparedStatementCallback(pss, null)); Assert.state(result != null, "No result array"); return result; @@ -1526,32 +1527,31 @@ private static int updateCount(@Nullable Integer result) { return result; } - private PreparedStatementCallback storeGeneratedKeys(KeyHolder generatedKeyHolder) { - return ps -> { - List> generatedKeys = generatedKeyHolder.getKeyList(); - generatedKeys.clear(); - ResultSet keys = ps.getGeneratedKeys(); - if (keys != null) { - try { - RowMapperResultSetExtractor> rse = - new RowMapperResultSetExtractor<>(getColumnMapRowMapper(), 1); - generatedKeys.addAll(result(rse.extractData(keys))); - } - finally { - JdbcUtils.closeResultSet(keys); - } + private void storeGeneratedKeys(KeyHolder generatedKeyHolder, PreparedStatement ps) throws SQLException { + List> generatedKeys = generatedKeyHolder.getKeyList(); + ResultSet keys = ps.getGeneratedKeys(); + if (keys != null) { + try { + RowMapperResultSetExtractor> rse = + new RowMapperResultSetExtractor<>(getColumnMapRowMapper(), 1); + generatedKeys.addAll(result(rse.extractData(keys))); } - return null; - }; + finally { + JdbcUtils.closeResultSet(keys); + } + } } - private PreparedStatementCallback getPreparedStatementCallback(BatchPreparedStatementSetter pss, PreparedStatementCallback afterUpdateCallback) { + private PreparedStatementCallback getPreparedStatementCallback(BatchPreparedStatementSetter pss, @Nullable KeyHolder generatedKeyHolder) { return ps -> { try { int batchSize = pss.getBatchSize(); InterruptibleBatchPreparedStatementSetter ipss = (pss instanceof InterruptibleBatchPreparedStatementSetter ? (InterruptibleBatchPreparedStatementSetter) pss : null); + if (generatedKeyHolder != null) { + generatedKeyHolder.getKeyList().clear(); + } if (JdbcUtils.supportsBatchUpdates(ps.getConnection())) { for (int i = 0; i < batchSize; i++) { pss.setValues(ps, i); @@ -1561,7 +1561,9 @@ private PreparedStatementCallback getPreparedStatementCallback(BatchPrepa ps.addBatch(); } int[] results = ps.executeBatch(); - afterUpdateCallback.doInPreparedStatement(ps); + if (generatedKeyHolder != null) { + storeGeneratedKeys(generatedKeyHolder, ps); + } return results; } else { @@ -1572,9 +1574,11 @@ private PreparedStatementCallback getPreparedStatementCallback(BatchPrepa break; } rowsAffected.add(ps.executeUpdate()); + if (generatedKeyHolder != null) { + storeGeneratedKeys(generatedKeyHolder, ps); + } } int[] rowsAffectedArray = new int[rowsAffected.size()]; - afterUpdateCallback.doInPreparedStatement(ps); for (int i = 0; i < rowsAffectedArray.length; i++) { rowsAffectedArray[i] = rowsAffected.get(i); } diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java index d48351116ff6..044076b4622a 100644 --- a/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java +++ b/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java @@ -1131,14 +1131,18 @@ void testBatchUpdateReturnsGeneratedKeys_whenDatabaseDoesNotSupportBatchUpdates( DatabaseMetaData databaseMetaData = mock(DatabaseMetaData.class); given(databaseMetaData.supportsBatchUpdates()).willReturn(false); given(this.connection.getMetaData()).willReturn(databaseMetaData); - ResultSet generatedKeysResultSet = mock(ResultSet.class); ResultSetMetaData rsmd = mock(ResultSetMetaData.class); given(rsmd.getColumnCount()).willReturn(1); given(rsmd.getColumnLabel(1)).willReturn("someId"); - given(generatedKeysResultSet.getMetaData()).willReturn(rsmd); - given(generatedKeysResultSet.getObject(1)).willReturn(123, 456); - given(generatedKeysResultSet.next()).willReturn(true, true, false); - given(this.preparedStatement.getGeneratedKeys()).willReturn(generatedKeysResultSet); + ResultSet generatedKeysResultSet1 = mock(ResultSet.class); + given(generatedKeysResultSet1.getMetaData()).willReturn(rsmd); + given(generatedKeysResultSet1.getObject(1)).willReturn(123); + given(generatedKeysResultSet1.next()).willReturn(true, false); + ResultSet generatedKeysResultSet2 = mock(ResultSet.class); + given(generatedKeysResultSet2.getMetaData()).willReturn(rsmd); + given(generatedKeysResultSet2.getObject(1)).willReturn(456); + given(generatedKeysResultSet2.next()).willReturn(true, false); + given(this.preparedStatement.getGeneratedKeys()).willReturn(generatedKeysResultSet1, generatedKeysResultSet2); int[] values = new int[]{100, 200}; BatchPreparedStatementSetter bpss = new BatchPreparedStatementSetter() { From 15d3e1e27c3983dd9801ba5667ade0663b279346 Mon Sep 17 00:00:00 2001 From: Chirag Tailor Date: Tue, 15 Mar 2022 13:45:15 -0500 Subject: [PATCH 8/8] Parameterize the number of expected rows when storing generated keys. This is to avoid frequent resizing of the row mapper ArrayList when extracted keys for a batch operation. --- .../org/springframework/jdbc/core/JdbcTemplate.java | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java index 7a127b26b786..cd7d45306e0f 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java @@ -991,7 +991,7 @@ public int update(final PreparedStatementCreator psc, final KeyHolder generatedK return updateCount(execute(psc, ps -> { int rows = ps.executeUpdate(); generatedKeyHolder.getKeyList().clear(); - storeGeneratedKeys(generatedKeyHolder, ps); + storeGeneratedKeys(generatedKeyHolder, ps, 1); if (logger.isTraceEnabled()) { logger.trace("SQL update affected " + rows + " rows and returned " + generatedKeyHolder.getKeyList().size() + " keys"); } @@ -1527,13 +1527,13 @@ private static int updateCount(@Nullable Integer result) { return result; } - private void storeGeneratedKeys(KeyHolder generatedKeyHolder, PreparedStatement ps) throws SQLException { + private void storeGeneratedKeys(KeyHolder generatedKeyHolder, PreparedStatement ps, int rowsExpected) throws SQLException { List> generatedKeys = generatedKeyHolder.getKeyList(); ResultSet keys = ps.getGeneratedKeys(); if (keys != null) { try { RowMapperResultSetExtractor> rse = - new RowMapperResultSetExtractor<>(getColumnMapRowMapper(), 1); + new RowMapperResultSetExtractor<>(getColumnMapRowMapper(), rowsExpected); generatedKeys.addAll(result(rse.extractData(keys))); } finally { @@ -1562,7 +1562,7 @@ private PreparedStatementCallback getPreparedStatementCallback(BatchPrepa } int[] results = ps.executeBatch(); if (generatedKeyHolder != null) { - storeGeneratedKeys(generatedKeyHolder, ps); + storeGeneratedKeys(generatedKeyHolder, ps, batchSize); } return results; } @@ -1575,7 +1575,7 @@ private PreparedStatementCallback getPreparedStatementCallback(BatchPrepa } rowsAffected.add(ps.executeUpdate()); if (generatedKeyHolder != null) { - storeGeneratedKeys(generatedKeyHolder, ps); + storeGeneratedKeys(generatedKeyHolder, ps, 1); } } int[] rowsAffectedArray = new int[rowsAffected.size()];