Skip to content

Commit 3f72baa

Browse files
committed
Fix event marshalling
Use the correct classes for the class -> marshaller mapper for new style generated event streams.
1 parent 4e281b0 commit 3f72baa

File tree

5 files changed

+182
-10
lines changed

5 files changed

+182
-10
lines changed

codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import software.amazon.awssdk.codegen.emitters.GeneratorTaskParams;
5353
import software.amazon.awssdk.codegen.model.config.customization.UtilitiesMethod;
5454
import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel;
55+
import software.amazon.awssdk.codegen.model.intermediate.MemberModel;
5556
import software.amazon.awssdk.codegen.model.intermediate.OperationModel;
5657
import software.amazon.awssdk.codegen.model.intermediate.ShapeModel;
5758
import software.amazon.awssdk.codegen.model.service.AuthType;
@@ -60,6 +61,7 @@
6061
import software.amazon.awssdk.codegen.poet.StaticImport;
6162
import software.amazon.awssdk.codegen.poet.client.specs.ProtocolSpec;
6263
import software.amazon.awssdk.codegen.poet.eventstream.EventStreamUtils;
64+
import software.amazon.awssdk.codegen.poet.model.EventStreamSpecHelper;
6365
import software.amazon.awssdk.core.RequestOverrideConfiguration;
6466
import software.amazon.awssdk.core.async.SdkPublisher;
6567
import software.amazon.awssdk.core.client.config.SdkAdvancedAsyncClientOption;
@@ -358,16 +360,17 @@ private CodeBlock eventToByteBufferPublisher(OperationModel opModel) {
358360
}
359361

