diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fd10d13b..83dc88da 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -52,5 +52,6 @@ jobs: php-version: '8.0' - name: Run Script run: | + composer install composer global require phpstan/phpstan ~/.composer/vendor/bin/phpstan analyse diff --git a/README.md b/README.md index 65e8fc18..42e8b6db 100644 --- a/README.md +++ b/README.md @@ -203,6 +203,45 @@ $jwks = ['keys' => []]; JWT::decode($payload, JWK::parseKeySet($jwks)); ``` +Using Cached Key Sets +--------------------- + +The `CachedKeySet` class can be used to fetch and cache JWKS (JSON Web Key Sets) from a public URI. +This has the following advantages: + +1. The results are cached for performance. +2. If an unrecognized key is requested, the cache is refreshed, to accomodate for key rotation. +3. If rate limiting is enabled, the JWKS URI will not make more than 10 requests a second. + +```php +use Firebase\JWT\CachedKeySet; +use Firebase\JWT\JWT; + +// The URI for the JWKS you wish to cache the results from +$jwksUri = 'https://www.gstatic.com/iap/verify/public_key-jwk'; + +// Create an HTTP client (can be any PSR-7 compatible HTTP client) +$httpClient = new GuzzleHttp\Client(); + +// Create an HTTP request factory (can be any PSR-17 compatible HTTP request factory) +$httpFactory = new GuzzleHttp\Psr\HttpFactory(); + +// Create a cache item pool (can be any PSR-6 compatible cache item pool) +$cacheItemPool = Phpfastcache\CacheManager::getInstance('files'); + +$keySet = new CachedKeySet( + $jwksUri, + $httpClient, + $httpFactory, + $cacheItemPool, + null, // $expiresAfter int seconds to set the JWKS to expire + true // $rateLimit true to enable rate limit of 10 RPS on lookup of invalid keys +); + +$jwt = 'eyJhbGci...'; // Some JWT signed by a key from the $jwkUri above +$decoded = JWT::decode($jwt, $keySet); +``` + Miscellaneous ------------- diff --git a/composer.json b/composer.json index 5ef2ea2d..2a3cb2df 100644 --- a/composer.json +++ b/composer.json @@ -31,6 +31,11 @@ } }, "require-dev": { - "phpunit/phpunit": "^7.5||9.5" + "guzzlehttp/guzzle": "^6.5||^7.4", + "phpspec/prophecy-phpunit": "^1.1", + "phpunit/phpunit": "^7.5||^9.5", + "psr/cache": "^1.0||^2.0", + "psr/http-client": "^1.0", + "psr/http-factory": "^1.0" } } diff --git a/phpunit.xml.dist b/phpunit.xml.dist index 092a662c..31195a91 100644 --- a/phpunit.xml.dist +++ b/phpunit.xml.dist @@ -8,7 +8,7 @@ convertWarningsToExceptions="true" processIsolation="false" stopOnFailure="false" - bootstrap="tests/bootstrap.php" + bootstrap="vendor/autoload.php" > diff --git a/src/CachedKeySet.php b/src/CachedKeySet.php new file mode 100644 index 00000000..077dceb0 --- /dev/null +++ b/src/CachedKeySet.php @@ -0,0 +1,225 @@ + + */ +class CachedKeySet implements ArrayAccess +{ + /** + * @var string + */ + private $jwksUri; + /** + * @var ClientInterface + */ + private $httpClient; + /** + * @var RequestFactoryInterface + */ + private $httpFactory; + /** + * @var CacheItemPoolInterface + */ + private $cache; + /** + * @var ?int + */ + private $expiresAfter; + /** + * @var ?CacheItemInterface + */ + private $cacheItem; + /** + * @var array + */ + private $keySet; + /** + * @var string + */ + private $cacheKey; + /** + * @var string + */ + private $cacheKeyPrefix = 'jwks'; + /** + * @var int + */ + private $maxKeyLength = 64; + /** + * @var bool + */ + private $rateLimit; + /** + * @var string + */ + private $rateLimitCacheKey; + /** + * @var int + */ + private $maxCallsPerMinute = 10; + + public function __construct( + string $jwksUri, + ClientInterface $httpClient, + RequestFactoryInterface $httpFactory, + CacheItemPoolInterface $cache, + int $expiresAfter = null, + bool $rateLimit = false + ) { + $this->jwksUri = $jwksUri; + $this->httpClient = $httpClient; + $this->httpFactory = $httpFactory; + $this->cache = $cache; + $this->expiresAfter = $expiresAfter; + $this->rateLimit = $rateLimit; + $this->setCacheKeys(); + } + + /** + * @param string $keyId + * @return Key + */ + public function offsetGet($keyId): Key + { + if (!$this->keyIdExists($keyId)) { + throw new OutOfBoundsException('Key ID not found'); + } + return $this->keySet[$keyId]; + } + + /** + * @param string $keyId + * @return bool + */ + public function offsetExists($keyId): bool + { + return $this->keyIdExists($keyId); + } + + /** + * @param string $offset + * @param Key $value + */ + public function offsetSet($offset, $value): void + { + throw new LogicException('Method not implemented'); + } + + /** + * @param string $offset + */ + public function offsetUnset($offset): void + { + throw new LogicException('Method not implemented'); + } + + private function keyIdExists(string $keyId): bool + { + $keySetToCache = null; + if (null === $this->keySet) { + $item = $this->getCacheItem(); + // Try to load keys from cache + if ($item->isHit()) { + // item found! Return it + $this->keySet = $item->get(); + } + } + + if (!isset($this->keySet[$keyId])) { + if ($this->rateLimitExceeded()) { + return false; + } + $request = $this->httpFactory->createRequest('get', $this->jwksUri); + $jwksResponse = $this->httpClient->sendRequest($request); + $jwks = json_decode((string) $jwksResponse->getBody(), true); + $this->keySet = $keySetToCache = JWK::parseKeySet($jwks); + + if (!isset($this->keySet[$keyId])) { + return false; + } + } + + if ($keySetToCache) { + $item = $this->getCacheItem(); + $item->set($keySetToCache); + if ($this->expiresAfter) { + $item->expiresAfter($this->expiresAfter); + } + $this->cache->save($item); + } + + return true; + } + + private function rateLimitExceeded(): bool + { + if (!$this->rateLimit) { + return false; + } + + $cacheItem = $this->cache->getItem($this->rateLimitCacheKey); + if (!$cacheItem->isHit()) { + $cacheItem->expiresAfter(1); // # of calls are cached each minute + } + + $callsPerMinute = (int) $cacheItem->get(); + if (++$callsPerMinute > $this->maxCallsPerMinute) { + return true; + } + $cacheItem->set($callsPerMinute); + $this->cache->save($cacheItem); + return false; + } + + private function getCacheItem(): CacheItemInterface + { + if (\is_null($this->cacheItem)) { + $this->cacheItem = $this->cache->getItem($this->cacheKey); + } + + return $this->cacheItem; + } + + private function setCacheKeys(): void + { + if (empty($this->jwksUri)) { + throw new RuntimeException('JWKS URI is empty'); + } + + // ensure we do not have illegal characters + $key = preg_replace('|[^a-zA-Z0-9_\.!]|', '', $this->jwksUri); + + // add prefix + $key = $this->cacheKeyPrefix . $key; + + // Hash keys if they exceed $maxKeyLength of 64 + if (\strlen($key) > $this->maxKeyLength) { + $key = substr(hash('sha256', $key), 0, $this->maxKeyLength); + } + + $this->cacheKey = $key; + + if ($this->rateLimit) { + // add prefix + $rateLimitKey = $this->cacheKeyPrefix . 'ratelimit' . $key; + + // Hash keys if they exceed $maxKeyLength of 64 + if (\strlen($rateLimitKey) > $this->maxKeyLength) { + $rateLimitKey = substr(hash('sha256', $rateLimitKey), 0, $this->maxKeyLength); + } + + $this->rateLimitCacheKey = $rateLimitKey; + } + } +} diff --git a/src/JWT.php b/src/JWT.php index 98a503e4..9011292f 100644 --- a/src/JWT.php +++ b/src/JWT.php @@ -2,6 +2,7 @@ namespace Firebase\JWT; +use ArrayAccess; use DateTime; use DomainException; use Exception; @@ -9,7 +10,6 @@ use OpenSSLAsymmetricKey; use OpenSSLCertificate; use stdClass; -use TypeError; use UnexpectedValueException; /** @@ -68,7 +68,7 @@ class JWT * Decodes a JWT string into a PHP object. * * @param string $jwt The JWT - * @param Key|array $keyOrKeyArray The Key or associative array of key IDs (kid) to Key objects. + * @param Key|array $keyOrKeyArray The Key or associative array of key IDs (kid) to Key objects. * If the algorithm used is asymmetric, this is the public key * Each Key object contains an algorithm and matching key. * Supported algorithms are 'ES384','ES256', 'HS256', 'HS384', @@ -409,7 +409,7 @@ public static function urlsafeB64Encode(string $input): string /** * Determine if an algorithm has been provided for each Key * - * @param Key|array $keyOrKeyArray + * @param Key|ArrayAccess|array $keyOrKeyArray * @param string|null $kid * * @throws UnexpectedValueException @@ -424,15 +424,12 @@ private static function getKey( return $keyOrKeyArray; } - foreach ($keyOrKeyArray as $keyId => $key) { - if (!$key instanceof Key) { - throw new TypeError( - '$keyOrKeyArray must be an instance of Firebase\JWT\Key key or an ' - . 'array of Firebase\JWT\Key keys' - ); - } + if ($keyOrKeyArray instanceof CachedKeySet) { + // Skip "isset" check, as this will automatically refresh if not set + return $keyOrKeyArray[$kid]; } - if (!isset($kid)) { + + if (empty($kid)) { throw new UnexpectedValueException('"kid" empty, unable to lookup correct key'); } if (!isset($keyOrKeyArray[$kid])) { diff --git a/tests/CachedKeySetTest.php b/tests/CachedKeySetTest.php new file mode 100644 index 00000000..22e1de5f --- /dev/null +++ b/tests/CachedKeySetTest.php @@ -0,0 +1,481 @@ +expectException(RuntimeException::class); + $this->expectExceptionMessage('JWKS URI is empty'); + + $cachedKeySet = new CachedKeySet( + '', + $this->prophesize(ClientInterface::class)->reveal(), + $this->prophesize(RequestFactoryInterface::class)->reveal(), + $this->prophesize(CacheItemPoolInterface::class)->reveal() + ); + + $cachedKeySet['foo']; + } + + public function testOffsetSetThrowsException() + { + $this->expectException(LogicException::class); + $this->expectExceptionMessage('Method not implemented'); + + $cachedKeySet = new CachedKeySet( + $this->testJwksUri, + $this->prophesize(ClientInterface::class)->reveal(), + $this->prophesize(RequestFactoryInterface::class)->reveal(), + $this->prophesize(CacheItemPoolInterface::class)->reveal() + ); + + $cachedKeySet['foo'] = 'bar'; + } + + public function testOffsetUnsetThrowsException() + { + $this->expectException(LogicException::class); + $this->expectExceptionMessage('Method not implemented'); + + $cachedKeySet = new CachedKeySet( + $this->testJwksUri, + $this->prophesize(ClientInterface::class)->reveal(), + $this->prophesize(RequestFactoryInterface::class)->reveal(), + $this->prophesize(CacheItemPoolInterface::class)->reveal() + ); + + unset($cachedKeySet['foo']); + } + + public function testOutOfBoundsThrowsException() + { + $this->expectException(OutOfBoundsException::class); + $this->expectExceptionMessage('Key ID not found'); + + $cachedKeySet = new CachedKeySet( + $this->testJwksUri, + $this->getMockHttpClient($this->testJwks1), + $this->getMockHttpFactory(), + $this->getMockEmptyCache() + ); + + // keyID doesn't exist + $cachedKeySet['bar']; + } + + public function testWithExistingKeyId() + { + $cachedKeySet = new CachedKeySet( + $this->testJwksUri, + $this->getMockHttpClient($this->testJwks1), + $this->getMockHttpFactory(), + $this->getMockEmptyCache() + ); + $this->assertInstanceOf(Key::class, $cachedKeySet['foo']); + $this->assertEquals('foo', $cachedKeySet['foo']->getAlgorithm()); + } + + public function testKeyIdIsCached() + { + $cacheItem = $this->prophesize(CacheItemInterface::class); + $cacheItem->isHit() + ->willReturn(true); + $cacheItem->get() + ->willReturn(JWK::parseKeySet(json_decode($this->testJwks1, true))); + + $cache = $this->prophesize(CacheItemPoolInterface::class); + $cache->getItem($this->testJwksUriKey) + ->willReturn($cacheItem->reveal()); + $cache->save(Argument::any()) + ->willReturn(true); + + $cachedKeySet = new CachedKeySet( + $this->testJwksUri, + $this->prophesize(ClientInterface::class)->reveal(), + $this->prophesize(RequestFactoryInterface::class)->reveal(), + $cache->reveal() + ); + $this->assertInstanceOf(Key::class, $cachedKeySet['foo']); + $this->assertEquals('foo', $cachedKeySet['foo']->getAlgorithm()); + } + + public function testCachedKeyIdRefresh() + { + $cacheItem = $this->prophesize(CacheItemInterface::class); + $cacheItem->isHit() + ->shouldBeCalledOnce() + ->willReturn(true); + $cacheItem->get() + ->shouldBeCalledOnce() + ->willReturn(JWK::parseKeySet(json_decode($this->testJwks1, true))); + $cacheItem->set(Argument::any()) + ->shouldBeCalledOnce() + ->will(function () { + return $this; + }); + + $cache = $this->prophesize(CacheItemPoolInterface::class); + $cache->getItem($this->testJwksUriKey) + ->shouldBeCalledOnce() + ->willReturn($cacheItem->reveal()); + $cache->save(Argument::any()) + ->shouldBeCalledOnce() + ->willReturn(true); + + $cachedKeySet = new CachedKeySet( + $this->testJwksUri, + $this->getMockHttpClient($this->testJwks2), // updated JWK + $this->getMockHttpFactory(), + $cache->reveal() + ); + $this->assertInstanceOf(Key::class, $cachedKeySet['foo']); + $this->assertEquals('foo', $cachedKeySet['foo']->getAlgorithm()); + + $this->assertInstanceOf(Key::class, $cachedKeySet['bar']); + $this->assertEquals('bar', $cachedKeySet['bar']->getAlgorithm()); + } + + public function testCacheItemWithExpiresAfter() + { + $expiresAfter = 10; + $cacheItem = $this->prophesize(CacheItemInterface::class); + $cacheItem->isHit() + ->shouldBeCalledOnce() + ->willReturn(false); + $cacheItem->set(Argument::any()) + ->shouldBeCalledOnce() + ->will(function () { + return $this; + }); + $cacheItem->expiresAfter($expiresAfter) + ->shouldBeCalledOnce() + ->will(function () { + return $this; + }); + + $cache = $this->prophesize(CacheItemPoolInterface::class); + $cache->getItem($this->testJwksUriKey) + ->shouldBeCalledOnce() + ->willReturn($cacheItem->reveal()); + $cache->save(Argument::any()) + ->shouldBeCalledOnce(); + + $cachedKeySet = new CachedKeySet( + $this->testJwksUri, + $this->getMockHttpClient($this->testJwks1), + $this->getMockHttpFactory(), + $cache->reveal(), + $expiresAfter + ); + $this->assertInstanceOf(Key::class, $cachedKeySet['foo']); + $this->assertEquals('foo', $cachedKeySet['foo']->getAlgorithm()); + } + + public function testJwtVerify() + { + $privKey1 = file_get_contents(__DIR__ . '/data/rsa1-private.pem'); + $payload = array('sub' => 'foo', 'exp' => strtotime('+10 seconds')); + $msg = JWT::encode($payload, $privKey1, 'RS256', 'jwk1'); + + $cacheItem = $this->prophesize(CacheItemInterface::class); + $cacheItem->isHit() + ->willReturn(true); + $cacheItem->get() + ->willReturn(JWK::parseKeySet( + json_decode(file_get_contents(__DIR__ . '/data/rsa-jwkset.json'), true) + )); + + $cache = $this->prophesize(CacheItemPoolInterface::class); + $cache->getItem($this->testJwksUriKey) + ->willReturn($cacheItem->reveal()); + + $cachedKeySet = new CachedKeySet( + $this->testJwksUri, + $this->prophesize(ClientInterface::class)->reveal(), + $this->prophesize(RequestFactoryInterface::class)->reveal(), + $cache->reveal() + ); + + $result = JWT::decode($msg, $cachedKeySet); + + $this->assertEquals('foo', $result->sub); + } + + public function testRateLimit() + { + // We request the key 11 times, HTTP should only be called 10 times + $shouldBeCalledTimes = 10; + + // Instantiate the cached key set + $cachedKeySet = new CachedKeySet( + $this->testJwksUri, + $this->getMockHttpClient($this->testJwks1, $shouldBeCalledTimes), + $factory = $this->getMockHttpFactory($shouldBeCalledTimes), + new TestMemoryCacheItemPool(), + 10, // expires after seconds + true // enable rate limiting + ); + + $invalidKid = 'invalidkey'; + for ($i = 0; $i < 10; $i++) { + $this->assertFalse(isset($cachedKeySet[$invalidKid])); + } + // The 11th time does not call HTTP + $this->assertFalse(isset($cachedKeySet[$invalidKid])); + } + + /** + * @dataProvider provideFullIntegration + */ + public function testFullIntegration(string $jwkUri): void + { + if (!class_exists(\GuzzleHttp\Psr7\HttpFactory::class)) { + self::markTestSkipped('Guzzle 7 only'); + } + // Create cache and http objects + $cache = new TestMemoryCacheItemPool(); + $http = new \GuzzleHttp\Client(); + $factory = new \GuzzleHttp\Psr7\HttpFactory(); + + // Determine "kid" dynamically, because these constantly change + $response = $http->get($jwkUri); + $json = (string) $response->getBody(); + $keys = json_decode($json, true); + $kid = $keys['keys'][0]['kid'] ?? null; + $this->assertNotNull($kid); + + // Instantiate the cached key set + $cachedKeySet = new CachedKeySet( + $jwkUri, + $http, + $factory, + $cache + ); + + $this->assertArrayHasKey($kid, $cachedKeySet); + $key = $cachedKeySet[$kid]; + $this->assertInstanceOf(Key::class, $key); + $this->assertEquals($keys['keys'][0]['alg'], $key->getAlgorithm()); + } + + public function provideFullIntegration() + { + return [ + [$this->googleRsaUri], + // [$this->googleEcUri, 'LYyP2g'] + ]; + } + + private function getMockHttpClient($testJwks, int $timesCalled = 1) + { + $body = $this->prophesize('Psr\Http\Message\StreamInterface'); + $body->__toString() + ->shouldBeCalledTimes($timesCalled) + ->willReturn($testJwks); + + $response = $this->prophesize('Psr\Http\Message\ResponseInterface'); + $response->getBody() + ->shouldBeCalledTimes($timesCalled) + ->willReturn($body->reveal()); + + $http = $this->prophesize(ClientInterface::class); + $http->sendRequest(Argument::any()) + ->shouldBeCalledTimes($timesCalled) + ->willReturn($response->reveal()); + + return $http->reveal(); + } + + private function getMockHttpFactory(int $timesCalled = 1) + { + $request = $this->prophesize('Psr\Http\Message\RequestInterface'); + $factory = $this->prophesize(RequestFactoryInterface::class); + $factory->createRequest('get', $this->testJwksUri) + ->shouldBeCalledTimes($timesCalled) + ->willReturn($request->reveal()); + + return $factory->reveal(); + } + + private function getMockEmptyCache() + { + $cacheItem = $this->prophesize(CacheItemInterface::class); + $cacheItem->isHit() + ->shouldBeCalledOnce() + ->willReturn(false); + $cacheItem->set(Argument::any()) + ->will(function () { + return $this; + }); + + $cache = $this->prophesize(CacheItemPoolInterface::class); + $cache->getItem($this->testJwksUriKey) + ->shouldBeCalledOnce() + ->willReturn($cacheItem->reveal()); + $cache->save(Argument::any()) + ->willReturn(true); + + return $cache->reveal(); + } +} + +/** + * A cache item pool + */ +final class TestMemoryCacheItemPool implements CacheItemPoolInterface +{ + private $items; + private $deferredItems; + + public function getItem($key): CacheItemInterface + { + return current($this->getItems([$key])); + } + + public function getItems(array $keys = []): iterable + { + $items = []; + + foreach ($keys as $key) { + $items[$key] = $this->hasItem($key) ? clone $this->items[$key] : new TestMemoryCacheItem($key); + } + + return $items; + } + + public function hasItem($key): bool + { + return isset($this->items[$key]) && $this->items[$key]->isHit(); + } + + public function clear(): bool + { + $this->items = []; + $this->deferredItems = []; + + return true; + } + + public function deleteItem($key): bool + { + return $this->deleteItems([$key]); + } + + public function deleteItems(array $keys): bool + { + foreach ($keys as $key) { + unset($this->items[$key]); + } + + return true; + } + + public function save(CacheItemInterface $item): bool + { + $this->items[$item->getKey()] = $item; + + return true; + } + + public function saveDeferred(CacheItemInterface $item): bool + { + $this->deferredItems[$item->getKey()] = $item; + + return true; + } + + public function commit(): bool + { + foreach ($this->deferredItems as $item) { + $this->save($item); + } + + $this->deferredItems = []; + + return true; + } +} + +/** + * A cache item. + */ +final class TestMemoryCacheItem implements CacheItemInterface +{ + private $key; + private $value; + private $expiration; + private $isHit = false; + + public function __construct(string $key) + { + $this->key = $key; + } + + public function getKey(): string + { + return $this->key; + } + + public function get() + { + return $this->isHit() ? $this->value : null; + } + + public function isHit(): bool + { + if (!$this->isHit) { + return false; + } + + if ($this->expiration === null) { + return true; + } + + return $this->currentTime()->getTimestamp() < $this->expiration->getTimestamp(); + } + + public function set($value) + { + $this->isHit = true; + $this->value = $value; + + return $this; + } + + public function expiresAt($expiration) + { + $this->expiration = $expiration; + return $this; + } + + public function expiresAfter($time) + { + $this->expiration = $this->currentTime()->add(new \DateInterval("PT{$time}S")); + return $this; + } + + protected function currentTime() + { + return new \DateTime('now', new \DateTimeZone('UTC')); + } +} diff --git a/tests/bootstrap.php b/tests/bootstrap.php deleted file mode 100644 index 385b6706..00000000 --- a/tests/bootstrap.php +++ /dev/null @@ -1,14 +0,0 @@ -