Skip to content

Commit 7b24496

Browse files
committed
DATAJPA-218 - Use domain type for untyped Examples.
1 parent bf4837e commit 7b24496

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

src/main/java/org/springframework/data/jpa/repository/support/SimpleJpaRepository.java

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.springframework.data.domain.PageImpl;
4545
import org.springframework.data.domain.Pageable;
4646
import org.springframework.data.domain.Sort;
47+
import org.springframework.data.domain.TypedExampleSpec;
4748
import org.springframework.data.jpa.convert.QueryByExamplePredicateBuilder;
4849
import org.springframework.data.jpa.domain.Specification;
4950
import org.springframework.data.jpa.provider.PersistenceProvider;
@@ -437,7 +438,7 @@ public List<T> findAll(Specification<T> spec, Sort sort) {
437438
@Override
438439
public <S extends T> S findOne(Example<S> example) {
439440
try {
440-
return getQuery(new ExampleSpecification<S>(example), example.getResultType(), (Sort) null).getSingleResult();
441+
return getQuery(new ExampleSpecification<S>(example), getResultType(example), (Sort) null).getSingleResult();
441442
} catch (NoResultException e) {
442443
return null;
443444
}
@@ -449,15 +450,15 @@ public <S extends T> S findOne(Example<S> example) {
449450
@SuppressWarnings("unchecked")
450451
@Override
451452
public <S extends T> long count(Example<S> example) {
452-
return executeCountQuery(getCountQuery(new ExampleSpecification<S>(example), example.getResultType()));
453+
return executeCountQuery(getCountQuery(new ExampleSpecification<S>(example), getResultType(example)));
453454
}
454455

455456
/* (non-Javadoc)
456457
* @see org.springframework.data.repository.query.QueryByExampleExecutor#exists(org.springframework.data.domain.Example)
457458
*/
458459
@Override
459460
public <S extends T> boolean exists(Example<S> example) {
460-
return !getQuery(new ExampleSpecification<S>(example), example.getResultType(), (Sort) null).getResultList()
461+
return !getQuery(new ExampleSpecification<S>(example), getResultType(example), (Sort) null).getResultList()
461462
.isEmpty();
462463
}
463464

@@ -467,7 +468,7 @@ public <S extends T> boolean exists(Example<S> example) {
467468
*/
468469
@Override
469470
public <S extends T> List<S> findAll(Example<S> example) {
470-
return getQuery(new ExampleSpecification<S>(example), example.getResultType(), (Sort) null).getResultList();
471+
return getQuery(new ExampleSpecification<S>(example), getResultType(example), (Sort) null).getResultList();
471472
}
472473

473474
/*
@@ -476,7 +477,7 @@ public <S extends T> List<S> findAll(Example<S> example) {
476477
*/
477478
@Override
478479
public <S extends T> List<S> findAll(Example<S> example, Sort sort) {
479-
return getQuery(new ExampleSpecification<S>(example), example.getResultType(), sort).getResultList();
480+
return getQuery(new ExampleSpecification<S>(example), getResultType(example), sort).getResultList();
480481
}
481482

482483
/*
@@ -487,9 +488,9 @@ public <S extends T> List<S> findAll(Example<S> example, Sort sort) {
487488
public <S extends T> Page<S> findAll(Example<S> example, Pageable pageable) {
488489

489490
ExampleSpecification<S> spec = new ExampleSpecification<S>(example);
490-
TypedQuery<S> query = getQuery(new ExampleSpecification<S>(example), example.getResultType(), pageable);
491+
TypedQuery<S> query = getQuery(new ExampleSpecification<S>(example), getResultType(example), pageable);
491492
return pageable == null ? new PageImpl<S>(query.getResultList())
492-
: readPage(query, example.getResultType(), pageable, spec);
493+
: readPage(query, getResultType(example), pageable, spec);
493494
}
494495

495496
/*
@@ -763,6 +764,15 @@ private void applyQueryHints(Query query) {
763764
}
764765
}
765766

767+
768+
private <S extends T> Class<S> getResultType(Example<S> example) {
769+
770+
if(example.getExampleSpec() instanceof TypedExampleSpec<?>){
771+
return example.getResultType();
772+
}
773+
return (Class<S>) getDomainClass();
774+
}
775+
766776
/**
767777
* Executes a count query and transparently sums up all values returned.
768778
*

0 commit comments

Comments
 (0)