360362
private CodeBlock createEventStreamTaggedUnionJsonMarshaller(ShapeModel eventStreamShape) {
363+
EventStreamSpecHelper specHelper = new EventStreamSpecHelper(eventStreamShape, model);
364+
361365
CodeBlock.Builder builder = CodeBlock.builder().add("$1T eventMarshaller = $1T.builder()",
362366
EventStreamTaggedUnionJsonMarshaller.class);
363367

364-
List<String> eventNames = EventStreamUtils.getEventMembers(eventStreamShape)
365-
.map(m -> m.getShape().getShapeName())
368+
List<MemberModel> eventMembers = EventStreamUtils.getEventMembers(eventStreamShape)
366369
.collect(Collectors.toList());
367370

368-
eventNames.forEach(event -> builder.add(".putMarshaller($T.class, new $T(protocolFactory))",
369-
poetExtensions.getModelClass(event),
370-
poetExtensions.getTransformClass(event + "Marshaller")));
371+
eventMembers.forEach(event -> builder.add(".putMarshaller($T.class, new $T(protocolFactory))",
372+
specHelper.eventClassName(event),
373+
poetExtensions.getTransformClass(event.getShape() + "Marshaller")));
371374

372375
builder.add(".build();");
373376
return builder.build();

codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-async-client-class.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,8 @@
6868
import software.amazon.awssdk.services.json.model.EventStreamOperationWithOnlyOutputResponseHandler;
6969
import software.amazon.awssdk.services.json.model.GetWithoutRequiredMembersRequest;
7070
import software.amazon.awssdk.services.json.model.GetWithoutRequiredMembersResponse;
71-
import software.amazon.awssdk.services.json.model.InputEvent;
7271
import software.amazon.awssdk.services.json.model.InputEventStream;
7372
import software.amazon.awssdk.services.json.model.InputEventStreamTwo;
74-
import software.amazon.awssdk.services.json.model.InputEventTwo;
7573
import software.amazon.awssdk.services.json.model.InvalidInputException;
7674
import software.amazon.awssdk.services.json.model.JsonException;
7775
import software.amazon.awssdk.services.json.model.JsonRequest;
@@ -87,6 +85,9 @@
8785
import software.amazon.awssdk.services.json.model.StreamingInputOutputOperationResponse;
8886
import software.amazon.awssdk.services.json.model.StreamingOutputOperationRequest;
8987
import software.amazon.awssdk.services.json.model.StreamingOutputOperationResponse;
88+
import software.amazon.awssdk.services.json.model.inputeventstream.DefaultInputEvent;
89+
import software.amazon.awssdk.services.json.model.inputeventstreamtwo.DefaultInputEventOne;
90+
import software.amazon.awssdk.services.json.model.inputeventstreamtwo.DefaultInputEventTwo;
9091
import software.amazon.awssdk.services.json.paginators.PaginatedOperationWithResultKeyPublisher;
9192
import software.amazon.awssdk.services.json.paginators.PaginatedOperationWithoutResultKeyPublisher;
9293
import software.amazon.awssdk.services.json.transform.APostOperationRequestMarshaller;
@@ -311,7 +312,7 @@ public CompletableFuture<Void> eventStreamOperation(EventStreamOperationRequest
311312
HttpResponseHandler<AwsServiceException> errorResponseHandler = createErrorResponseHandler(protocolFactory,
312313
operationMetadata);
313314
EventStreamTaggedUnionJsonMarshaller eventMarshaller = EventStreamTaggedUnionJsonMarshaller.builder()
314-
.putMarshaller(InputEvent.class, new InputEventMarshaller(protocolFactory)).build();
315+
.putMarshaller(DefaultInputEvent.class, new InputEventMarshaller(protocolFactory)).build();
315316
SdkPublisher<InputEventStream> eventPublisher = SdkPublisher.adapt(requestStream);
316317
Publisher<ByteBuffer> adapted = eventPublisher.map(event -> eventMarshaller.marshall(event)).map(
317318
AwsClientHandlerUtils::encodeEventStreamRequestToByteBuffer);
@@ -398,8 +399,8 @@ public CompletableFuture<EventStreamOperationWithOnlyInputResponse> eventStreamO
398399
HttpResponseHandler<AwsServiceException> errorResponseHandler = createErrorResponseHandler(protocolFactory,
399400
operationMetadata);
400401
EventStreamTaggedUnionJsonMarshaller eventMarshaller = EventStreamTaggedUnionJsonMarshaller.builder()
401-
.putMarshaller(InputEvent.class, new InputEventMarshaller(protocolFactory))
402-
.putMarshaller(InputEventTwo.class, new InputEventTwoMarshaller(protocolFactory)).build();
402+
.putMarshaller(DefaultInputEventOne.class, new InputEventMarshaller(protocolFactory))
403+
.putMarshaller(DefaultInputEventTwo.class, new InputEventTwoMarshaller(protocolFactory)).build();
403404
SdkPublisher<InputEventStreamTwo> eventPublisher = SdkPublisher.adapt(requestStream);
404405
Publisher<ByteBuffer> adapted = eventPublisher.map(event -> eventMarshaller.marshall(event)).map(
405406
AwsClientHandlerUtils::encodeEventStreamRequestToByteBuffer);

test/codegen-generated-classes-test/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@
190190
<artifactId>rxjava</artifactId>
191191
<scope>test</scope>
192192
</dependency>
193+
<dependency>
194+
<groupId>software.amazon.eventstream</groupId>
195+
<artifactId>eventstream</artifactId>
196+
<scope>test</scope>
197+
</dependency>
193198
</dependencies>
194199

195200
<build>

test/codegen-generated-classes-test/src/main/resources/codegen-resources/eventstreams/service-2.json

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@
5757
"members": {
5858
"InputEvent": {
5959
"shape": "InputEvent"
60+
},
61+
"InputEventB": {
62+
"shape": "InputEvent"
63+
},
64+
"InputEventTwo": {
65+
"shape": "InputEventTwo"
6066
}
6167
},
6268
"eventstream": true
@@ -71,6 +77,17 @@
7177
},
7278
"event": true
7379
},
80+
"InputEventTwo": {
81+
"type": "structure",
82+
"members": {
83+
"ExplicitPayloadMember": {
84+
"shape":"ExplicitPayloadMember",
85+
"eventpayload":true
86+
}
87+
},
88+
"event": true
89+
},
90+
7491
"ExplicitPayloadMember":{"type":"blob"},
7592
"EventStream": {
7693
"type": "structure",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.services.eventstreams;
17+
18+
import static org.assertj.core.api.Assertions.assertThat;
19+
import static org.mockito.Matchers.any;
20+
import static org.mockito.Mockito.when;
21+
import io.reactivex.Flowable;
22+
import java.nio.ByteBuffer;
23+
import java.util.ArrayList;
24+
import java.util.List;
25+
import java.util.concurrent.CompletableFuture;
26+
import java.util.stream.Collectors;
27+
import java.util.stream.Stream;
28+
import org.junit.Before;
29+
import org.junit.Test;
30+
import org.junit.runner.RunWith;
31+
import org.mockito.Mock;
32+
import org.mockito.invocation.InvocationOnMock;
33+
import org.mockito.runners.MockitoJUnitRunner;
34+
import org.reactivestreams.Subscriber;
35+
import org.reactivestreams.Subscription;
36+
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
37+
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
38+
import software.amazon.awssdk.http.SdkHttpResponse;
39+
import software.amazon.awssdk.http.async.AsyncExecuteRequest;
40+
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
41+
import software.amazon.awssdk.http.async.SdkHttpContentPublisher;
42+
import software.amazon.awssdk.services.eventstreamrestjson.EventStreamRestJsonAsyncClient;
43+
import software.amazon.awssdk.services.eventstreamrestjson.model.EventStream;
44+
import software.amazon.awssdk.services.eventstreamrestjson.model.EventStreamOperationRequest;
45+
import software.amazon.awssdk.services.eventstreamrestjson.model.EventStreamOperationResponseHandler;
46+
import software.amazon.awssdk.services.eventstreamrestjson.model.InputEventStream;
47+
import software.amazon.eventstream.Message;
48+
import software.amazon.eventstream.MessageDecoder;
49+
50+
@RunWith(MockitoJUnitRunner.class)
51+
public class EventMarshallingTest {
52+
@Mock
53+
public SdkAsyncHttpClient mockHttpClient;
54+
55+
private EventStreamRestJsonAsyncClient client;
56+
57+
private List<Message> marshalledEvents;
58+
59+
private MessageDecoder chunkDecoder;
60+
private MessageDecoder eventDecoder;
61+
62+
@Before
63+
public void setup() {
64+
when(mockHttpClient.execute(any(AsyncExecuteRequest.class))).thenAnswer(this::mockExecute);
65+
client = EventStreamRestJsonAsyncClient.builder()
66+
.credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid")))
67+
.httpClient(mockHttpClient)
68+
.build();
69+
70+
marshalledEvents = new ArrayList<>();
71+
72+
chunkDecoder = new MessageDecoder();
73+
eventDecoder = new MessageDecoder();
74+
}
75+
76+
@Test
77+
public void testMarshalling_setsCorrectEventType() {
78+
List<InputEventStream> inputEvents = Stream.of(
79+
InputEventStream.inputEventBuilder().build(),
80+
InputEventStream.inputEventBBuilder().build(),
81+
InputEventStream.inputEventTwoBuilder().build()
82+
).collect(Collectors.toList());
83+
84+
Flowable<InputEventStream> inputStream = Flowable.fromIterable(inputEvents);
85+
86+
client.eventStreamOperation(EventStreamOperationRequest.builder().build(), inputStream, EventStreamOperationResponseHandler.builder()
87+
.subscriber(() -> new Subscriber<EventStream>() {
88+
@Override
89+
public void onSubscribe(Subscription subscription) {
90+
91+
}
92+
93+
@Override
94+
public void onNext(EventStream eventStream) {
95+
96+
}
97+
98+
@Override
99+
public void onError(Throwable throwable) {
100+
101+
}
102+
103+
@Override
104+
public void onComplete() {
105+
106+
}
107+
})
108+
.build()).join();
109+
110+
List<String> expectedTypes = Stream.of(
111+
"InputEvent",
112+
"InputEventB",
113+
"InputEventTwo"
114+
).collect(Collectors.toList());;
115+
116+
assertThat(marshalledEvents).hasSize(inputEvents.size());
117+
118+
for (int i = 0; i < marshalledEvents.size(); ++i) {
119+
Message marshalledEvent = marshalledEvents.get(i);
120+
String expectedType = expectedTypes.get(i);
121+
assertThat(marshalledEvent.getHeaders().get(":event-type").getString())
122+
.isEqualTo(expectedType);
123+
}
124+
}
125+
126+
private CompletableFuture<Void> mockExecute(InvocationOnMock invocation) {
127+
AsyncExecuteRequest request = invocation.getArgumentAt(0, AsyncExecuteRequest.class);
128+
SdkHttpContentPublisher content = request.requestContentPublisher();
129+
List<ByteBuffer> chunks = Flowable.fromPublisher(content).toList().blockingGet();
130+
131+
for (ByteBuffer c : chunks) {
132+
chunkDecoder.feed(c);
133+
}
134+
135+
for (Message m : chunkDecoder.getDecodedMessages()) {
136+
eventDecoder.feed(m.getPayload());
137+
}
138+
139+
marshalledEvents.addAll(eventDecoder.getDecodedMessages());
140+
141+
request.responseHandler().onHeaders(SdkHttpResponse.builder().statusCode(200).build());
142+
request.responseHandler().onStream(Flowable.empty());
143+
144+
return CompletableFuture.completedFuture(null);
145+
}
146+
}

0 commit comments

Comments
 (0)