@@ -92,7 +92,8 @@ public class GraphQLJpaSchemaBuilder implements GraphQLSchemaBuilder {
92
92
93
93
private Map <Class <?>, GraphQLType > classCache = new HashMap <>();
94
94
private Map <EntityType <?>, GraphQLObjectType > entityCache = new HashMap <>();
95
- private Map <EmbeddableType <?>, GraphQLObjectType > embeddableCache = new HashMap <>();
95
+ private Map <EmbeddableType <?>, GraphQLObjectType > embeddableOutputCache = new HashMap <>();
96
+ private Map <EmbeddableType <?>, GraphQLInputObjectType > embeddableInputCache = new HashMap <>();
96
97
97
98
private static final Logger log = LoggerFactory .getLogger (GraphQLJpaSchemaBuilder .class );
98
99
@@ -289,13 +290,13 @@ private GraphQLInputType getWhereAttributeType(Attribute<?,?> attribute) {
289
290
.field (GraphQLInputObjectField .newInputObjectField ()
290
291
.name (Criteria .EQ .name ())
291
292
.description ("Equals criteria" )
292
- .type (( GraphQLInputType ) getAttributeType (attribute ))
293
+ .type (getAttributeInputType (attribute ))
293
294
.build ()
294
295
)
295
296
.field (GraphQLInputObjectField .newInputObjectField ()
296
297
.name (Criteria .NE .name ())
297
298
.description ("Not Equals criteria" )
298
- .type (( GraphQLInputType ) getAttributeType (attribute ))
299
+ .type (getAttributeInputType (attribute ))
299
300
.build ()
300
301
);
301
302
@@ -304,25 +305,25 @@ private GraphQLInputType getWhereAttributeType(Attribute<?,?> attribute) {
304
305
builder .field (GraphQLInputObjectField .newInputObjectField ()
305
306
.name (Criteria .LE .name ())
306
307
.description ("Less then or Equals criteria" )
307
- .type (( GraphQLInputType ) getAttributeType (attribute ))
308
+ .type (getAttributeInputType (attribute ))
308
309
.build ()
309
310
)
310
311
.field (GraphQLInputObjectField .newInputObjectField ()
311
312
.name (Criteria .GE .name ())
312
313
.description ("Greater or Equals criteria" )
313
- .type (( GraphQLInputType ) getAttributeType (attribute ))
314
+ .type (getAttributeInputType (attribute ))
314
315
.build ()
315
316
)
316
317
.field (GraphQLInputObjectField .newInputObjectField ()
317
318
.name (Criteria .GT .name ())
318
319
.description ("Greater Then criteria" )
319
- .type (( GraphQLInputType ) getAttributeType (attribute ))
320
+ .type (getAttributeInputType (attribute ))
320
321
.build ()
321
322
)
322
323
.field (GraphQLInputObjectField .newInputObjectField ()
323
324
.name (Criteria .LT .name ())
324
325
.description ("Less Then criteria" )
325
- .type (( GraphQLInputType ) getAttributeType (attribute ))
326
+ .type (getAttributeInputType (attribute ))
326
327
.build ()
327
328
);
328
329
}
@@ -331,25 +332,25 @@ private GraphQLInputType getWhereAttributeType(Attribute<?,?> attribute) {
331
332
builder .field (GraphQLInputObjectField .newInputObjectField ()
332
333
.name (Criteria .LIKE .name ())
333
334
.description ("Like criteria" )
334
- .type (( GraphQLInputType ) getAttributeType (attribute ))
335
+ .type (getAttributeInputType (attribute ))
335
336
.build ()
336
337
)
337
338
.field (GraphQLInputObjectField .newInputObjectField ()
338
339
.name (Criteria .CASE .name ())
339
340
.description ("Case sensitive match criteria" )
340
- .type (( GraphQLInputType ) getAttributeType (attribute ))
341
+ .type (getAttributeInputType (attribute ))
341
342
.build ()
342
343
)
343
344
.field (GraphQLInputObjectField .newInputObjectField ()
344
345
.name (Criteria .STARTS .name ())
345
346
.description ("Starts with criteria" )
346
- .type (( GraphQLInputType ) getAttributeType (attribute ))
347
+ .type (getAttributeInputType (attribute ))
347
348
.build ()
348
349
)
349
350
.field (GraphQLInputObjectField .newInputObjectField ()
350
351
.name (Criteria .ENDS .name ())
351
352
.description ("Ends with criteria" )
352
- .type (( GraphQLInputType ) getAttributeType (attribute ))
353
+ .type (getAttributeInputType (attribute ))
353
354
.build ()
354
355
);
355
356
}
@@ -370,13 +371,13 @@ private GraphQLInputType getWhereAttributeType(Attribute<?,?> attribute) {
370
371
.field (GraphQLInputObjectField .newInputObjectField ()
371
372
.name (Criteria .IN .name ())
372
373
.description ("In criteria" )
373
- .type (new GraphQLList (getAttributeType (attribute )))
374
+ .type (new GraphQLList (getAttributeInputType (attribute )))
374
375
.build ()
375
376
)
376
377
.field (GraphQLInputObjectField .newInputObjectField ()
377
378
.name (Criteria .NIN .name ())
378
379
.description ("Not In criteria" )
379
- .type (new GraphQLList (getAttributeType (attribute )))
380
+ .type (new GraphQLList (getAttributeInputType (attribute )))
380
381
.build ()
381
382
);
382
383
@@ -389,39 +390,52 @@ private GraphQLInputType getWhereAttributeType(Attribute<?,?> attribute) {
389
390
}
390
391
391
392
private GraphQLArgument getArgument (Attribute <?,?> attribute ) {
392
- GraphQLType type = getAttributeType (attribute );
393
+ GraphQLInputType type = getAttributeInputType (attribute );
393
394
String description = getSchemaDescription (attribute .getJavaMember ());
394
395
395
- if (type instanceof GraphQLInputType ) {
396
- return GraphQLArgument .newArgument ()
397
- .name (attribute .getName ())
398
- .type ((GraphQLInputType ) type )
399
- .description (description )
400
- .build ();
401
- }
402
-
403
- throw new IllegalArgumentException ("Attribute " + attribute + " cannot be mapped as an Input Argument" );
396
+ return GraphQLArgument .newArgument ()
397
+ .name (attribute .getName ())
398
+ .type ((GraphQLInputType ) type )
399
+ .description (description )
400
+ .build ();
404
401
}
405
402
406
- private GraphQLObjectType getEmbeddableType (EmbeddableType <?> embeddableType ) {
407
- if (embeddableCache .containsKey (embeddableType ))
408
- return embeddableCache .get (embeddableType );
409
-
410
- String embeddableTypeName = namingStrategy .singularize (embeddableType .getJavaType ().getSimpleName ())+"EmbeddableType" ;
411
-
412
- GraphQLObjectType objectType = GraphQLObjectType .newObject ()
413
- .name (embeddableTypeName )
414
- .description (getSchemaDescription ( embeddableType .getJavaType ()))
415
- .fields (embeddableType .getAttributes ().stream ()
416
- .filter (this ::isNotIgnored )
417
- .map (this ::getObjectField )
418
- .collect (Collectors .toList ())
419
- )
420
- .build ();
421
-
422
- embeddableCache .putIfAbsent (embeddableType , objectType );
403
+ private GraphQLType getEmbeddableType (EmbeddableType <?> embeddableType , boolean input ) {
404
+ if (input && embeddableInputCache .containsKey (embeddableType ))
405
+ return embeddableInputCache .get (embeddableType );
406
+
407
+ if (!input && embeddableOutputCache .containsKey (embeddableType ))
408
+ return embeddableOutputCache .get (embeddableType );
409
+ String embeddableTypeName = namingStrategy .singularize (embeddableType .getJavaType ().getSimpleName ())+ (input ? "Input" : "" ) +"EmbeddableType" ;
410
+ GraphQLType graphQLType =null ;
411
+ if (input ) {
412
+ graphQLType = GraphQLInputObjectType .newInputObject ()
413
+ .name (embeddableTypeName )
414
+ .description (getSchemaDescription (embeddableType .getJavaType ()))
415
+ .fields (embeddableType .getAttributes ().stream ()
416
+ .filter (this ::isNotIgnored )
417
+ .map (this ::getInputObjectField )
418
+ .collect (Collectors .toList ())
419
+ )
420
+ .build ();
421
+ } else {
422
+ graphQLType = GraphQLObjectType .newObject ()
423
+ .name (embeddableTypeName )
424
+ .description (getSchemaDescription (embeddableType .getJavaType ()))
425
+ .fields (embeddableType .getAttributes ().stream ()
426
+ .filter (this ::isNotIgnored )
427
+ .map (this ::getObjectField )
428
+ .collect (Collectors .toList ())
429
+ )
430
+ .build ();
431
+ }
432
+ if (input ) {
433
+ embeddableInputCache .putIfAbsent (embeddableType , (GraphQLInputObjectType ) graphQLType );
434
+ } else {
435
+ embeddableOutputCache .putIfAbsent (embeddableType , (GraphQLObjectType ) graphQLType );
436
+ }
423
437
424
- return objectType ;
438
+ return graphQLType ;
425
439
}
426
440
427
441
@@ -447,67 +461,92 @@ private GraphQLObjectType getObjectType(EntityType<?> entityType) {
447
461
448
462
@ SuppressWarnings ( { "rawtypes" , "unchecked" } )
449
463
private GraphQLFieldDefinition getObjectField (Attribute attribute ) {
450
- GraphQLType type = getAttributeType (attribute );
451
-
452
- if (type instanceof GraphQLOutputType ) {
453
- List <GraphQLArgument > arguments = new ArrayList <>();
454
- DataFetcher dataFetcher = PropertyDataFetcher .fetching (attribute .getName ());
455
-
456
- // Only add the orderBy argument for basic attribute types
457
- if (attribute instanceof SingularAttribute
458
- && attribute .getPersistentAttributeType () == Attribute .PersistentAttributeType .BASIC ) {
459
- arguments .add (GraphQLArgument .newArgument ()
460
- .name (ORDER_BY_PARAM_NAME )
461
- .description ("Specifies field sort direction in the query results." )
462
- .type (orderByDirectionEnum )
463
- .build ()
464
- );
465
- }
466
-
467
- // Get the fields that can be queried on (i.e. Simple Types, no Sub-Objects)
468
- if (attribute instanceof SingularAttribute
469
- && attribute .getPersistentAttributeType () != Attribute .PersistentAttributeType .BASIC ) {
470
- ManagedType foreignType = (ManagedType ) ((SingularAttribute ) attribute ).getType ();
471
-
472
- // TODO fix page count query
473
- arguments .add (getWhereArgument (foreignType ));
474
-
475
- } // Get Sub-Objects fields queries via DataFetcher
476
- else if (attribute instanceof PluralAttribute
477
- && (attribute .getPersistentAttributeType () == Attribute .PersistentAttributeType .ONE_TO_MANY
478
- || attribute .getPersistentAttributeType () == Attribute .PersistentAttributeType .MANY_TO_MANY )) {
479
- EntityType declaringType = (EntityType ) ((PluralAttribute ) attribute ).getDeclaringType ();
480
- EntityType elementType = (EntityType ) ((PluralAttribute ) attribute ).getElementType ();
481
-
482
- arguments .add (getWhereArgument (elementType ));
483
- dataFetcher = new GraphQLJpaOneToManyDataFetcher (entityManager , declaringType , (PluralAttribute ) attribute );
484
- }
464
+ GraphQLOutputType type = getAttributeOutputType (attribute );
465
+
466
+ List <GraphQLArgument > arguments = new ArrayList <>();
467
+ DataFetcher dataFetcher = PropertyDataFetcher .fetching (attribute .getName ());
468
+
469
+ // Only add the orderBy argument for basic attribute types
470
+ if (attribute instanceof SingularAttribute
471
+ && attribute .getPersistentAttributeType () == Attribute .PersistentAttributeType .BASIC ) {
472
+ arguments .add (GraphQLArgument .newArgument ()
473
+ .name (ORDER_BY_PARAM_NAME )
474
+ .description ("Specifies field sort direction in the query results." )
475
+ .type (orderByDirectionEnum )
476
+ .build ()
477
+ );
478
+ }
485
479
486
- return GraphQLFieldDefinition .newFieldDefinition ()
487
- .name (attribute .getName ())
488
- .description (getSchemaDescription (attribute .getJavaMember ()))
489
- .type ((GraphQLOutputType ) type )
490
- .dataFetcher (dataFetcher )
491
- .argument (arguments )
492
- .build ();
480
+ // Get the fields that can be queried on (i.e. Simple Types, no Sub-Objects)
481
+ if (attribute instanceof SingularAttribute
482
+ && attribute .getPersistentAttributeType () != Attribute .PersistentAttributeType .BASIC ) {
483
+ ManagedType foreignType = (ManagedType ) ((SingularAttribute ) attribute ).getType ();
484
+
485
+ // TODO fix page count query
486
+ arguments .add (getWhereArgument (foreignType ));
487
+
488
+ } // Get Sub-Objects fields queries via DataFetcher
489
+ else if (attribute instanceof PluralAttribute
490
+ && (attribute .getPersistentAttributeType () == Attribute .PersistentAttributeType .ONE_TO_MANY
491
+ || attribute .getPersistentAttributeType () == Attribute .PersistentAttributeType .MANY_TO_MANY )) {
492
+ EntityType declaringType = (EntityType ) ((PluralAttribute ) attribute ).getDeclaringType ();
493
+ EntityType elementType = (EntityType ) ((PluralAttribute ) attribute ).getElementType ();
494
+
495
+ arguments .add (getWhereArgument (elementType ));
496
+ dataFetcher = new GraphQLJpaOneToManyDataFetcher (entityManager , declaringType , (PluralAttribute ) attribute );
493
497
}
494
498
495
- throw new IllegalArgumentException ("Attribute " + attribute + " cannot be mapped as an Output Argument" );
499
+ return GraphQLFieldDefinition .newFieldDefinition ()
500
+ .name (attribute .getName ())
501
+ .description (getSchemaDescription (attribute .getJavaMember ()))
502
+ .type (type )
503
+ .dataFetcher (dataFetcher )
504
+ .argument (arguments )
505
+ .build ();
506
+ }
507
+
508
+ @ SuppressWarnings ( { "rawtypes" , "unchecked" } )
509
+ private GraphQLInputObjectField getInputObjectField (Attribute attribute ) {
510
+ GraphQLInputType type = getAttributeInputType (attribute );
511
+
512
+ return GraphQLInputObjectField .newInputObjectField ()
513
+ .name (attribute .getName ())
514
+ .description (getSchemaDescription (attribute .getJavaMember ()))
515
+ .type (type )
516
+ .build ();
496
517
}
497
518
498
519
private Stream <Attribute <?,?>> findBasicAttributes (Collection <Attribute <?,?>> attributes ) {
499
520
return attributes .stream ().filter (it -> it .getPersistentAttributeType () == Attribute .PersistentAttributeType .BASIC );
500
521
}
501
522
502
523
@ SuppressWarnings ( "rawtypes" )
503
- private GraphQLType getAttributeType (Attribute <?,?> attribute ) {
524
+ private GraphQLInputType getAttributeInputType (Attribute <?,?> attribute ) {
525
+ try {
526
+ return (GraphQLInputType ) getAttributeType (attribute , true );
527
+ } catch (ClassCastException e ){
528
+ throw new IllegalArgumentException ("Attribute " + attribute + " cannot be mapped as an Input Argument" );
529
+ }
530
+ }
531
+
532
+ @ SuppressWarnings ( "rawtypes" )
533
+ private GraphQLOutputType getAttributeOutputType (Attribute <?,?> attribute ) {
534
+ try {
535
+ return (GraphQLOutputType ) getAttributeType (attribute , false );
536
+ } catch (ClassCastException e ){
537
+ throw new IllegalArgumentException ("Attribute " + attribute + " cannot be mapped as an Output Argument" );
538
+ }
539
+ }
540
+
541
+ @ SuppressWarnings ( "rawtypes" )
542
+ private GraphQLType getAttributeType (Attribute <?,?> attribute , boolean input ) {
504
543
505
544
if (isBasic (attribute )) {
506
545
return getGraphQLTypeFromJavaType (attribute .getJavaType ());
507
546
}
508
547
else if (isEmbeddable (attribute )) {
509
548
EmbeddableType embeddableType = (EmbeddableType ) ((SingularAttribute ) attribute ).getType ();
510
- return getEmbeddableType (embeddableType );
549
+ return getEmbeddableType (embeddableType , input );
511
550
}
512
551
else if (isToMany (attribute )) {
513
552
EntityType foreignType = (EntityType ) ((PluralAttribute ) attribute ).getElementType ();
@@ -557,7 +596,8 @@ protected final boolean isToOne(Attribute<?,?> attribute) {
557
596
558
597
protected final boolean isValidInput (Attribute <?,?> attribute ) {
559
598
return attribute .getPersistentAttributeType () == Attribute .PersistentAttributeType .BASIC ||
560
- attribute .getPersistentAttributeType () == Attribute .PersistentAttributeType .ELEMENT_COLLECTION ;
599
+ attribute .getPersistentAttributeType () == Attribute .PersistentAttributeType .ELEMENT_COLLECTION ||
600
+ attribute .getPersistentAttributeType () == Attribute .PersistentAttributeType .EMBEDDED ;
561
601
}
562
602
563
603
private String getSchemaDescription (Member member ) {
0 commit comments