From 84c6d6c86d0619f43673a0bb7e028dd2d444d947 Mon Sep 17 00:00:00 2001 From: George Fu Date: Wed, 21 May 2025 15:32:26 -0400 Subject: [PATCH 1/2] feat(core): add rpcv2 cbor runtime protocol --- .changeset/pretty-cows-call.md | 5 + packages/core/package.json | 1 + .../core/src/submodules/cbor/CborCodec.ts | 173 +++++++++++ .../cbor/SmithyRpcV2CborProtocol.spec.ts | 276 ++++++++++++++++++ .../cbor/SmithyRpcV2CborProtocol.ts | 116 ++++++++ packages/core/src/submodules/cbor/index.ts | 4 +- .../protocols/HttpBindingProtocol.spec.ts | 159 ++++++++++ .../protocols/HttpBindingProtocol.ts | 237 +++++++++++++++ .../src/submodules/protocols/HttpProtocol.ts | 240 +++++++++++++++ .../src/submodules/protocols/RpcProtocol.ts | 108 +++++++ .../core/src/submodules/protocols/index.ts | 7 + .../serde/FromStringShapeDeserializer.ts | 84 ++++++ .../HttpInterceptingShapeDeserializer.ts | 59 ++++ .../serde/HttpInterceptingShapeSerializer.ts | 50 ++++ .../serde/ToStringShapeSerializer.ts | 99 +++++++ .../serde/determineTimestampFormat.ts | 40 +++ .../serde/copyDocumentWithTransform.ts | 61 ++++ packages/core/src/submodules/serde/index.ts | 7 +- yarn.lock | 1 + 19 files changed, 1723 insertions(+), 4 deletions(-) create mode 100644 .changeset/pretty-cows-call.md create mode 100644 packages/core/src/submodules/cbor/CborCodec.ts create mode 100644 packages/core/src/submodules/cbor/SmithyRpcV2CborProtocol.spec.ts create mode 100644 packages/core/src/submodules/cbor/SmithyRpcV2CborProtocol.ts create mode 100644 packages/core/src/submodules/protocols/HttpBindingProtocol.spec.ts create mode 100644 packages/core/src/submodules/protocols/HttpBindingProtocol.ts create mode 100644 packages/core/src/submodules/protocols/HttpProtocol.ts create mode 100644 packages/core/src/submodules/protocols/RpcProtocol.ts create mode 100644 packages/core/src/submodules/protocols/serde/FromStringShapeDeserializer.ts create mode 100644 packages/core/src/submodules/protocols/serde/HttpInterceptingShapeDeserializer.ts create mode 100644 packages/core/src/submodules/protocols/serde/HttpInterceptingShapeSerializer.ts create mode 100644 packages/core/src/submodules/protocols/serde/ToStringShapeSerializer.ts create mode 100644 packages/core/src/submodules/protocols/serde/determineTimestampFormat.ts create mode 100644 packages/core/src/submodules/serde/copyDocumentWithTransform.ts diff --git a/.changeset/pretty-cows-call.md b/.changeset/pretty-cows-call.md new file mode 100644 index 00000000000..e4818105ba7 --- /dev/null +++ b/.changeset/pretty-cows-call.md @@ -0,0 +1,5 @@ +--- +"@smithy/core": minor +--- + +add cbor protocol (alpha) diff --git a/packages/core/package.json b/packages/core/package.json index 235633a1347..e88f61ce105 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -72,6 +72,7 @@ "@smithy/middleware-serde": "workspace:^", "@smithy/protocol-http": "workspace:^", "@smithy/types": "workspace:^", + "@smithy/util-base64": "workspace:^", "@smithy/util-body-length-browser": "workspace:^", "@smithy/util-middleware": "workspace:^", "@smithy/util-stream": "workspace:^", diff --git a/packages/core/src/submodules/cbor/CborCodec.ts b/packages/core/src/submodules/cbor/CborCodec.ts new file mode 100644 index 00000000000..e4ab36f57ac --- /dev/null +++ b/packages/core/src/submodules/cbor/CborCodec.ts @@ -0,0 +1,173 @@ +import { NormalizedSchema } from "@smithy/core/schema"; +import { copyDocumentWithTransform, parseEpochTimestamp } from "@smithy/core/serde"; +import { Codec, Schema, SchemaRef, SerdeFunctions, ShapeDeserializer, ShapeSerializer } from "@smithy/types"; + +import { cbor } from "./cbor"; +import { dateToTag } from "./parseCborBody"; + +/** + * @alpha + */ +export class CborCodec implements Codec { + private serdeContext?: SerdeFunctions; + + public createSerializer(): CborShapeSerializer { + const serializer = new CborShapeSerializer(); + serializer.setSerdeContext(this.serdeContext!); + return serializer; + } + + public createDeserializer(): CborShapeDeserializer { + const deserializer = new CborShapeDeserializer(); + deserializer.setSerdeContext(this.serdeContext!); + return deserializer; + } + + public setSerdeContext(serdeContext: SerdeFunctions): void { + this.serdeContext = serdeContext; + } +} + +/** + * @alpha + */ +export class CborShapeSerializer implements ShapeSerializer { + private serdeContext?: SerdeFunctions; + private value: unknown; + + public setSerdeContext(serdeContext: SerdeFunctions) { + this.serdeContext = serdeContext; + } + + public write(schema: Schema, value: unknown): void { + this.value = copyDocumentWithTransform(value, schema, (_: any, schemaRef: SchemaRef) => { + if (_ instanceof Date) { + return dateToTag(_); + } + if (_ instanceof Uint8Array) { + return _; + } + + const ns = NormalizedSchema.of(schemaRef); + const sparse = !!ns.getMergedTraits().sparse; + + if (Array.isArray(_)) { + if (!sparse) { + return _.filter((item) => item != null); + } + } else if (_ && typeof _ === "object") { + if (!sparse || ns.isStructSchema()) { + for (const [k, v] of Object.entries(_)) { + if (v == null) { + delete _[k]; + } + } + return _; + } + } + + return _; + }); + } + + public flush(): Uint8Array { + const buffer = cbor.serialize(this.value); + this.value = undefined; + return buffer as Uint8Array; + } +} + +/** + * @alpha + */ +export class CborShapeDeserializer implements ShapeDeserializer { + private serdeContext?: SerdeFunctions; + + public setSerdeContext(serdeContext: SerdeFunctions) { + this.serdeContext = serdeContext; + } + + public read(schema: Schema, bytes: Uint8Array): any { + const data: any = cbor.deserialize(bytes); + return this.readValue(schema, data); + } + + private readValue(_schema: Schema, value: any): any { + const ns = NormalizedSchema.of(_schema); + const schema = ns.getSchema(); + + if (typeof schema === "number") { + if (ns.isTimestampSchema()) { + // format is ignored. + return parseEpochTimestamp(value); + } + if (ns.isBlobSchema()) { + return value; + } + } + + if ( + typeof value === "undefined" || + typeof value === "boolean" || + typeof value === "number" || + typeof value === "string" || + typeof value === "bigint" || + typeof value === "symbol" + ) { + return value; + } else if (typeof value === "function" || typeof value === "object") { + if (value === null) { + return null; + } + if ("byteLength" in (value as Uint8Array)) { + return value; + } + if (value instanceof Date) { + return value; + } + if (ns.isDocumentSchema()) { + return value; + } + + if (ns.isListSchema()) { + const newArray = []; + const memberSchema = ns.getValueSchema(); + const sparse = ns.isListSchema() && !!ns.getMergedTraits().sparse; + + for (const item of value) { + newArray.push(this.readValue(memberSchema, item)); + if (!sparse && newArray[newArray.length - 1] == null) { + newArray.pop(); + } + } + return newArray; + } + + const newObject = {} as any; + + if (ns.isMapSchema()) { + const sparse = ns.getMergedTraits().sparse; + const targetSchema = ns.getValueSchema(); + + for (const key of Object.keys(value)) { + newObject[key] = this.readValue(targetSchema, value[key]); + + if (newObject[key] == null && !sparse) { + delete newObject[key]; + } + } + } else if (ns.isStructSchema()) { + for (const key of Object.keys(value)) { + const targetSchema = ns.getMemberSchema(key); + if (targetSchema === undefined) { + continue; + } + newObject[key] = this.readValue(targetSchema, value[key]); + } + } + return newObject; + } else { + return value; + } + } +} diff --git a/packages/core/src/submodules/cbor/SmithyRpcV2CborProtocol.spec.ts b/packages/core/src/submodules/cbor/SmithyRpcV2CborProtocol.spec.ts new file mode 100644 index 00000000000..ca6bb407f5f --- /dev/null +++ b/packages/core/src/submodules/cbor/SmithyRpcV2CborProtocol.spec.ts @@ -0,0 +1,276 @@ +import { list, map, SCHEMA, struct } from "@smithy/core/schema"; +import { HttpRequest, HttpResponse } from "@smithy/protocol-http"; +import { SchemaRef } from "@smithy/types"; +import { describe, expect, test as it } from "vitest"; + +import { cbor } from "./cbor"; +import { dateToTag } from "./parseCborBody"; +import { SmithyRpcV2CborProtocol } from "./SmithyRpcV2CborProtocol"; + +describe(SmithyRpcV2CborProtocol.name, () => { + const bytes = (arr: number[]) => Buffer.from(arr); + + describe("serialization", () => { + const testCases: Array<{ + name: string; + schema: SchemaRef; + input: any; + expected: { + request: any; + body: any; + }; + }> = [ + { + name: "document with timestamp and blob", + schema: struct( + "", + "MyExtendedDocument", + {}, + ["timestamp", "blob"], + [ + [SCHEMA.TIMESTAMP_DEFAULT, 0], + [SCHEMA.BLOB, 0], + ] + ), + input: { + bool: true, + int: 5, + float: -3.001, + timestamp: new Date(1_000_000), + blob: bytes([97, 98, 99, 100]), + }, + expected: { + request: {}, + body: { + timestamp: dateToTag(new Date(1_000_000)), + blob: bytes([97, 98, 99, 100]), + }, + }, + }, + { + name: "do not write to header or query", + schema: struct( + "", + "MyExtendedDocument", + {}, + ["bool", "timestamp", "blob", "prefixHeaders", "searchParams"], + [ + [SCHEMA.BOOLEAN, { httpQuery: "bool" }], + [SCHEMA.TIMESTAMP_DEFAULT, { httpHeader: "timestamp" }], + [SCHEMA.BLOB, { httpHeader: "blob" }], + [SCHEMA.MAP_MODIFIER | SCHEMA.STRING, { httpPrefixHeaders: "anti-" }], + [SCHEMA.MAP_MODIFIER | SCHEMA.STRING, { httpQueryParams: 1 }], + ] + ), + input: { + bool: true, + timestamp: new Date(1_000_000), + blob: bytes([97, 98, 99, 100]), + prefixHeaders: { + pasto: "cheese dodecahedron", + clockwise: "left", + }, + searchParams: { + a: 1, + b: 2, + }, + }, + expected: { + request: { + headers: {}, + query: {}, + }, + body: { + bool: true, + timestamp: dateToTag(new Date(1_000_000)), + blob: bytes([97, 98, 99, 100]), + prefixHeaders: { + pasto: "cheese dodecahedron", + clockwise: "left", + }, + searchParams: { + a: 1, + b: 2, + }, + }, + }, + }, + { + name: "sparse list and map", + schema: struct( + "", + "MyShape", + 0, + ["mySparseList", "myRegularList", "mySparseMap", "myRegularMap"], + [ + [() => list("", "MyList", { sparse: 1 }, SCHEMA.NUMERIC), {}], + [() => list("", "MyList", {}, SCHEMA.NUMERIC), {}], + [() => map("", "MyMap", { sparse: 1 }, SCHEMA.STRING, SCHEMA.NUMERIC), {}], + [() => map("", "MyMap", {}, SCHEMA.STRING, SCHEMA.NUMERIC), {}], + ] + ), + input: { + mySparseList: [null, 1, null, 2, null], + myRegularList: [null, 1, null, 2, null], + mySparseMap: { + 0: null, + 1: 1, + 2: null, + 3: 3, + 4: null, + }, + myRegularMap: { + 0: null, + 1: 1, + 2: null, + 3: 3, + 4: null, + }, + }, + expected: { + request: {}, + body: { + mySparseList: [null, 1, null, 2, null], + myRegularList: [1, 2], + mySparseMap: { + 0: null, + 1: 1, + 2: null, + 3: 3, + 4: null, + }, + myRegularMap: { + 1: 1, + 3: 3, + }, + }, + }, + }, + ]; + + for (const testCase of testCases) { + it(`should serialize HTTP Requests: ${testCase.name}`, async () => { + const protocol = new SmithyRpcV2CborProtocol({ defaultNamespace: "" }); + const httpRequest = await protocol.serializeRequest( + { + name: "dummy", + input: testCase.schema, + output: void 0, + traits: {}, + }, + testCase.input, + { + async endpoint() { + return { + protocol: "https:", + hostname: "example.com", + path: "/", + }; + }, + } as any + ); + + const body = httpRequest.body; + httpRequest.body = void 0; + + expect(httpRequest).toEqual( + new HttpRequest({ + protocol: "https:", + hostname: "example.com", + method: "POST", + path: "/service/undefined/operation/undefined", + ...testCase.expected.request, + headers: { + accept: "application/cbor", + "content-type": "application/cbor", + "smithy-protocol": "rpc-v2-cbor", + "content-length": String(body.byteLength), + ...testCase.expected.request.headers, + }, + }) + ); + + expect(cbor.deserialize(body)).toEqual(testCase.expected.body); + }); + } + }); + + describe("deserialization", () => { + const testCases = [ + { + name: "sparse list and map", + schema: struct( + "", + "MyShape", + 0, + ["mySparseList", "myRegularList", "mySparseMap", "myRegularMap"], + [ + [() => list("", "MyList", { sparse: 1 }, SCHEMA.NUMERIC), {}], + [() => list("", "MyList", {}, SCHEMA.NUMERIC), {}], + [() => map("", "MyMap", { sparse: 1 }, SCHEMA.STRING, SCHEMA.NUMERIC), {}], + [() => map("", "MyMap", {}, SCHEMA.STRING, SCHEMA.NUMERIC), {}], + ] + ), + mockOutput: { + mySparseList: [null, 1, null, 2, null], + myRegularList: [null, 1, null, 2, null], + mySparseMap: { + 0: null, + 1: 1, + 2: null, + 3: 3, + 4: null, + }, + myRegularMap: { + 0: null, + 1: 1, + 2: null, + 3: 3, + 4: null, + }, + }, + expected: { + output: { + mySparseList: [null, 1, null, 2, null], + myRegularList: [1, 2], + mySparseMap: { + 0: null, + 1: 1, + 2: null, + 3: 3, + 4: null, + }, + myRegularMap: { + 1: 1, + 3: 3, + }, + }, + }, + }, + ]; + + for (const testCase of testCases) { + it(`should deserialize HTTP Responses: ${testCase.name}`, async () => { + const protocol = new SmithyRpcV2CborProtocol({ + defaultNamespace: "", + }); + const output = await protocol.deserializeResponse( + { + name: "dummy", + input: void 0, + output: testCase.schema, + traits: {}, + }, + {} as any, + new HttpResponse({ + statusCode: 200, + body: cbor.serialize(testCase.mockOutput), + }) + ); + + delete (output as Partial).$metadata; + expect(output).toEqual(testCase.expected.output); + }); + } + }); +}); diff --git a/packages/core/src/submodules/cbor/SmithyRpcV2CborProtocol.ts b/packages/core/src/submodules/cbor/SmithyRpcV2CborProtocol.ts new file mode 100644 index 00000000000..ab7b03c521e --- /dev/null +++ b/packages/core/src/submodules/cbor/SmithyRpcV2CborProtocol.ts @@ -0,0 +1,116 @@ +import { RpcProtocol } from "@smithy/core/protocols"; +import { deref, ErrorSchema, OperationSchema, TypeRegistry } from "@smithy/core/schema"; +import type { + EndpointBearer, + HandlerExecutionContext, + HttpRequest as IHttpRequest, + HttpResponse as IHttpResponse, + MetadataBearer, + ResponseMetadata, + SerdeFunctions, +} from "@smithy/types"; +import { getSmithyContext } from "@smithy/util-middleware"; + +import { CborCodec } from "./CborCodec"; +import { loadSmithyRpcV2CborErrorCode } from "./parseCborBody"; + +/** + * Client protocol for Smithy RPCv2 CBOR. + * + * @alpha + */ +export class SmithyRpcV2CborProtocol extends RpcProtocol { + private codec = new CborCodec(); + protected serializer = this.codec.createSerializer(); + protected deserializer = this.codec.createDeserializer(); + + public constructor({ defaultNamespace }: { defaultNamespace: string }) { + super({ defaultNamespace }); + } + + public getShapeId(): string { + return "smithy.protocols#rpcv2Cbor"; + } + + public getPayloadCodec(): CborCodec { + return this.codec; + } + + public async serializeRequest( + operationSchema: OperationSchema, + input: Input, + context: HandlerExecutionContext & SerdeFunctions & EndpointBearer + ): Promise { + const request = await super.serializeRequest(operationSchema, input, context); + Object.assign(request.headers, { + "content-type": "application/cbor", + "smithy-protocol": "rpc-v2-cbor", + accept: "application/cbor", + }); + if (deref(operationSchema.input) === "unit") { + delete request.body; + delete request.headers["content-type"]; + } else { + if (!request.body) { + this.serializer.write(15, {}); + request.body = this.serializer.flush(); + } + try { + request.headers["content-length"] = String((request.body as Uint8Array).byteLength); + } catch (e) {} + } + const { service, operation } = getSmithyContext(context) as { + service: string; + operation: string; + }; + const path = `/service/${service}/operation/${operation}`; + if (request.path.endsWith("/")) { + request.path += path.slice(1); + } else { + request.path += path; + } + return request; + } + + public async deserializeResponse( + operationSchema: OperationSchema, + context: HandlerExecutionContext & SerdeFunctions, + response: IHttpResponse + ): Promise { + return super.deserializeResponse(operationSchema, context, response); + } + + protected async handleError( + operationSchema: OperationSchema, + context: HandlerExecutionContext & SerdeFunctions, + response: IHttpResponse, + dataObject: any, + metadata: ResponseMetadata + ): Promise { + const error = loadSmithyRpcV2CborErrorCode(response, dataObject) ?? "Unknown"; + + let namespace = this.options.defaultNamespace; + if (error.includes("#")) { + [namespace] = error.split("#"); + } + + const registry = TypeRegistry.for(namespace); + const errorSchema: ErrorSchema = registry.getSchema(error) as ErrorSchema; + + if (!errorSchema) { + // TODO(schema) throw client base exception using the dataObject. + throw new Error("schema not found for " + error); + } + + const message = dataObject.message ?? dataObject.Message ?? "Unknown"; + const exception = new errorSchema.ctor(message); + Object.assign(exception, { + $metadata: metadata, + $response: response, + message, + ...dataObject, + }); + + throw exception; + } +} diff --git a/packages/core/src/submodules/cbor/index.ts b/packages/core/src/submodules/cbor/index.ts index 0910d274e31..c53524e3a48 100644 --- a/packages/core/src/submodules/cbor/index.ts +++ b/packages/core/src/submodules/cbor/index.ts @@ -1,3 +1,5 @@ export { cbor } from "./cbor"; +export { tag, tagSymbol } from "./cbor-types"; export * from "./parseCborBody"; -export { tagSymbol, tag } from "./cbor-types"; +export * from "./SmithyRpcV2CborProtocol"; +export * from "./CborCodec"; diff --git a/packages/core/src/submodules/protocols/HttpBindingProtocol.spec.ts b/packages/core/src/submodules/protocols/HttpBindingProtocol.spec.ts new file mode 100644 index 00000000000..3a815c00414 --- /dev/null +++ b/packages/core/src/submodules/protocols/HttpBindingProtocol.spec.ts @@ -0,0 +1,159 @@ +import { op, SCHEMA, struct } from "@smithy/core/schema"; +import { HttpResponse } from "@smithy/protocol-http"; +import { + Codec, + CodecSettings, + HandlerExecutionContext, + HttpResponse as IHttpResponse, + OperationSchema, + ResponseMetadata, + ShapeDeserializer, + ShapeSerializer, +} from "@smithy/types"; +import { parseUrl } from "@smithy/url-parser/src"; +import { describe, expect, test as it } from "vitest"; + +import { HttpBindingProtocol } from "./HttpBindingProtocol"; +import { FromStringShapeDeserializer } from "./serde/FromStringShapeDeserializer"; +import { ToStringShapeSerializer } from "./serde/ToStringShapeSerializer"; + +describe(HttpBindingProtocol.name, () => { + class StringRestProtocol extends HttpBindingProtocol { + protected serializer: ShapeSerializer; + protected deserializer: ShapeDeserializer; + + public constructor() { + super({ + defaultNamespace: "", + }); + const settings: CodecSettings = { + timestampFormat: { + useTrait: true, + default: SCHEMA.TIMESTAMP_EPOCH_SECONDS, + }, + httpBindings: true, + }; + this.serializer = new ToStringShapeSerializer(settings); + this.deserializer = new FromStringShapeDeserializer(settings); + } + + public getShapeId(): string { + throw new Error("Method not implemented."); + } + public getPayloadCodec(): Codec { + throw new Error("Method not implemented."); + } + protected handleError( + operationSchema: OperationSchema, + context: HandlerExecutionContext, + response: IHttpResponse, + dataObject: any, + metadata: ResponseMetadata + ): Promise { + void [operationSchema, context, response, dataObject, metadata]; + throw new Error("Method not implemented."); + } + } + + it("should deserialize timestamp list with unescaped commas", async () => { + const response = new HttpResponse({ + statusCode: 200, + headers: { + "x-timestamplist": "Mon, 16 Dec 2019 23:48:18 GMT, Mon, 16 Dec 2019 23:48:18 GMT", + }, + }); + + const protocol = new StringRestProtocol(); + const output = await protocol.deserializeResponse( + op( + "", + "", + 0, + "unit", + struct( + "", + "", + 0, + ["timestampList"], + [ + [ + SCHEMA.LIST_MODIFIER | SCHEMA.TIMESTAMP_DEFAULT, + { + httpHeader: "x-timestamplist", + }, + ], + ] + ) + ), + {} as any, + response + ); + delete output.$metadata; + expect(output).toEqual({ + timestampList: [new Date("2019-12-16T23:48:18.000Z"), new Date("2019-12-16T23:48:18.000Z")], + }); + }); + + it("should deserialize all headers when httpPrefixHeaders value is empty string", async () => { + const response = new HttpResponse({ + statusCode: 200, + headers: { + "x-tents": "tents", + hello: "Hello", + }, + }); + + const protocol = new StringRestProtocol(); + const output = await protocol.deserializeResponse( + op( + "", + "", + 0, + "unit", + struct( + "", + "", + 0, + ["httpPrefixHeaders"], + [ + [ + SCHEMA.MAP_MODIFIER | SCHEMA.STRING, + { + httpPrefixHeaders: "", + }, + ], + ] + ) + ), + {} as any, + response + ); + delete output.$metadata; + expect(output).toEqual({ + httpPrefixHeaders: { + "x-tents": "tents", + hello: "Hello", + }, + }); + }); + + it("should serialize custom paths in context-provided endpoint", async () => { + const protocol = new StringRestProtocol(); + const request = await protocol.serializeRequest( + op( + "", + "", + { + http: ["GET", "/Operation", 200], + }, + "unit", + "unit" + ), + {}, + { + endpoint: async () => parseUrl("https://localhost/custom"), + } as any + ); + expect(request.path).toEqual("/custom/Operation"); + }); +}); diff --git a/packages/core/src/submodules/protocols/HttpBindingProtocol.ts b/packages/core/src/submodules/protocols/HttpBindingProtocol.ts new file mode 100644 index 00000000000..14f49e7ba0c --- /dev/null +++ b/packages/core/src/submodules/protocols/HttpBindingProtocol.ts @@ -0,0 +1,237 @@ +import { NormalizedSchema, SCHEMA } from "@smithy/core/schema"; +import { HttpRequest } from "@smithy/protocol-http"; +import { + Endpoint, + EndpointBearer, + HandlerExecutionContext, + HttpRequest as IHttpRequest, + HttpResponse as IHttpResponse, + MetadataBearer, + OperationSchema, + SerdeFunctions, +} from "@smithy/types"; + +import { collectBody } from "./collect-stream-body"; +import { extendedEncodeURIComponent } from "./extended-encode-uri-component"; +import { HttpProtocol } from "./HttpProtocol"; + +/** + * Base for HTTP-binding protocols. Downstream examples + * include AWS REST JSON and AWS REST XML. + * + * @alpha + */ +export abstract class HttpBindingProtocol extends HttpProtocol { + public async serializeRequest( + operationSchema: OperationSchema, + input: Input, + context: HandlerExecutionContext & SerdeFunctions & EndpointBearer + ): Promise { + const serializer = this.serializer; + const query = {} as Record; + const headers = {} as Record; + const endpoint: Endpoint = await context.endpoint(); + + const ns = NormalizedSchema.of(operationSchema?.input); + const schema = ns.getSchema(); + + let hasNonHttpBindingMember = false; + let payload: any; + + const request = new HttpRequest({ + protocol: "", + hostname: "", + port: undefined, + path: "", + fragment: undefined, + query: query, + headers: headers, + body: undefined, + }); + + if (endpoint) { + this.updateServiceEndpoint(request, endpoint); + this.setHostPrefix(request, operationSchema, input); + const opTraits = NormalizedSchema.translateTraits(operationSchema.traits); + if (opTraits.http) { + request.method = opTraits.http[0]; + const [path, search] = opTraits.http[1].split("?"); + if (request.path == "/") { + request.path = path; + } else { + request.path += path; + } + const traitSearchParams = new URLSearchParams(search ?? ""); + Object.assign(query, Object.fromEntries(traitSearchParams)); + } + } + + const _input: any = { + ...input, + }; + + for (const memberName of Object.keys(_input)) { + const memberNs = ns.getMemberSchema(memberName); + if (memberNs === undefined) { + continue; + } + const memberTraits = memberNs.getMergedTraits(); + const inputMember = (_input as any)[memberName] as any; + + if (memberTraits.httpPayload) { + const isStreaming = memberNs.isStreaming(); + if (isStreaming) { + const isEventStream = memberNs.isStructSchema(); + if (isEventStream) { + // todo(schema) + throw new Error("serialization of event streams is not yet implemented"); + } else { + // streaming blob body + payload = inputMember; + } + } else { + // structural/document body + serializer.write(memberNs, inputMember); + payload = serializer.flush(); + } + } else if (memberTraits.httpLabel) { + serializer.write(memberNs, inputMember); + const replacement = serializer.flush() as string; + if (request.path.includes(`{${memberName}+}`)) { + request.path = request.path.replace( + `{${memberName}+}`, + replacement.split("/").map(extendedEncodeURIComponent).join("/") + ); + } else if (request.path.includes(`{${memberName}}`)) { + request.path = request.path.replace(`{${memberName}}`, extendedEncodeURIComponent(replacement)); + } + delete _input[memberName]; + } else if (memberTraits.httpHeader) { + serializer.write(memberNs, inputMember); + headers[memberTraits.httpHeader.toLowerCase() as string] = String(serializer.flush()); + delete _input[memberName]; + } else if (typeof memberTraits.httpPrefixHeaders === "string") { + for (const [key, val] of Object.entries(inputMember)) { + const amalgam = memberTraits.httpPrefixHeaders + key; + serializer.write([memberNs.getValueSchema(), { httpHeader: amalgam }], val); + headers[amalgam.toLowerCase()] = serializer.flush() as string; + } + delete _input[memberName]; + } else if (memberTraits.httpQuery || memberTraits.httpQueryParams) { + this.serializeQuery(memberNs, inputMember, query); + delete _input[memberName]; + } else { + hasNonHttpBindingMember = true; + } + } + + if (hasNonHttpBindingMember && input) { + serializer.write(schema, _input); + payload = serializer.flush() as Uint8Array; + } + + request.headers = headers; + request.query = query; + request.body = payload; + + return request; + } + + protected serializeQuery(ns: NormalizedSchema, data: any, query: HttpRequest["query"]) { + const serializer = this.serializer; + const traits = ns.getMergedTraits(); + + if (traits.httpQueryParams) { + for (const [key, val] of Object.entries(data)) { + if (!(key in query)) { + this.serializeQuery( + NormalizedSchema.of([ + ns.getValueSchema(), + { + // We pass on the traits to the sub-schema + // because we are still in the process of serializing the map itself. + ...traits, + httpQuery: key, + httpQueryParams: undefined, + }, + ]), + val, + query + ); + } + } + return; + } + + if (ns.isListSchema()) { + const sparse = !!ns.getMergedTraits().sparse; + const buffer = []; + for (const item of data) { + // We pass on the traits to the sub-schema + // because we are still in the process of serializing the list itself. + serializer.write([ns.getValueSchema(), traits], item); + const serializable = serializer.flush() as string; + if (sparse || serializable !== undefined) { + buffer.push(serializable); + } + } + query[traits.httpQuery as string] = buffer; + } else { + serializer.write([ns, traits], data); + query[traits.httpQuery as string] = serializer.flush() as string; + } + } + + public async deserializeResponse( + operationSchema: OperationSchema, + context: HandlerExecutionContext & SerdeFunctions, + response: IHttpResponse + ): Promise { + const deserializer = this.deserializer; + const ns = NormalizedSchema.of(operationSchema.output); + + const dataObject: any = {}; + + if (response.statusCode >= 300) { + const bytes: Uint8Array = await collectBody(response.body, context); + if (bytes.byteLength > 0) { + Object.assign(dataObject, await deserializer.read(SCHEMA.DOCUMENT, bytes)); + } + await this.handleError(operationSchema, context, response, dataObject, this.deserializeMetadata(response)); + throw new Error("@smithy/core/protocols - HTTP Protocol error handler failed to throw."); + } + + for (const header in response.headers) { + const value = response.headers[header]; + delete response.headers[header]; + response.headers[header.toLowerCase()] = value; + } + + const headerBindings = new Set( + Object.values(ns.getMemberSchemas()) + .map((schema) => { + return schema.getMergedTraits().httpHeader; + }) + .filter(Boolean) as string[] + ); + + const nonHttpBindingMembers = await this.deserializeHttpMessage(ns, context, response, headerBindings, dataObject); + + if (nonHttpBindingMembers.length) { + const bytes: Uint8Array = await collectBody(response.body, context); + if (bytes.byteLength > 0) { + const dataFromBody = await deserializer.read(ns, bytes); + for (const member of nonHttpBindingMembers) { + dataObject[member] = dataFromBody[member]; + } + } + } + + const output: Output = { + $metadata: this.deserializeMetadata(response), + ...dataObject, + }; + + return output; + } +} diff --git a/packages/core/src/submodules/protocols/HttpProtocol.ts b/packages/core/src/submodules/protocols/HttpProtocol.ts new file mode 100644 index 00000000000..c8ae7307592 --- /dev/null +++ b/packages/core/src/submodules/protocols/HttpProtocol.ts @@ -0,0 +1,240 @@ +import { NormalizedSchema, SCHEMA } from "@smithy/core/schema"; +import { splitEvery, splitHeader } from "@smithy/core/serde"; +import { HttpRequest, HttpResponse } from "@smithy/protocol-http"; +import { + ClientProtocol, + Codec, + Endpoint, + EndpointBearer, + EndpointV2, + EventStreamSerdeContext, + HandlerExecutionContext, + HttpRequest as IHttpRequest, + HttpResponse as IHttpResponse, + MetadataBearer, + OperationSchema, + ResponseMetadata, + Schema, + SerdeFunctions, + ShapeDeserializer, + ShapeSerializer, +} from "@smithy/types"; +import { sdkStreamMixin } from "@smithy/util-stream"; + +import { collectBody } from "./collect-stream-body"; + +/** + * Abstract base for HTTP-based client protocols. + * + * @alpha + */ +export abstract class HttpProtocol implements ClientProtocol { + protected abstract serializer: ShapeSerializer; + protected abstract deserializer: ShapeDeserializer; + protected serdeContext?: SerdeFunctions; + + protected constructor( + public readonly options: { + defaultNamespace: string; + } + ) {} + + public abstract getShapeId(): string; + + public abstract getPayloadCodec(): Codec; + + public getRequestType(): new (...args: any[]) => IHttpRequest { + return HttpRequest; + } + + public getResponseType(): new (...args: any[]) => IHttpResponse { + return HttpResponse; + } + + public setSerdeContext(serdeContext: SerdeFunctions): void { + this.serdeContext = serdeContext; + this.serializer.setSerdeContext(serdeContext); + this.deserializer.setSerdeContext(serdeContext); + if (this.getPayloadCodec()) { + this.getPayloadCodec().setSerdeContext(serdeContext); + } + } + + public abstract serializeRequest( + operationSchema: OperationSchema, + input: Input, + context: HandlerExecutionContext & SerdeFunctions & EndpointBearer + ): Promise; + + public updateServiceEndpoint(request: IHttpRequest, endpoint: EndpointV2 | Endpoint) { + if ("url" in endpoint) { + request.protocol = endpoint.url.protocol; + request.hostname = endpoint.url.hostname; + request.port = endpoint.url.port ? Number(endpoint.url.port) : undefined; + request.path = endpoint.url.pathname; + request.fragment = endpoint.url.hash || void 0; + request.username = endpoint.url.username || void 0; + request.password = endpoint.url.password || void 0; + for (const [k, v] of endpoint.url.searchParams.entries()) { + if (!request.query) { + request.query = {}; + } + request.query[k] = v; + } + return request; + } else { + request.protocol = endpoint.protocol; + request.hostname = endpoint.hostname; + request.port = endpoint.port ? Number(endpoint.port) : undefined; + request.path = endpoint.path; + request.query = { + ...endpoint.query, + }; + return request; + } + } + + public abstract deserializeResponse( + operationSchema: OperationSchema, + context: HandlerExecutionContext & SerdeFunctions, + response: IHttpResponse + ): Promise; + + protected setHostPrefix( + request: IHttpRequest, + operationSchema: OperationSchema, + input: Input + ): void { + const operationNs = NormalizedSchema.of(operationSchema); + const inputNs = NormalizedSchema.of(operationSchema.input); + if (operationNs.getMergedTraits().endpoint) { + let hostPrefix = operationNs.getMergedTraits().endpoint?.[0]; + if (typeof hostPrefix === "string") { + const hostLabelInputs = Object.entries(inputNs.getMemberSchemas()).filter( + ([, member]) => member.getMergedTraits().hostLabel + ); + for (const [name] of hostLabelInputs) { + const replacement = input[name as keyof typeof input]; + if (typeof replacement !== "string") { + throw new Error(`@smithy/core/schema - ${name} in input must be a string as hostLabel.`); + } + hostPrefix = hostPrefix.replace(`{${name}}`, replacement); + } + request.hostname = hostPrefix + request.hostname; + } + } + } + + protected abstract handleError( + operationSchema: OperationSchema, + context: HandlerExecutionContext & SerdeFunctions, + response: IHttpResponse, + dataObject: any, + metadata: ResponseMetadata + ): Promise; + + protected deserializeMetadata(output: IHttpResponse): ResponseMetadata { + return { + httpStatusCode: output.statusCode, + requestId: + output.headers["x-amzn-requestid"] ?? output.headers["x-amzn-request-id"] ?? output.headers["x-amz-request-id"], + extendedRequestId: output.headers["x-amz-id-2"], + cfId: output.headers["x-amz-cf-id"], + }; + } + + protected async deserializeHttpMessage( + schema: Schema, + context: HandlerExecutionContext & SerdeFunctions, + response: IHttpResponse, + headerBindings: Set, + dataObject: any + ): Promise { + const deserializer = this.deserializer; + const ns = NormalizedSchema.of(schema); + const nonHttpBindingMembers = [] as string[]; + + for (const [memberName, memberSchema] of Object.entries(ns.getMemberSchemas())) { + const memberTraits = memberSchema.getMemberTraits(); + + if (memberTraits.httpPayload) { + const isStreaming = memberSchema.isStreaming(); + if (isStreaming) { + const isEventStream = memberSchema.isStructSchema(); + if (isEventStream) { + // streaming event stream (union) + const context = this.serdeContext as unknown as EventStreamSerdeContext; + if (!context.eventStreamMarshaller) { + throw new Error("@smithy/core - HttpProtocol: eventStreamMarshaller missing in serdeContext."); + } + const memberSchemas = memberSchema.getMemberSchemas(); + dataObject[memberName] = context.eventStreamMarshaller.deserialize(response.body, async (event) => { + const unionMember = + Object.keys(event).find((key) => { + return key !== "__type"; + }) ?? ""; + if (unionMember in memberSchemas) { + const eventStreamSchema = memberSchemas[unionMember]; + return { + [unionMember]: await deserializer.read(eventStreamSchema, event[unionMember].body), + }; + } else { + // this union convention is ignored by the event stream marshaller. + return { + $unknown: event, + }; + } + }); + } else { + // streaming blob body + dataObject[memberName] = sdkStreamMixin(response.body); + } + } else if (response.body) { + const bytes: Uint8Array = await collectBody(response.body, context as SerdeFunctions); + if (bytes.byteLength > 0) { + dataObject[memberName] = await deserializer.read(memberSchema, bytes); + } + } + } else if (memberTraits.httpHeader) { + const key = String(memberTraits.httpHeader).toLowerCase(); + const value = response.headers[key]; + if (null != value) { + if (memberSchema.isListSchema()) { + const headerListValueSchema = memberSchema.getValueSchema(); + let sections: string[]; + if ( + headerListValueSchema.isTimestampSchema() && + headerListValueSchema.getSchema() === SCHEMA.TIMESTAMP_DEFAULT + ) { + sections = splitEvery(value, ",", 2); + } else { + sections = splitHeader(value); + } + const list = []; + for (const section of sections) { + list.push(await deserializer.read([headerListValueSchema, { httpHeader: key }], section.trim())); + } + dataObject[memberName] = list; + } else { + dataObject[memberName] = await deserializer.read(memberSchema, value); + } + } + } else if (memberTraits.httpPrefixHeaders !== undefined) { + dataObject[memberName] = {}; + for (const [header, value] of Object.entries(response.headers)) { + if (!headerBindings.has(header) && header.startsWith(memberTraits.httpPrefixHeaders)) { + dataObject[memberName][header.slice(memberTraits.httpPrefixHeaders.length)] = await deserializer.read( + [memberSchema.getValueSchema(), { httpHeader: header }], + value + ); + } + } + } else if (memberTraits.httpResponseCode) { + dataObject[memberName] = response.statusCode; + } else { + nonHttpBindingMembers.push(memberName); + } + } + return nonHttpBindingMembers; + } +} diff --git a/packages/core/src/submodules/protocols/RpcProtocol.ts b/packages/core/src/submodules/protocols/RpcProtocol.ts new file mode 100644 index 00000000000..8470bcf4ccf --- /dev/null +++ b/packages/core/src/submodules/protocols/RpcProtocol.ts @@ -0,0 +1,108 @@ +import { NormalizedSchema, SCHEMA } from "@smithy/core/schema"; +import { HttpRequest } from "@smithy/protocol-http"; +import { + Endpoint, + EndpointBearer, + HandlerExecutionContext, + HttpRequest as IHttpRequest, + HttpResponse as IHttpResponse, + MetadataBearer, + OperationSchema, + SerdeFunctions, +} from "@smithy/types"; + +import { collectBody } from "./collect-stream-body"; +import { HttpProtocol } from "./HttpProtocol"; + +/** + * Abstract base for RPC-over-HTTP protocols. + * + * @alpha + */ +export abstract class RpcProtocol extends HttpProtocol { + public async serializeRequest( + operationSchema: OperationSchema, + input: Input, + context: HandlerExecutionContext & SerdeFunctions & EndpointBearer + ): Promise { + const serializer = this.serializer; + const query = {} as Record; + const headers = {} as Record; + const endpoint: Endpoint = await context.endpoint(); + + const ns = NormalizedSchema.of(operationSchema?.input); + const schema = ns.getSchema(); + + let payload: any; + + const request = new HttpRequest({ + protocol: "", + hostname: "", + port: undefined, + path: "/", + fragment: undefined, + query: query, + headers: headers, + body: undefined, + }); + + if (endpoint) { + this.updateServiceEndpoint(request, endpoint); + this.setHostPrefix(request, operationSchema, input); + } + + const _input: any = { + ...input, + }; + + if (input) { + serializer.write(schema, _input); + payload = serializer.flush() as Uint8Array; + } + + request.headers = headers; + request.query = query; + request.body = payload; + request.method = "POST"; + + return request; + } + + public async deserializeResponse( + operationSchema: OperationSchema, + context: HandlerExecutionContext & SerdeFunctions, + response: IHttpResponse + ): Promise { + const deserializer = this.deserializer; + const ns = NormalizedSchema.of(operationSchema.output); + + const dataObject: any = {}; + + if (response.statusCode >= 300) { + const bytes: Uint8Array = await collectBody(response.body, context as SerdeFunctions); + if (bytes.byteLength > 0) { + Object.assign(dataObject, await deserializer.read(SCHEMA.DOCUMENT, bytes)); + } + await this.handleError(operationSchema, context, response, dataObject, this.deserializeMetadata(response)); + throw new Error("@smithy/core/protocols - RPC Protocol error handler failed to throw."); + } + + for (const header in response.headers) { + const value = response.headers[header]; + delete response.headers[header]; + response.headers[header.toLowerCase()] = value; + } + + const bytes: Uint8Array = await collectBody(response.body, context as SerdeFunctions); + if (bytes.byteLength > 0) { + Object.assign(dataObject, await deserializer.read(ns, bytes)); + } + + const output: Output = { + $metadata: this.deserializeMetadata(response), + ...dataObject, + }; + + return output; + } +} diff --git a/packages/core/src/submodules/protocols/index.ts b/packages/core/src/submodules/protocols/index.ts index a5de22f1a4a..33a857e25ea 100644 --- a/packages/core/src/submodules/protocols/index.ts +++ b/packages/core/src/submodules/protocols/index.ts @@ -1,4 +1,11 @@ export * from "./collect-stream-body"; export * from "./extended-encode-uri-component"; +export * from "./HttpBindingProtocol"; +export * from "./RpcProtocol"; export * from "./requestBuilder"; export * from "./resolve-path"; +export * from "./serde/FromStringShapeDeserializer"; +export * from "./serde/HttpInterceptingShapeDeserializer"; +export * from "./serde/HttpInterceptingShapeSerializer"; +export * from "./serde/ToStringShapeSerializer"; +export * from "./serde/determineTimestampFormat"; diff --git a/packages/core/src/submodules/protocols/serde/FromStringShapeDeserializer.ts b/packages/core/src/submodules/protocols/serde/FromStringShapeDeserializer.ts new file mode 100644 index 00000000000..c29133232b0 --- /dev/null +++ b/packages/core/src/submodules/protocols/serde/FromStringShapeDeserializer.ts @@ -0,0 +1,84 @@ +import { NormalizedSchema, SCHEMA } from "@smithy/core/schema"; +import { + LazyJsonString, + NumericValue, + parseEpochTimestamp, + parseRfc3339DateTimeWithOffset, + parseRfc7231DateTime, + splitHeader, +} from "@smithy/core/serde"; +import { CodecSettings, Schema, SerdeFunctions, ShapeDeserializer } from "@smithy/types"; +import { fromBase64 } from "@smithy/util-base64"; +import { toUtf8 } from "@smithy/util-utf8"; + +import { determineTimestampFormat } from "./determineTimestampFormat"; + +/** + * This deserializer reads strings. + * + * @alpha + */ +export class FromStringShapeDeserializer implements ShapeDeserializer { + private serdeContext: SerdeFunctions | undefined; + + public constructor(private settings: CodecSettings) {} + + public setSerdeContext(serdeContext: SerdeFunctions): void { + this.serdeContext = serdeContext; + } + + public read(_schema: Schema, data: string): any { + const ns = NormalizedSchema.of(_schema); + if (ns.isListSchema()) { + return splitHeader(data).map((item) => this.read(ns.getValueSchema(), item)); + } + if (ns.isBlobSchema()) { + return (this.serdeContext?.base64Decoder ?? fromBase64)(data); + } + if (ns.isTimestampSchema()) { + const format = determineTimestampFormat(ns, this.settings); + switch (format) { + case SCHEMA.TIMESTAMP_DATE_TIME: + return parseRfc3339DateTimeWithOffset(data); + case SCHEMA.TIMESTAMP_HTTP_DATE: + return parseRfc7231DateTime(data); + case SCHEMA.TIMESTAMP_EPOCH_SECONDS: + return parseEpochTimestamp(data); + default: + console.warn("Missing timestamp format, parsing value with Date constructor:", data); + return new Date(data as string | number); + } + } + + if (ns.isStringSchema()) { + const mediaType = ns.getMergedTraits().mediaType; + let intermediateValue: string | LazyJsonString = data; + if (mediaType) { + if (ns.getMergedTraits().httpHeader) { + intermediateValue = this.base64ToUtf8(intermediateValue); + } + const isJson = mediaType === "application/json" || mediaType.endsWith("+json"); + if (isJson) { + intermediateValue = LazyJsonString.from(intermediateValue); + } + return intermediateValue; + } + } + + switch (true) { + case ns.isNumericSchema(): + return Number(data); + case ns.isBigIntegerSchema(): + return BigInt(data); + case ns.isBigDecimalSchema(): + return new NumericValue(data, "bigDecimal"); + case ns.isBooleanSchema(): + return String(data).toLowerCase() === "true"; + } + return data; + } + + private base64ToUtf8(base64String: string): any { + return (this.serdeContext?.utf8Encoder ?? toUtf8)((this.serdeContext?.base64Decoder ?? fromBase64)(base64String)); + } +} diff --git a/packages/core/src/submodules/protocols/serde/HttpInterceptingShapeDeserializer.ts b/packages/core/src/submodules/protocols/serde/HttpInterceptingShapeDeserializer.ts new file mode 100644 index 00000000000..96e45b69e9a --- /dev/null +++ b/packages/core/src/submodules/protocols/serde/HttpInterceptingShapeDeserializer.ts @@ -0,0 +1,59 @@ +import { NormalizedSchema } from "@smithy/core/schema"; +import { CodecSettings, Schema, SerdeFunctions, ShapeDeserializer } from "@smithy/types"; +import { fromUtf8, toUtf8 } from "@smithy/util-utf8"; + +import { FromStringShapeDeserializer } from "./FromStringShapeDeserializer"; + +/** + * This deserializer is a dispatcher that decides whether to use a string deserializer + * or a codec deserializer based on HTTP traits. + * + * For example, in a JSON HTTP message, the deserialization of a field will differ depending on whether + * it is bound to the HTTP header (string) or body (JSON). + * + * @alpha + */ +export class HttpInterceptingShapeDeserializer> + implements ShapeDeserializer +{ + private stringDeserializer: FromStringShapeDeserializer; + private serdeContext: SerdeFunctions | undefined; + + public constructor( + private codecDeserializer: CodecShapeDeserializer, + codecSettings: CodecSettings + ) { + this.stringDeserializer = new FromStringShapeDeserializer(codecSettings); + } + + public setSerdeContext(serdeContext: SerdeFunctions): void { + this.stringDeserializer.setSerdeContext(serdeContext); + this.codecDeserializer.setSerdeContext(serdeContext); + this.serdeContext = serdeContext; + } + + public read(schema: Schema, data: string | Uint8Array): any | Promise { + const ns = NormalizedSchema.of(schema); + const traits = ns.getMergedTraits(); + const toString = this.serdeContext?.utf8Encoder ?? toUtf8; + + if (traits.httpHeader || traits.httpResponseCode) { + return this.stringDeserializer.read(ns, toString(data)); + } + if (traits.httpPayload) { + if (ns.isBlobSchema()) { + const toBytes = this.serdeContext?.utf8Decoder ?? fromUtf8; + if (typeof data === "string") { + return toBytes(data); + } + return data; + } else if (ns.isStringSchema()) { + if ("byteLength" in (data as Uint8Array)) { + return toString(data); + } + return data; + } + } + return this.codecDeserializer.read(ns, data); + } +} diff --git a/packages/core/src/submodules/protocols/serde/HttpInterceptingShapeSerializer.ts b/packages/core/src/submodules/protocols/serde/HttpInterceptingShapeSerializer.ts new file mode 100644 index 00000000000..85135cffaf0 --- /dev/null +++ b/packages/core/src/submodules/protocols/serde/HttpInterceptingShapeSerializer.ts @@ -0,0 +1,50 @@ +import { NormalizedSchema } from "@smithy/core/schema"; +import { CodecSettings, Schema as ISchema, SerdeFunctions, ShapeSerializer } from "@smithy/types"; + +import { ToStringShapeSerializer } from "./ToStringShapeSerializer"; + +/** + * This serializer decides whether to dispatch to a string serializer or a codec serializer + * depending on HTTP binding traits within the given schema. + * + * For example, a JavaScript array is serialized differently when being written + * to a REST JSON HTTP header (comma-delimited string) and a REST JSON HTTP body (JSON array). + * + * @alpha + */ +export class HttpInterceptingShapeSerializer> + implements ShapeSerializer +{ + private buffer: string | undefined; + + public constructor( + private codecSerializer: CodecShapeSerializer, + codecSettings: CodecSettings, + private stringSerializer = new ToStringShapeSerializer(codecSettings) + ) {} + + public setSerdeContext(serdeContext: SerdeFunctions): void { + this.codecSerializer.setSerdeContext(serdeContext); + this.stringSerializer.setSerdeContext(serdeContext); + } + + public write(schema: ISchema, value: unknown): void { + const ns = NormalizedSchema.of(schema); + const traits = ns.getMergedTraits(); + if (traits.httpHeader || traits.httpLabel || traits.httpQuery) { + this.stringSerializer.write(ns, value); + this.buffer = this.stringSerializer.flush(); + return; + } + return this.codecSerializer.write(ns, value); + } + + public flush(): string | Uint8Array { + if (this.buffer !== undefined) { + const buffer = this.buffer; + this.buffer = undefined; + return buffer; + } + return this.codecSerializer.flush(); + } +} diff --git a/packages/core/src/submodules/protocols/serde/ToStringShapeSerializer.ts b/packages/core/src/submodules/protocols/serde/ToStringShapeSerializer.ts new file mode 100644 index 00000000000..0d513ea5e23 --- /dev/null +++ b/packages/core/src/submodules/protocols/serde/ToStringShapeSerializer.ts @@ -0,0 +1,99 @@ +import { NormalizedSchema, SCHEMA } from "@smithy/core/schema"; +import { dateToUtcString, LazyJsonString, quoteHeader } from "@smithy/core/serde"; +import { CodecSettings, Schema, SerdeFunctions, ShapeSerializer } from "@smithy/types"; +import { toBase64 } from "@smithy/util-base64"; + +import { determineTimestampFormat } from "./determineTimestampFormat"; + +/** + * Serializes a shape to string. + * + * @alpha + */ +export class ToStringShapeSerializer implements ShapeSerializer { + private stringBuffer = ""; + private serdeContext: SerdeFunctions | undefined = undefined; + + public constructor(private settings: CodecSettings) {} + + public setSerdeContext(serdeContext: SerdeFunctions): void { + this.serdeContext = serdeContext; + } + + public write(schema: Schema, value: unknown): void { + const ns = NormalizedSchema.of(schema); + switch (typeof value) { + case "object": + if (value === null) { + this.stringBuffer = "null"; + return; + } + if (ns.isTimestampSchema()) { + if (!(value instanceof Date)) { + throw new Error( + `@smithy/core/protocols - received non-Date value ${value} when schema expected Date in ${ns.getName(true)}` + ); + } + const format = determineTimestampFormat(ns, this.settings); + switch (format) { + case SCHEMA.TIMESTAMP_DATE_TIME: + this.stringBuffer = value.toISOString().replace(".000Z", "Z"); + break; + case SCHEMA.TIMESTAMP_HTTP_DATE: + this.stringBuffer = dateToUtcString(value); + break; + case SCHEMA.TIMESTAMP_EPOCH_SECONDS: + this.stringBuffer = String(value.getTime() / 1000); + break; + default: + console.warn("Missing timestamp format, using epoch seconds", value); + this.stringBuffer = String(value.getTime() / 1000); + } + return; + } + if (ns.isBlobSchema() && "byteLength" in (value as Uint8Array)) { + this.stringBuffer = (this.serdeContext?.base64Encoder ?? toBase64)(value as Uint8Array); + return; + } + if (ns.isListSchema() && Array.isArray(value)) { + let buffer = ""; + for (const item of value) { + this.write([ns.getValueSchema(), ns.getMergedTraits()], item); + const headerItem = this.flush(); + const serialized = ns.getValueSchema().isTimestampSchema() ? headerItem : quoteHeader(headerItem); + if (buffer !== "") { + buffer += ", "; + } + buffer += serialized; + } + this.stringBuffer = buffer; + return; + } + this.stringBuffer = JSON.stringify(value, null, 2); + break; + case "string": + const mediaType = ns.getMergedTraits().mediaType; + let intermediateValue: string | LazyJsonString = value; + if (mediaType) { + const isJson = mediaType === "application/json" || mediaType.endsWith("+json"); + if (isJson) { + intermediateValue = LazyJsonString.from(intermediateValue); + } + if (ns.getMergedTraits().httpHeader) { + this.stringBuffer = (this.serdeContext?.base64Encoder ?? toBase64)(intermediateValue.toString()); + return; + } + } + this.stringBuffer = value; + break; + default: + this.stringBuffer = String(value); + } + } + + public flush(): string { + const buffer = this.stringBuffer; + this.stringBuffer = ""; + return buffer; + } +} diff --git a/packages/core/src/submodules/protocols/serde/determineTimestampFormat.ts b/packages/core/src/submodules/protocols/serde/determineTimestampFormat.ts new file mode 100644 index 00000000000..74a64036f21 --- /dev/null +++ b/packages/core/src/submodules/protocols/serde/determineTimestampFormat.ts @@ -0,0 +1,40 @@ +import { NormalizedSchema, SCHEMA } from "@smithy/core/schema"; +import { + type TimestampDateTimeSchema, + type TimestampEpochSecondsSchema, + type TimestampHttpDateSchema, + CodecSettings, +} from "@smithy/types"; + +/** + * Assuming the schema is a timestamp type, the function resolves the format using + * either the timestamp's own traits, or the default timestamp format from the CodecSettings. + * + * @internal + */ +export function determineTimestampFormat( + ns: NormalizedSchema, + settings: CodecSettings +): TimestampDateTimeSchema | TimestampHttpDateSchema | TimestampEpochSecondsSchema { + if (settings.timestampFormat.useTrait) { + if ( + ns.isTimestampSchema() && + (ns.getSchema() === SCHEMA.TIMESTAMP_DATE_TIME || + ns.getSchema() === SCHEMA.TIMESTAMP_HTTP_DATE || + ns.getSchema() === SCHEMA.TIMESTAMP_EPOCH_SECONDS) + ) { + return ns.getSchema() as TimestampDateTimeSchema | TimestampHttpDateSchema | TimestampEpochSecondsSchema; + } + } + + const { httpLabel, httpPrefixHeaders, httpHeader, httpQuery } = ns.getMergedTraits(); + const bindingFormat = settings.httpBindings + ? typeof httpPrefixHeaders === "string" || Boolean(httpHeader) + ? SCHEMA.TIMESTAMP_HTTP_DATE + : Boolean(httpQuery) || Boolean(httpLabel) + ? SCHEMA.TIMESTAMP_DATE_TIME + : undefined + : undefined; + + return bindingFormat ?? settings.timestampFormat.default; +} diff --git a/packages/core/src/submodules/serde/copyDocumentWithTransform.ts b/packages/core/src/submodules/serde/copyDocumentWithTransform.ts new file mode 100644 index 00000000000..d74a2a7291b --- /dev/null +++ b/packages/core/src/submodules/serde/copyDocumentWithTransform.ts @@ -0,0 +1,61 @@ +import { NormalizedSchema } from "@smithy/core/schema"; +import { SchemaRef } from "@smithy/types"; + +/** + * @internal + */ +export const copyDocumentWithTransform = ( + source: any, + schemaRef: SchemaRef, + transform: (_: any, schemaRef: SchemaRef) => any = (_) => _ +): any => { + const ns = NormalizedSchema.of(schemaRef); + switch (typeof source) { + case "undefined": + case "boolean": + case "number": + case "string": + case "bigint": + case "symbol": + return transform(source, ns); + case "function": + case "object": + if (source === null) { + return transform(null, ns); + } + if (Array.isArray(source)) { + const newArray = new Array(source.length); + let i = 0; + for (const item of source) { + newArray[i++] = copyDocumentWithTransform(item, ns.getValueSchema(), transform); + } + return transform(newArray, ns); + } + if ("byteLength" in (source as Uint8Array)) { + const newBytes = new Uint8Array(source.byteLength); + newBytes.set(source, 0); + return transform(newBytes, ns); + } + if (source instanceof Date) { + return transform(source, ns); + } + const newObject = {} as any; + if (ns.isMapSchema()) { + for (const key of Object.keys(source)) { + newObject[key] = copyDocumentWithTransform(source[key], ns.getValueSchema(), transform); + } + } else if (ns.isStructSchema()) { + for (const [key, memberSchema] of Object.entries(ns.getMemberSchemas())) { + newObject[key] = copyDocumentWithTransform(source[key], memberSchema, transform); + } + } else if (ns.isDocumentSchema()) { + for (const key of Object.keys(source)) { + newObject[key] = copyDocumentWithTransform(source[key], ns.getValueSchema(), transform); + } + } + + return transform(newObject, ns); + default: + return transform(source, ns); + } +}; diff --git a/packages/core/src/submodules/serde/index.ts b/packages/core/src/submodules/serde/index.ts index e0d4e038320..ae7219a89d2 100644 --- a/packages/core/src/submodules/serde/index.ts +++ b/packages/core/src/submodules/serde/index.ts @@ -1,7 +1,8 @@ -export * from "./parse-utils"; +export * from "./copyDocumentWithTransform"; export * from "./date-utils"; +export * from "./lazy-json"; +export * from "./parse-utils"; export * from "./quote-header"; +export * from "./split-every"; export * from "./split-header"; export * from "./value/NumericValue"; -export * from "./lazy-json"; -export * from "./split-every"; diff --git a/yarn.lock b/yarn.lock index 70eccb3bd59..6a1afc68a7b 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2449,6 +2449,7 @@ __metadata: "@smithy/middleware-serde": "workspace:^" "@smithy/protocol-http": "workspace:^" "@smithy/types": "workspace:^" + "@smithy/util-base64": "workspace:^" "@smithy/util-body-length-browser": "workspace:^" "@smithy/util-middleware": "workspace:^" "@smithy/util-stream": "workspace:^" From 7dfd2931c1232cd8f89a76e6c36a9db6ef239f73 Mon Sep 17 00:00:00 2001 From: George Fu Date: Wed, 21 May 2025 16:58:03 -0400 Subject: [PATCH 2/2] add structIterator generator --- .changeset/two-berries-repeat.md | 5 +++++ .../core/src/submodules/cbor/CborCodec.ts | 18 ++++++++-------- .../src/submodules/protocols/HttpProtocol.ts | 4 ++-- .../schema/schemas/NormalizedSchema.spec.ts | 14 +++++++++++++ .../schema/schemas/NormalizedSchema.ts | 21 +++++++++++++++++++ .../serde/copyDocumentWithTransform.ts | 2 +- packages/smithy-client/src/command.ts | 16 ++++++++++++++ 7 files changed, 68 insertions(+), 12 deletions(-) create mode 100644 .changeset/two-berries-repeat.md diff --git a/.changeset/two-berries-repeat.md b/.changeset/two-berries-repeat.md new file mode 100644 index 00000000000..cb9026cca9b --- /dev/null +++ b/.changeset/two-berries-repeat.md @@ -0,0 +1,5 @@ +--- +"@smithy/smithy-client": minor +--- + +add schema property to Command class diff --git a/packages/core/src/submodules/cbor/CborCodec.ts b/packages/core/src/submodules/cbor/CborCodec.ts index e4ab36f57ac..4624a4520f8 100644 --- a/packages/core/src/submodules/cbor/CborCodec.ts +++ b/packages/core/src/submodules/cbor/CborCodec.ts @@ -51,14 +51,18 @@ export class CborShapeSerializer implements ShapeSerializer { const ns = NormalizedSchema.of(schemaRef); const sparse = !!ns.getMergedTraits().sparse; - if (Array.isArray(_)) { + if (ns.isListSchema() && Array.isArray(_)) { if (!sparse) { return _.filter((item) => item != null); } } else if (_ && typeof _ === "object") { - if (!sparse || ns.isStructSchema()) { + const members = ns.getMemberSchemas(); + const isStruct = ns.isStructSchema(); + if (!sparse || isStruct) { for (const [k, v] of Object.entries(_)) { - if (v == null) { + const filteredOutByNonSparse = !sparse && v == null; + const filteredOutByUnrecognizedMember = isStruct && !(k in members); + if (filteredOutByNonSparse || filteredOutByUnrecognizedMember) { delete _[k]; } } @@ -157,12 +161,8 @@ export class CborShapeDeserializer implements ShapeDeserializer { } } } else if (ns.isStructSchema()) { - for (const key of Object.keys(value)) { - const targetSchema = ns.getMemberSchema(key); - if (targetSchema === undefined) { - continue; - } - newObject[key] = this.readValue(targetSchema, value[key]); + for (const [key, memberSchema] of ns.structIterator()) { + newObject[key] = this.readValue(memberSchema, value[key]); } } return newObject; diff --git a/packages/core/src/submodules/protocols/HttpProtocol.ts b/packages/core/src/submodules/protocols/HttpProtocol.ts index c8ae7307592..c8e754890e2 100644 --- a/packages/core/src/submodules/protocols/HttpProtocol.ts +++ b/packages/core/src/submodules/protocols/HttpProtocol.ts @@ -110,7 +110,7 @@ export abstract class HttpProtocol implements ClientProtocol member.getMergedTraits().hostLabel ); for (const [name] of hostLabelInputs) { @@ -154,7 +154,7 @@ export abstract class HttpProtocol implements ClientProtocol { }); }); + describe("iteration", () => { + it("iterates over member schemas", () => { + const iteration = Array.from(ns.structIterator()); + const entries = Object.entries(ns.getMemberSchemas()); + for (let i = 0; i < iteration.length; i++) { + const [name, schema] = iteration[i]; + const [entryName, entrySchema] = entries[i]; + expect(name).toBe(entryName); + expect(schema.getMemberName()).toEqual(entrySchema.getMemberName()); + expect(schema.getMergedTraits()).toEqual(entrySchema.getMergedTraits()); + } + }); + }); + describe("traits", () => { const member: MemberSchema = [sim("ack", "SimpleString", 0, { idempotencyToken: 1 }), 0b0000_0001]; const ns = NormalizedSchema.of(member, "member_name"); diff --git a/packages/core/src/submodules/schema/schemas/NormalizedSchema.ts b/packages/core/src/submodules/schema/schemas/NormalizedSchema.ts index 8e0f688af78..fb65e8a0a65 100644 --- a/packages/core/src/submodules/schema/schemas/NormalizedSchema.ts +++ b/packages/core/src/submodules/schema/schemas/NormalizedSchema.ts @@ -371,6 +371,11 @@ export class NormalizedSchema implements INormalizedSchema { } /** + * This can be used for checking the members as a hashmap. + * Prefer the structIterator method for iteration. + * + * This does NOT return list and map members, it is only for structures. + * * @returns a map of member names to member schemas (normalized). */ public getMemberSchemas(): Record { @@ -389,6 +394,22 @@ export class NormalizedSchema implements INormalizedSchema { return {}; } + /** + * Allows iteration over members of a structure schema. + * Each yield is a pair of the member name and member schema. + * + * This avoids the overhead of calling Object.entries(ns.getMemberSchemas()). + */ + public *structIterator(): Generator<[string, NormalizedSchema], undefined, undefined> { + if (!this.isStructSchema()) { + throw new Error("@smithy/core/schema - cannot acquire structIterator on non-struct schema."); + } + const struct = this.getSchema() as StructureSchema; + for (let i = 0; i < struct.memberNames.length; ++i) { + yield [struct.memberNames[i], NormalizedSchema.memberFrom([struct.memberList[i], 0], struct.memberNames[i])]; + } + } + /** * @returns a last-resort human-readable name for the schema if it has no other identifiers. */ diff --git a/packages/core/src/submodules/serde/copyDocumentWithTransform.ts b/packages/core/src/submodules/serde/copyDocumentWithTransform.ts index d74a2a7291b..f5146a231d8 100644 --- a/packages/core/src/submodules/serde/copyDocumentWithTransform.ts +++ b/packages/core/src/submodules/serde/copyDocumentWithTransform.ts @@ -45,7 +45,7 @@ export const copyDocumentWithTransform = ( newObject[key] = copyDocumentWithTransform(source[key], ns.getValueSchema(), transform); } } else if (ns.isStructSchema()) { - for (const [key, memberSchema] of Object.entries(ns.getMemberSchemas())) { + for (const [key, memberSchema] of ns.structIterator()) { newObject[key] = copyDocumentWithTransform(source[key], memberSchema, transform); } } else if (ns.isDocumentSchema()) { diff --git a/packages/smithy-client/src/command.ts b/packages/smithy-client/src/command.ts index b100c8233e1..e486e41f613 100644 --- a/packages/smithy-client/src/command.ts +++ b/packages/smithy-client/src/command.ts @@ -11,6 +11,8 @@ import type { Logger, MetadataBearer, MiddlewareStack as IMiddlewareStack, + Mutable, + OperationSchema, OptionalParameter, Pluggable, RequestHandler, @@ -31,6 +33,7 @@ export abstract class Command< { public abstract input: Input; public readonly middlewareStack: IMiddlewareStack = constructStack(); + public readonly schema?: OperationSchema; /** * Factory for Command ClassBuilder. @@ -131,6 +134,8 @@ class ClassBuilder< private _outputFilterSensitiveLog = (_: any) => _; private _serializer: (input: I, context: SerdeContext | any) => Promise = null as any; private _deserializer: (output: IHttpResponse, context: SerdeContext | any) => Promise = null as any; + private _operationSchema?: OperationSchema; + /** * Optional init callback. */ @@ -212,6 +217,16 @@ class ClassBuilder< this._deserializer = deserializer; return this; } + + /** + * Sets input/output schema for the operation. + */ + public sc(operation: OperationSchema): ClassBuilder { + this._operationSchema = operation; + this._smithyContext.operationSchema = operation; + return this; + } + /** * @returns a Command class with the classBuilder properties. */ @@ -241,6 +256,7 @@ class ClassBuilder< super(); this.input = input ?? ({} as unknown as I); closure._init(this); + (this as Mutable).schema = closure._operationSchema; } /**