Skip to content

Commit 87d20ec

Browse files
DATAMONGO-2130 - Fix Repository count inside transaction.
We now make sure invocations on repository count methods delegate to countDocuments when inside an transaction.
1 parent 010e653 commit 87d20ec

File tree

8 files changed

+121
-19
lines changed

8 files changed

+121
-19
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/MongoDatabaseUtils.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,34 @@ private static MongoDatabase doGetMongoDatabase(@Nullable String dbName, MongoDb
110110

111111
ClientSession session = doGetSession(factory, sessionSynchronization);
112112

113-
if(session == null) {
113+
if (session == null) {
114114
return StringUtils.hasText(dbName) ? factory.getDb(dbName) : factory.getDb();
115115
}
116116

117117
MongoDbFactory factoryToUse = factory.withSession(session);
118118
return StringUtils.hasText(dbName) ? factoryToUse.getDb(dbName) : factoryToUse.getDb();
119119
}
120120

121+
/**
122+
* Check if the {@link MongoDbFactory} is actually bound to a {@link ClientSession} that has an active transaction, or
123+
* if a {@link TransactionSynchronization} has been registered for the {@link MongoDbFactory resource} and if the
124+
* associated {@link ClientSession} has an {@link ClientSession#hasActiveTransaction() active transaction}.
125+
*
126+
* @param dbFactory the resource to check transactions for. Must not be {@literal null}.
127+
* @return {@literal true} if the factory has an ongoing transaction.
128+
* @since 2.1.3
129+
*/
130+
public static boolean isTransactionActive(MongoDbFactory dbFactory) {
131+
132+
if (dbFactory.isTransactionActive()) {
133+
return true;
134+
}
135+
136+
MongoResourceHolder resourceHolder = (MongoResourceHolder) TransactionSynchronizationManager.getResource(dbFactory);
137+
return resourceHolder != null
138+
&& (resourceHolder.hasSession() && resourceHolder.getSession().hasActiveTransaction());
139+
}
140+
121141
@Nullable
122142
private static ClientSession doGetSession(MongoDbFactory dbFactory, SessionSynchronization sessionSynchronization) {
123143

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/MongoDbFactory.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,15 @@ default MongoDbFactory withSession(ClientSessionOptions options) {
108108
* @since 2.1
109109
*/
110110
MongoDbFactory withSession(ClientSession session);
111+
112+
/**
113+
* Returns if the given {@link MongoDbFactory} is bound to a {@link ClientSession} that has an
114+
* {@link ClientSession#hasActiveTransaction() active transaction}.
115+
*
116+
* @return {@literal true} if there's an active transaction, {@literal false} otherwise.
117+
* @since 2.1.3
118+
*/
119+
default boolean isTransactionActive() {
120+
return false;
121+
}
111122
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoDbFactorySupport.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,15 @@ public MongoDbFactory withSession(ClientSession session) {
233233
return delegate.withSession(session);
234234
}
235235

236+
/*
237+
* (non-Javadoc)
238+
* @see org.springframework.data.mongodb.MongoDbFactory#isTransactionActive()
239+
*/
240+
@Override
241+
public boolean isTransactionActive() {
242+
return session != null && session.hasActiveTransaction();
243+
}
244+
236245
private MongoDatabase proxyMongoDatabase(MongoDatabase database) {
237246
return createProxyInstance(session, database, MongoDatabase.class);
238247
}
@@ -241,7 +250,8 @@ private MongoDatabase proxyDatabase(com.mongodb.session.ClientSession session, M
241250
return createProxyInstance(session, database, MongoDatabase.class);
242251
}
243252

244-
private MongoCollection<?> proxyCollection(com.mongodb.session.ClientSession session, MongoCollection<?> collection) {
253+
private MongoCollection<?> proxyCollection(com.mongodb.session.ClientSession session,
254+
MongoCollection<?> collection) {
245255
return createProxyInstance(session, collection, MongoCollection.class);
246256
}
247257

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,7 +1119,16 @@ public long count(Query query, @Nullable Class<?> entityClass, String collection
11191119
Document document = queryMapper.getMappedObject(query.getQueryObject(),
11201120
Optional.ofNullable(entityClass).map(it -> mappingContext.getPersistentEntity(entityClass)));
11211121

1122-
return execute(collectionName, collection -> collection.count(document, options));
1122+
return doCount(collectionName, document, options);
1123+
}
1124+
1125+
protected long doCount(String collectionName, Document filter, CountOptions options) {
1126+
1127+
if (!MongoDatabaseUtils.isTransactionActive(getMongoDbFactory())) {
1128+
return execute(collectionName, collection -> collection.count(filter, options));
1129+
}
1130+
1131+
return execute(collectionName, collection -> collection.countDocuments(filter, options));
11231132
}
11241133

11251134
/*
@@ -3343,23 +3352,16 @@ public MongoDatabase getDb() {
33433352

33443353
/*
33453354
* (non-Javadoc)
3346-
* @see org.springframework.data.mongodb.core.MongoTemplate#count(org.springframework.data.mongodb.core.query.Query, java.lang.Class, java.lang.String)
3355+
* @see org.springframework.data.mongodb.core.MongoTemplate#doCount(java.lang.String, org.bson.Document, com.mongodb.client.model.CountOptions)
33473356
*/
33483357
@Override
3349-
@SuppressWarnings("unchecked")
3350-
public long count(Query query, @Nullable Class<?> entityClass, String collectionName) {
3358+
protected long doCount(String collectionName, Document filter, CountOptions options) {
33513359

33523360
if (!session.hasActiveTransaction()) {
3353-
return super.count(query, entityClass, collectionName);
3361+
return super.doCount(collectionName, filter, options);
33543362
}
33553363

3356-
CountOptions options = new CountOptions();
3357-
query.getCollation().map(Collation::toMongoCollation).ifPresent(options::collation);
3358-
3359-
Document document = delegate.queryMapper.getMappedObject(query.getQueryObject(),
3360-
Optional.ofNullable(entityClass).map(it -> delegate.mappingContext.getPersistentEntity(entityClass)));
3361-
3362-
return execute(collectionName, collection -> collection.countDocuments(document, options));
3364+
return execute(collectionName, collection -> collection.countDocuments(filter, options));
33633365
}
33643366
}
33653367
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/SimpleMongoRepository.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ public boolean existsById(ID id) {
137137
*/
138138
@Override
139139
public long count() {
140-
return mongoOperations.getCollection(entityInformation.getCollectionName()).count();
140+
return mongoOperations.count(new Query(), entityInformation.getCollectionName());
141141
}
142142

143143
/*

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/MongoDatabaseUtilsUnitTests.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,39 @@ public void verifyTransactionSynchronizationManagerState() {
7878
assertFalse(TransactionSynchronizationManager.isActualTransactionActive());
7979
}
8080

81+
@Test // DATAMONGO-2130
82+
public void isTransactionActiveShouldDetectTxViaFactory() {
83+
84+
when(dbFactory.isTransactionActive()).thenReturn(true);
85+
86+
assertThat(MongoDatabaseUtils.isTransactionActive(dbFactory)).isTrue();
87+
}
88+
89+
@Test // DATAMONGO-2130
90+
public void isTransactionActiveShouldReturnFalseIfNoTxActive() {
91+
92+
when(dbFactory.isTransactionActive()).thenReturn(false);
93+
94+
assertThat(MongoDatabaseUtils.isTransactionActive(dbFactory)).isFalse();
95+
}
96+
97+
@Test // DATAMONGO-2130
98+
public void isTransactionActiveShouldLookupTxForActiveTransactionSynchronizationViaTxManager() {
99+
100+
when(dbFactory.isTransactionActive()).thenReturn(false);
101+
102+
MongoTransactionManager txManager = new MongoTransactionManager(dbFactory);
103+
TransactionTemplate txTemplate = new TransactionTemplate(txManager);
104+
105+
txTemplate.execute(new TransactionCallbackWithoutResult() {
106+
107+
@Override
108+
protected void doInTransactionWithoutResult(TransactionStatus transactionStatus) {
109+
assertThat(MongoDatabaseUtils.isTransactionActive(dbFactory)).isTrue();
110+
}
111+
});
112+
}
113+
81114
@Test // DATAMONGO-1920
82115
public void shouldNotStartSessionWhenNoTransactionOngoing() {
83116

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/config/ServerAddressPropertyEditorUnitTests.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import static org.hamcrest.Matchers.*;
1919
import static org.junit.Assert.*;
2020

21+
import java.io.IOException;
2122
import java.net.InetAddress;
2223
import java.net.UnknownHostException;
2324
import java.util.Arrays;
@@ -124,7 +125,8 @@ public void handleIPv6HostaddressLoopbackLongWithBrackets() throws UnknownHostEx
124125
* We can't tell whether the last part of the hostAddress represents a port or not.
125126
*/
126127
@Test // DATAMONGO-808
127-
public void shouldFailToHandleAmbiguousIPv6HostaddressLongWithoutPortAndWithoutBrackets() throws UnknownHostException {
128+
public void shouldFailToHandleAmbiguousIPv6HostaddressLongWithoutPortAndWithoutBrackets()
129+
throws UnknownHostException {
128130

129131
expectedException.expect(IllegalArgumentException.class);
130132

@@ -173,9 +175,9 @@ private void assertUnresolveableHostnames(String... hostnames) {
173175

174176
for (String hostname : hostnames) {
175177
try {
176-
InetAddress.getByName(hostname);
178+
InetAddress.getByName(hostname).isReachable(1500);
177179
Assert.fail("Supposedly unresolveable hostname '" + hostname + "' can be resolved.");
178-
} catch (UnknownHostException expected) {
180+
} catch (IOException expected) {
179181
// ok
180182
}
181183
}

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/support/SimpleMongoRepositoryTests.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@
3232
import org.junit.runner.RunWith;
3333
import org.springframework.beans.factory.annotation.Autowired;
3434
import org.springframework.data.domain.Example;
35-
import org.springframework.data.domain.ExampleMatcher.StringMatcher;
3635
import org.springframework.data.domain.Page;
3736
import org.springframework.data.domain.PageRequest;
37+
import org.springframework.data.domain.ExampleMatcher.*;
3838
import org.springframework.data.geo.Point;
39+
import org.springframework.data.mongodb.MongoTransactionManager;
3940
import org.springframework.data.mongodb.core.MongoTemplate;
4041
import org.springframework.data.mongodb.core.geo.GeoJsonPoint;
4142
import org.springframework.data.mongodb.core.mapping.Document;
@@ -47,6 +48,7 @@
4748
import org.springframework.test.context.ContextConfiguration;
4849
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
4950
import org.springframework.test.util.ReflectionTestUtils;
51+
import org.springframework.transaction.support.TransactionTemplate;
5052

5153
/**
5254
* @author A. B. M. Kowser
@@ -383,6 +385,28 @@ public void saveAllUsesEntityCollection() {
383385
assertThat(repository.findAll()).containsExactlyInAnyOrder(first, second);
384386
}
385387

388+
@Test // DATAMONGO-2130
389+
public void countShouldBePossibleInTransaction() {
390+
391+
MongoTransactionManager txmgr = new MongoTransactionManager(template.getMongoDbFactory());
392+
TransactionTemplate tt = new TransactionTemplate(txmgr);
393+
tt.afterPropertiesSet();
394+
395+
long countPreTx = repository.count();
396+
397+
long count = tt.execute(status -> {
398+
399+
Person sample = new Person();
400+
sample.setLastname("Matthews");
401+
402+
repository.save(sample);
403+
404+
return repository.count();
405+
});
406+
407+
assertThat(count).isEqualTo(countPreTx+1);
408+
}
409+
386410
private void assertThatAllReferencePersonsWereStoredCorrectly(Map<String, Person> references, List<Person> saved) {
387411

388412
for (Person person : saved) {

0 commit comments

Comments
 (0)