diff --git a/modules/encrypt-browser/src/encrypt.ts b/modules/encrypt-browser/src/encrypt.ts index 39e51f3b5..366fbc42a 100644 --- a/modules/encrypt-browser/src/encrypt.ts +++ b/modules/encrypt-browser/src/encrypt.ts @@ -113,6 +113,7 @@ export async function encrypt ( /* The final frame has a variable length. * The value needs to be known, but should only be calculated once. * So I calculate how much of a frame I should have at the end. + * This value will NEVER be larger than the frameLength. */ const finalFrameLength = frameLength - ((numberOfFrames * frameLength) - plaintextLength) const bodyContent = [] diff --git a/modules/encrypt-node/src/framed_encrypt_stream.ts b/modules/encrypt-node/src/framed_encrypt_stream.ts index 5a039413d..3fac3d002 100644 --- a/modules/encrypt-node/src/framed_encrypt_stream.ts +++ b/modules/encrypt-node/src/framed_encrypt_stream.ts @@ -164,7 +164,7 @@ export function getFramedEncryptStream (getCipher: GetCipher, messageHeader: Mes /* Push the authTag onto the end. Yes, I am abusing the name. */ cipherContent.push(cipher.getAuthTag()) - needs(frameSize === frameLength || isFinalFrame, 'Malformed frame') + needs(frameSize === frameLength || (isFinalFrame && frameLength >= frameSize), 'Malformed frame') for (const cipherText of cipherContent) { if (!this.push(cipherText)) { @@ -190,11 +190,18 @@ type EncryptFrameInput = { export function getEncryptFrame (input: EncryptFrameInput): EncryptFrame { const { pendingFrame, messageHeader, getCipher, isFinalFrame } = input const { sequenceNumber, contentLength, content } = pendingFrame - const frameIv = serialize.frameIv(messageHeader.headerIvLength, sequenceNumber) + const { frameLength, contentType, messageId, headerIvLength } = messageHeader + /* Precondition: The content length MUST correlate with the frameLength. + * In the case of a regular frame, + * the content length MUST strictly equal the frame length. + * In the case of the final frame, + * it MUST NOT be larger than the frame length. + */ + needs(frameLength === contentLength || (isFinalFrame && frameLength >= contentLength), `Malformed frame length and content length: ${JSON.stringify({ frameLength, contentLength, isFinalFrame })}`) + const frameIv = serialize.frameIv(headerIvLength, sequenceNumber) const bodyHeader = Buffer.from(isFinalFrame ? finalFrameHeader(sequenceNumber, frameIv, contentLength) : frameHeader(sequenceNumber, frameIv)) - const { contentType, messageId } = messageHeader const contentString = aadUtility.messageAADContentString({ contentType, isFinalFrame }) const { buffer, byteOffset, byteLength } = aadUtility.messageAAD(messageId, contentString, sequenceNumber, contentLength) const cipher = getCipher(frameIv) diff --git a/modules/encrypt-node/test/framed_encrypt_stream.test.ts b/modules/encrypt-node/test/framed_encrypt_stream.test.ts index 85ac6f142..7d5ed1004 100644 --- a/modules/encrypt-node/test/framed_encrypt_stream.test.ts +++ b/modules/encrypt-node/test/framed_encrypt_stream.test.ts @@ -17,7 +17,7 @@ import * as chai from 'chai' import chaiAsPromised from 'chai-as-promised' -import { getFramedEncryptStream } from '../src/framed_encrypt_stream' +import { getFramedEncryptStream, getEncryptFrame } from '../src/framed_encrypt_stream' chai.use(chaiAsPromised) const { expect } = chai @@ -60,3 +60,90 @@ describe('getFramedEncryptStream', () => { expect(called).to.equal(true) }) }) + +describe('getEncryptFrame', () => { + it('can return an EncryptFrame', () => { + const input = { + pendingFrame: { + content: [Buffer.from([1, 2, 3, 4, 5])], + contentLength: 5, + sequenceNumber: 1 + }, + isFinalFrame: false, + getCipher: () => ({ setAAD: () => {} }) as any, + messageHeader: { + frameLength: 5, + contentType: 2, + messageId: Buffer.from([]), + headerIvLength: 12 as 12, + version: 1, + type: 12, + suiteId: 1, + encryptionContext: {}, + encryptedDataKeys: [] + } + } + const test1 = getEncryptFrame(input) + expect(test1.content).to.equal(input.pendingFrame.content) + expect(test1.isFinalFrame).to.equal(input.isFinalFrame) + + // Just a quick flip to make sure... + input.isFinalFrame = true + const test2 = getEncryptFrame(input) + expect(test2.content).to.equal(input.pendingFrame.content) + expect(test2.isFinalFrame).to.equal(input.isFinalFrame) + }) + + it('Precondition: The content length MUST correlate with the frameLength.', () => { + const inputFinalFrameToLarge = { + pendingFrame: { + content: [Buffer.from([1, 2, 3, 4, 5, 6])], + // This exceeds the frameLength below + contentLength: 6, + sequenceNumber: 1 + }, + isFinalFrame: true, + getCipher: () => ({ setAAD: () => {} }) as any, + messageHeader: { + frameLength: 5, + contentType: 2, + messageId: Buffer.from([]), + headerIvLength: 12 as 12, + version: 1, + type: 12, + suiteId: 1, + encryptionContext: {}, + encryptedDataKeys: [] + } + } + + expect(() => getEncryptFrame(inputFinalFrameToLarge)).to.throw('Malformed frame length and content length:') + + const inputFrame = { + pendingFrame: { + content: [Buffer.from([1, 2, 3, 4, 5])], + contentLength: 5, + sequenceNumber: 1 + }, + isFinalFrame: false, + getCipher: () => ({ setAAD: () => {} }) as any, + messageHeader: { + frameLength: 5, + contentType: 2, + messageId: Buffer.from([]), + headerIvLength: 12 as 12, + version: 1, + type: 12, + suiteId: 1, + encryptionContext: {}, + encryptedDataKeys: [] + } + } + + // Make sure that it must be equal as long as we are here... + inputFrame.pendingFrame.contentLength = 4 + expect(() => getEncryptFrame(inputFrame)).to.throw('Malformed frame length and content length:') + inputFrame.pendingFrame.contentLength = 6 + expect(() => getEncryptFrame(inputFrame)).to.throw('Malformed frame length and content length:') + }) +}) diff --git a/modules/serialize/src/decode_body_header.ts b/modules/serialize/src/decode_body_header.ts index aaa450fa9..2a46982e1 100644 --- a/modules/serialize/src/decode_body_header.ts +++ b/modules/serialize/src/decode_body_header.ts @@ -155,6 +155,8 @@ export function decodeFinalFrameBodyHeader (buffer: Uint8Array, headerInfo: Head needs(sequenceNumber > 0, 'Malformed sequenceNumber.') const iv = buffer.slice(readPos += 4, readPos += ivLength) const contentLength = dataView.getUint32(readPos) + /* Postcondition: The final frame MUST NOT exceed the frameLength. */ + needs(headerInfo.messageHeader.frameLength >= contentLength, 'Final frame length exceeds frame length.') return { sequenceNumber, iv, diff --git a/modules/serialize/test/decode_body_header.test.ts b/modules/serialize/test/decode_body_header.test.ts index eeb4fed7c..68ce5acea 100644 --- a/modules/serialize/test/decode_body_header.test.ts +++ b/modules/serialize/test/decode_body_header.test.ts @@ -101,7 +101,7 @@ describe('decodeFrameBodyHeader', () => { it('return final frame header', () => { const headerInfo = { messageHeader: { - frameLength: 99, + frameLength: 999, contentType: ContentType.FRAMED_DATA }, algorithmSuite: { @@ -205,7 +205,7 @@ describe('decodeFrameBodyHeader', () => { const buffer = concatBuffers(new Uint8Array(10), fixtures.finalFrameHeader()) const headerInfo = { messageHeader: { - frameLength: 99, + frameLength: 999, contentType: ContentType.FRAMED_DATA }, algorithmSuite: { @@ -296,7 +296,7 @@ describe('decodeFinalFrameBodyHeader', () => { it('return final frame header from readPos', () => { const headerInfo = { messageHeader: { - frameLength: 99, + frameLength: 999, contentType: ContentType.FRAMED_DATA }, algorithmSuite: { @@ -318,12 +318,36 @@ describe('decodeFinalFrameBodyHeader', () => { expect(test.tagLength).to.eql(16) expect(test.isFinalFrame).to.eql(true) expect(test.contentType).to.eql(ContentType.FRAMED_DATA) + expect(test.contentLength).to.eql(999) + }) + + it('The final frame can be 0 length.', () => { + const headerInfo = { + messageHeader: { + frameLength: 999, + contentType: ContentType.FRAMED_DATA + }, + algorithmSuite: { + ivLength: 12, + tagLength: 16 + } + } as any + const buffer = fixtures.finalFrameHeaderZeroBytes() + + const test = decodeFinalFrameBodyHeader(buffer, headerInfo, 0) + if (!test) throw new Error('failure') + expect(test.sequenceNumber).to.eql(1) + expect(test.iv).to.eql(fixtures.basicFrameIV()) + expect(test.tagLength).to.eql(16) + expect(test.isFinalFrame).to.eql(true) + expect(test.contentType).to.eql(ContentType.FRAMED_DATA) + expect(test.contentLength).to.eql(0) }) it('Precondition: The contentType must be FRAMED_DATA to be a Final Frame.', () => { const headerInfo = { messageHeader: { - frameLength: 99, + frameLength: 999, contentType: 'not FRAMED_DATA' }, algorithmSuite: { @@ -338,7 +362,7 @@ describe('decodeFinalFrameBodyHeader', () => { it('Precondition: decodeFinalFrameBodyHeader readPos must be within the byte length of the buffer given.', () => { const headerInfo = { messageHeader: { - frameLength: 99, + frameLength: 999, contentType: ContentType.FRAMED_DATA }, algorithmSuite: { @@ -355,7 +379,7 @@ describe('decodeFinalFrameBodyHeader', () => { it('Postcondition: sequenceEnd must be SEQUENCE_NUMBER_END.', () => { const headerInfo = { messageHeader: { - frameLength: 99, + frameLength: 999, contentType: ContentType.FRAMED_DATA }, algorithmSuite: { @@ -371,7 +395,7 @@ describe('decodeFinalFrameBodyHeader', () => { it('Postcondition: decodeFinalFrameBodyHeader sequenceNumber must be greater than 0.', () => { const headerInfo = { messageHeader: { - frameLength: 99, + frameLength: 999, contentType: ContentType.FRAMED_DATA }, algorithmSuite: { @@ -388,7 +412,7 @@ describe('decodeFinalFrameBodyHeader', () => { const frameHeader = fixtures.finalFrameHeader() const headerInfo = { messageHeader: { - frameLength: 99, + frameLength: 999, contentType: ContentType.FRAMED_DATA }, algorithmSuite: { @@ -402,6 +426,24 @@ describe('decodeFinalFrameBodyHeader', () => { expect(test).to.eql(false) } }) + + it('Postcondition: The final frame MUST NOT exceed the frameLength.', () => { + const headerInfo = { + messageHeader: { + // The content length in this final frame is 999 + // So I set the frame length to less than this + frameLength: 99, + contentType: ContentType.FRAMED_DATA + }, + algorithmSuite: { + ivLength: 12, + tagLength: 16 + } + } as any + const buffer = fixtures.finalFrameHeader() + + expect(() => decodeFinalFrameBodyHeader(buffer, headerInfo, 0)).to.throw('Final frame length exceeds frame length.') + }) }) describe('decodeNonFrameBodyHeader', () => { diff --git a/modules/serialize/test/fixtures.ts b/modules/serialize/test/fixtures.ts index 1430f5fb6..adccd3caf 100644 --- a/modules/serialize/test/fixtures.ts +++ b/modules/serialize/test/fixtures.ts @@ -61,6 +61,10 @@ export function finalFrameHeader () { return new Uint8Array([255, 255, 255, 255, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 3, 231]) } +export function finalFrameHeaderZeroBytes () { + return new Uint8Array([255, 255, 255, 255, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]) +} + export function invalidSequenceEndFinalFrameHeader () { return new Uint8Array([0, 255, 255, 255, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 3, 231]) }