diff --git a/pom.xml b/pom.xml index ded4d85d02..d81c079a2d 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-mongodb-parent - 4.5.0-SNAPSHOT + 4.5.x-GH-4918-SNAPSHOT pom Spring Data MongoDB diff --git a/spring-data-mongodb-distribution/pom.xml b/spring-data-mongodb-distribution/pom.xml index 58c63dfc97..d7a46f6f11 100644 --- a/spring-data-mongodb-distribution/pom.xml +++ b/spring-data-mongodb-distribution/pom.xml @@ -15,7 +15,7 @@ org.springframework.data spring-data-mongodb-parent - 4.5.0-SNAPSHOT + 4.5.x-GH-4918-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/pom.xml b/spring-data-mongodb/pom.xml index 37e68c6f78..90cfbc7d74 100644 --- a/spring-data-mongodb/pom.xml +++ b/spring-data-mongodb/pom.xml @@ -13,7 +13,7 @@ org.springframework.data spring-data-mongodb-parent - 4.5.0-SNAPSHOT + 4.5.x-GH-4918-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/BasicUpdate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/BasicUpdate.java index bf29d25e6b..12843ce622 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/BasicUpdate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/BasicUpdate.java @@ -15,13 +15,21 @@ */ package org.springframework.data.mongodb.core.query; -import java.util.Arrays; +import java.util.ArrayList; import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; import org.bson.Document; + import org.springframework.lang.Nullable; +import org.springframework.util.ClassUtils; /** + * {@link Document}-based {@link Update} variant. + * * @author Thomas Risberg * @author John Brisbin * @author Oliver Gierke @@ -33,74 +41,114 @@ public class BasicUpdate extends Update { private final Document updateObject; public BasicUpdate(String updateString) { - super(); - this.updateObject = Document.parse(updateString); + this(Document.parse(updateString)); } public BasicUpdate(Document updateObject) { - super(); this.updateObject = updateObject; } @Override public Update set(String key, @Nullable Object value) { - updateObject.put("$set", Collections.singletonMap(key, value)); + setOperationValue("$set", key, value); return this; } @Override public Update unset(String key) { - updateObject.put("$unset", Collections.singletonMap(key, 1)); + setOperationValue("$unset", key, 1); return this; } @Override public Update inc(String key, Number inc) { - updateObject.put("$inc", Collections.singletonMap(key, inc)); + setOperationValue("$inc", key, inc); return this; } @Override public Update push(String key, @Nullable Object value) { - updateObject.put("$push", Collections.singletonMap(key, value)); + setOperationValue("$push", key, value); return this; } @Override public Update addToSet(String key, @Nullable Object value) { - updateObject.put("$addToSet", Collections.singletonMap(key, value)); + setOperationValue("$addToSet", key, value); return this; } @Override public Update pop(String key, Position pos) { - updateObject.put("$pop", Collections.singletonMap(key, (pos == Position.FIRST ? -1 : 1))); + setOperationValue("$pop", key, (pos == Position.FIRST ? -1 : 1)); return this; } @Override public Update pull(String key, @Nullable Object value) { - updateObject.put("$pull", Collections.singletonMap(key, value)); + setOperationValue("$pull", key, value); return this; } @Override public Update pullAll(String key, Object[] values) { - Document keyValue = new Document(); - keyValue.put(key, Arrays.copyOf(values, values.length)); - updateObject.put("$pullAll", keyValue); + setOperationValue("$pullAll", key, List.of(values), (o, o2) -> { + + if (o instanceof List prev && o2 instanceof List currentValue) { + List merged = new ArrayList<>(prev.size() + currentValue.size()); + merged.addAll(prev); + merged.addAll(currentValue); + return merged; + } + + return o2; + }); return this; } @Override public Update rename(String oldName, String newName) { - updateObject.put("$rename", Collections.singletonMap(oldName, newName)); + setOperationValue("$rename", oldName, newName); return this; } + @Override + public boolean modifies(String key) { + return super.modifies(key) || Update.fromDocument(getUpdateObject()).modifies(key); + } + @Override public Document getUpdateObject() { return updateObject; } + void setOperationValue(String operator, String key, @Nullable Object value) { + setOperationValue(operator, key, value, (o, o2) -> o2); + } + + void setOperationValue(String operator, String key, @Nullable Object value, + BiFunction mergeFunction) { + + if (!updateObject.containsKey(operator)) { + updateObject.put(operator, Collections.singletonMap(key, value)); + } else { + Object o = updateObject.get(operator); + if (o instanceof Map existing) { + Map target = new LinkedHashMap<>(existing); + + if (target.containsKey(key)) { + target.put(key, mergeFunction.apply(target.get(key), value)); + } else { + target.put(key, value); + } + updateObject.put(operator, target); + } else { + throw new IllegalStateException( + "Cannot add ['%s' : { '%s' : ... }]. Operator already exists with value of type [%s] which is not suitable for appending" + .formatted(operator, key, + o != null ? ClassUtils.getShortName(o.getClass()) : "null")); + } + } + } + } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/Update.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/Update.java index e6da45a785..32d98f5804 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/Update.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/Update.java @@ -447,13 +447,11 @@ protected void addMultiFieldOperation(String operator, String key, @Nullable Obj if (existingValue == null) { keyValueMap = new Document(); this.modifierOps.put(operator, keyValueMap); + } else if (existingValue instanceof Document document) { + keyValueMap = document; } else { - if (existingValue instanceof Document document) { - keyValueMap = document; - } else { - throw new InvalidDataAccessApiUsageException( - "Modifier Operations should be a LinkedHashMap but was " + existingValue.getClass()); - } + throw new InvalidDataAccessApiUsageException( + "Modifier Operations should be a LinkedHashMap but was " + existingValue.getClass()); } keyValueMap.put(key, value); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUpdateTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUpdateTests.java index 10a0202d93..4249506d77 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUpdateTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUpdateTests.java @@ -41,6 +41,7 @@ import org.springframework.data.mongodb.core.aggregation.SetOperation; import org.springframework.data.mongodb.core.mapping.Document; import org.springframework.data.mongodb.core.mapping.Field; +import org.springframework.data.mongodb.core.query.BasicUpdate; import org.springframework.data.mongodb.core.query.Criteria; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.Update; @@ -326,6 +327,20 @@ void updateFirstWithSort(Class domainType, Sort sort, UpdateDefinition update "Science is real!"); } + @Test // GH-4918 + void updateShouldHonorVersionProvided() { + + Versioned source = template.insert(Versioned.class).one(new Versioned("id-1", "value-0")); + + Update update = new BasicUpdate("{ '$set' : { 'value' : 'changed' }, '$inc' : { 'version' : 10 } }"); + template.update(Versioned.class).matching(Query.query(Criteria.where("id").is(source.id))).apply(update).first(); + + assertThat( + collection(Versioned.class).find(new org.bson.Document("_id", source.id)).limit(1).into(new ArrayList<>())) + .containsExactly(new org.bson.Document("_id", source.id).append("version", 10L).append("value", "changed") + .append("_class", "org.springframework.data.mongodb.core.MongoTemplateUpdateTests$Versioned")); + } + private List all(Class type) { return collection(type).find(new org.bson.Document()).into(new ArrayList<>()); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/BasicUpdateUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/BasicUpdateUnitTests.java new file mode 100644 index 0000000000..dacc270230 --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/BasicUpdateUnitTests.java @@ -0,0 +1,114 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.query; + +import static org.assertj.core.api.Assertions.*; +import static org.springframework.data.mongodb.test.util.Assertions.*; +import static org.springframework.data.mongodb.test.util.Assertions.assertThat; + +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Stream; + +import org.bson.Document; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.MethodSource; + +import org.springframework.data.mongodb.core.query.Update.Position; + +/** + * Unit tests for {@link BasicUpdate}. + * + * @author Christoph Strobl + * @author Mark Paluch + */ +class BasicUpdateUnitTests { + + @Test // GH-4918 + void setOperationValueShouldAppendsOpsCorrectly() { + + BasicUpdate basicUpdate = new BasicUpdate("{}"); + basicUpdate.setOperationValue("$set", "key1", "alt"); + basicUpdate.setOperationValue("$set", "key2", "nps"); + basicUpdate.setOperationValue("$unset", "key3", "x"); + + assertThat(basicUpdate.getUpdateObject()) + .isEqualTo("{ '$set' : { 'key1' : 'alt', 'key2' : 'nps' }, '$unset' : { 'key3' : 'x' } }"); + } + + @Test // GH-4918 + void setOperationErrorsOnNonMapType() { + + BasicUpdate basicUpdate = new BasicUpdate("{ '$set' : 1 }"); + assertThatExceptionOfType(IllegalStateException.class) + .isThrownBy(() -> basicUpdate.setOperationValue("$set", "k", "v")); + } + + @ParameterizedTest // GH-4918 + @CsvSource({ // + "{ }, k1, false", // + "{ '$set' : { 'k1' : 'v1' } }, k1, true", // + "{ '$set' : { 'k1' : 'v1' } }, k2, false", // + "{ '$set' : { 'k1.k2' : 'v1' } }, k1, false", // + "{ '$set' : { 'k1.k2' : 'v1' } }, k1.k2, true", // + "{ '$set' : { 'k1' : 'v1' } }, '', false", // + "{ '$inc' : { 'k1' : 1 } }, k1, true" }) + void modifiesLooksUpKeyCorrectly(String source, String key, boolean modified) { + + BasicUpdate basicUpdate = new BasicUpdate(source); + assertThat(basicUpdate.modifies(key)).isEqualTo(modified); + } + + @ParameterizedTest // GH-4918 + @MethodSource("updateOpArgs") + void updateOpsShouldNotOverrideExistingValues(String operator, Function updateFunction) { + + Document source = Document.parse("{ '%s' : { 'key-1' : 'value-1' } }".formatted(operator)); + Update update = updateFunction.apply(new BasicUpdate(source)); + + assertThat(update.getUpdateObject()).containsEntry("%s.key-1".formatted(operator), "value-1") + .containsKey("%s.key-2".formatted(operator)); + } + + @Test // GH-4918 + void shouldNotOverridePullAll() { + + Document source = Document.parse("{ '$pullAll' : { 'key-1' : ['value-1'] } }"); + Update update = new BasicUpdate(source).pullAll("key-1", new String[] { "value-2" }).pullAll("key-2", + new String[] { "value-3" }); + + assertThat(update.getUpdateObject()).containsEntry("$pullAll.key-1", Arrays.asList("value-1", "value-2")) + .containsEntry("$pullAll.key-2", List.of("value-3")); + } + + static Stream updateOpArgs() { + return Stream.of( // + Arguments.of("$set", (Function) update -> update.set("key-2", "value-2")), + Arguments.of("$unset", (Function) update -> update.unset("key-2")), + Arguments.of("$inc", (Function) update -> update.inc("key-2", 1)), + Arguments.of("$push", (Function) update -> update.push("key-2", "value-2")), + Arguments.of("$addToSet", (Function) update -> update.addToSet("key-2", "value-2")), + Arguments.of("$pop", (Function) update -> update.pop("key-2", Position.FIRST)), + Arguments.of("$pull", (Function) update -> update.pull("key-2", "value-2")), + Arguments.of("$pullAll", + (Function) update -> update.pullAll("key-2", new String[] { "value-2" })), + Arguments.of("$rename", (Function) update -> update.rename("key-2", "value-2"))); + }; +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VersionedPersonRepositoryIntegrationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VersionedPersonRepositoryIntegrationTests.java new file mode 100644 index 0000000000..f4e1e0282e --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VersionedPersonRepositoryIntegrationTests.java @@ -0,0 +1,168 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository; + +import static org.assertj.core.api.Assertions.*; + +import org.bson.Document; +import org.bson.types.ObjectId; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.ComponentScan.Filter; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.FilterType; +import org.springframework.data.mongodb.config.AbstractMongoClientConfiguration; +import org.springframework.data.mongodb.core.MongoTemplate; +import org.springframework.data.mongodb.repository.config.EnableMongoRepositories; +import org.springframework.data.mongodb.test.util.Client; +import org.springframework.data.mongodb.test.util.MongoClientExtension; +import org.springframework.data.mongodb.test.util.MongoTestUtils; +import org.springframework.data.repository.CrudRepository; +import org.springframework.lang.Nullable; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit.jupiter.SpringExtension; + +import com.mongodb.client.MongoClient; + +/** + * Integration tests for Repositories using optimistic locking. + * + * @author Christoph Strobl + */ +@ExtendWith({ MongoClientExtension.class, SpringExtension.class }) +@ContextConfiguration +class VersionedPersonRepositoryIntegrationTests { + + static @Client MongoClient mongoClient; + + @Autowired VersionedPersonRepository versionedPersonRepository; + @Autowired MongoTemplate template; + + @Configuration + @EnableMongoRepositories(considerNestedRepositories = true, + includeFilters = @Filter(type = FilterType.ASSIGNABLE_TYPE, classes = VersionedPersonRepository.class)) + static class Config extends AbstractMongoClientConfiguration { + + @Override + protected String getDatabaseName() { + return "versioned-person-tests"; + } + + @Override + public MongoClient mongoClient() { + return mongoClient; + } + } + + @BeforeEach + void beforeEach() { + MongoTestUtils.flushCollection("versioned-person-tests", + template.getCollectionName(VersionedPersonWithCounter.class), mongoClient); + } + + @Test // GH-4918 + void updatesVersionedTypeCorrectly() { + + VersionedPerson person = template.insert(VersionedPersonWithCounter.class) + .one(new VersionedPersonWithCounter("Donald", "Duckling")); + + int updateCount = versionedPersonRepository.findAndSetFirstnameToLastnameByLastname(person.getLastname()); + + assertThat(updateCount).isOne(); + + Document document = template.execute(VersionedPersonWithCounter.class, collection -> { + return collection.find(new Document("_id", new ObjectId(person.getId()))).first(); + }); + + assertThat(document).containsEntry("firstname", "Duckling").containsEntry("version", 1L); + } + + @Test // GH-4918 + void updatesVersionedTypeCorrectlyWhenUpdateIsUsingInc() { + + VersionedPerson person = template.insert(VersionedPersonWithCounter.class) + .one(new VersionedPersonWithCounter("Donald", "Duckling")); + + int updateCount = versionedPersonRepository.findAndIncCounterByLastname(person.getLastname()); + + assertThat(updateCount).isOne(); + + Document document = template.execute(VersionedPersonWithCounter.class, collection -> { + return collection.find(new Document("_id", new ObjectId(person.getId()))).first(); + }); + + assertThat(document).containsEntry("lastname", "Duckling").containsEntry("version", 1L).containsEntry("counter", + 42); + } + + @Test // GH-4918 + void updatesVersionedTypeCorrectlyWhenUpdateCoversVersionBump() { + + VersionedPerson person = template.insert(VersionedPersonWithCounter.class) + .one(new VersionedPersonWithCounter("Donald", "Duckling")); + + int updateCount = versionedPersonRepository.findAndSetFirstnameToLastnameIncVersionByLastname(person.getLastname(), + 10); + + assertThat(updateCount).isOne(); + + Document document = template.execute(VersionedPersonWithCounter.class, collection -> { + return collection.find(new Document("_id", new ObjectId(person.getId()))).first(); + }); + + assertThat(document).containsEntry("firstname", "Duckling").containsEntry("version", 10L); + } + + interface VersionedPersonRepository extends CrudRepository { + + @Update("{ '$set': { 'firstname' : ?0 } }") + int findAndSetFirstnameToLastnameByLastname(String lastname); + + @Update("{ '$inc': { 'counter' : 42 } }") + int findAndIncCounterByLastname(String lastname); + + @Update(""" + { + '$set': { 'firstname' : ?0 }, + '$inc': { 'version' : ?1 } + }""") + int findAndSetFirstnameToLastnameIncVersionByLastname(String lastname, int incVersion); + + } + + @org.springframework.data.mongodb.core.mapping.Document("versioned-person") + static class VersionedPersonWithCounter extends VersionedPerson { + + int counter; + + public VersionedPersonWithCounter(String firstname, @Nullable String lastname) { + super(firstname, lastname); + } + + public int getCounter() { + return counter; + } + + public void setCounter(int counter) { + this.counter = counter; + } + + } + +}