diff --git a/springdoc-openapi-starter-common/src/main/java/org/springdoc/core/customizers/ServerBaseUrlCustomizer.java b/springdoc-openapi-starter-common/src/main/java/org/springdoc/core/customizers/ServerBaseUrlCustomizer.java index b209ba72e..47c5d3891 100644 --- a/springdoc-openapi-starter-common/src/main/java/org/springdoc/core/customizers/ServerBaseUrlCustomizer.java +++ b/springdoc-openapi-starter-common/src/main/java/org/springdoc/core/customizers/ServerBaseUrlCustomizer.java @@ -24,6 +24,8 @@ package org.springdoc.core.customizers; +import org.springframework.http.HttpRequest; + /** * The interface Server Base URL customiser. * @author skylar -stark @@ -35,7 +37,8 @@ public interface ServerBaseUrlCustomizer { * Customise. * * @param serverBaseUrl the serverBaseUrl. + * @param request the request. * @return the customised serverBaseUrl */ - String customize(String serverBaseUrl); + String customize(String serverBaseUrl, HttpRequest request); } diff --git a/springdoc-openapi-starter-common/src/main/java/org/springdoc/core/service/OpenAPIService.java b/springdoc-openapi-starter-common/src/main/java/org/springdoc/core/service/OpenAPIService.java index 4a3fe8ad8..bb69b1a6c 100644 --- a/springdoc-openapi-starter-common/src/main/java/org/springdoc/core/service/OpenAPIService.java +++ b/springdoc-openapi-starter-common/src/main/java/org/springdoc/core/service/OpenAPIService.java @@ -81,6 +81,7 @@ import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.core.annotation.AnnotationUtils; import org.springframework.core.type.filter.AnnotationTypeFilter; +import org.springframework.http.HttpRequest; import org.springframework.stereotype.Controller; import org.springframework.util.CollectionUtils; import org.springframework.web.bind.annotation.ControllerAdvice; @@ -490,12 +491,12 @@ public Schema resolveProperties(Schema schema, Locale locale) { * * @param serverBaseUrl the server base url */ - public void setServerBaseUrl(String serverBaseUrl) { + public void setServerBaseUrl(String serverBaseUrl, HttpRequest httpRequest) { String customServerBaseUrl = serverBaseUrl; if (serverBaseUrlCustomizers.isPresent()) { for (ServerBaseUrlCustomizer customizer : serverBaseUrlCustomizers.get()) { - customServerBaseUrl = customizer.customize(customServerBaseUrl); + customServerBaseUrl = customizer.customize(customServerBaseUrl, httpRequest); } } diff --git a/springdoc-openapi-starter-common/src/test/java/org/springdoc/api/AbstractOpenApiResourceTest.java b/springdoc-openapi-starter-common/src/test/java/org/springdoc/api/AbstractOpenApiResourceTest.java index fdf75d7b0..f670da701 100644 --- a/springdoc-openapi-starter-common/src/test/java/org/springdoc/api/AbstractOpenApiResourceTest.java +++ b/springdoc-openapi-starter-common/src/test/java/org/springdoc/api/AbstractOpenApiResourceTest.java @@ -62,6 +62,7 @@ import org.springframework.context.ApplicationContext; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.web.bind.annotation.RequestMethod; +import org.springframework.mock.http.client.MockClientHttpRequest; import static java.util.Arrays.asList; import static java.util.Collections.singletonList; @@ -190,7 +191,7 @@ void preLoadingModeShouldNotOverwriteServers() throws InterruptedException { doCallRealMethod().when(openAPIService).updateServers(any()); when(openAPIService.getCachedOpenAPI(any())).thenCallRealMethod(); doAnswer(new CallsRealMethods()).when(openAPIService).setServersPresent(true); - doAnswer(new CallsRealMethods()).when(openAPIService).setServerBaseUrl(any()); + doAnswer(new CallsRealMethods()).when(openAPIService).setServerBaseUrl(any(), any()); doAnswer(new CallsRealMethods()).when(openAPIService).setCachedOpenAPI(any(), any()); String customUrl = "https://custom.com"; @@ -212,7 +213,7 @@ properties, springDocProviders, new SpringDocCustomizers(Optional.of(singletonLi Thread.sleep(1_000); // emulate generating base url - openAPIService.setServerBaseUrl(generatedUrl); + openAPIService.setServerBaseUrl(generatedUrl, new MockClientHttpRequest()); openAPIService.updateServers(openAPI); Locale locale = Locale.US; OpenAPI after = resource.getOpenApi(locale); @@ -224,7 +225,7 @@ properties, springDocProviders, new SpringDocCustomizers(Optional.of(singletonLi void serverBaseUrlCustomisersTest() throws InterruptedException { doCallRealMethod().when(openAPIService).updateServers(any()); when(openAPIService.getCachedOpenAPI(any())).thenCallRealMethod(); - doAnswer(new CallsRealMethods()).when(openAPIService).setServerBaseUrl(any()); + doAnswer(new CallsRealMethods()).when(openAPIService).setServerBaseUrl(any(), any()); doAnswer(new CallsRealMethods()).when(openAPIService).setCachedOpenAPI(any(), any()); SpringDocConfigProperties properties = new SpringDocConfigProperties(); @@ -247,37 +248,37 @@ springDocProviders, new SpringDocCustomizers(Optional.empty(),Optional.empty(),O // Test that setting generated URL works fine with no customizers present String generatedUrl = "https://generated-url.com/context-path"; - openAPIService.setServerBaseUrl(generatedUrl); + openAPIService.setServerBaseUrl(generatedUrl, new MockClientHttpRequest()); openAPIService.updateServers(openAPI); OpenAPI after = resource.getOpenApi(locale); assertThat(after.getServers().get(0).getUrl(), is(generatedUrl)); // Test that adding a serverBaseUrlCustomizer has the desired effect - ServerBaseUrlCustomizer serverBaseUrlCustomizer = serverBaseUrl -> serverBaseUrl.replace("/context-path", ""); + ServerBaseUrlCustomizer serverBaseUrlCustomizer = (serverBaseUrl, request) -> serverBaseUrl.replace("/context-path", ""); List serverBaseUrlCustomizerList = new ArrayList<>(); serverBaseUrlCustomizerList.add(serverBaseUrlCustomizer); ReflectionTestUtils.setField(openAPIService, "serverBaseUrlCustomizers", Optional.of(serverBaseUrlCustomizerList)); - openAPIService.setServerBaseUrl(generatedUrl); + openAPIService.setServerBaseUrl(generatedUrl, new MockClientHttpRequest()); openAPIService.updateServers(openAPI); after = resource.getOpenApi(locale); assertThat(after.getServers().get(0).getUrl(), is("https://generated-url.com")); // Test that serverBaseUrlCustomisers are performed in order generatedUrl = "https://generated-url.com/context-path/second-path"; - ServerBaseUrlCustomizer serverBaseUrlCustomiser2 = serverBaseUrl -> serverBaseUrl.replace("/context-path/second-path", ""); + ServerBaseUrlCustomizer serverBaseUrlCustomiser2 = (serverBaseUrl, request) -> serverBaseUrl.replace("/context-path/second-path", ""); serverBaseUrlCustomizerList.add(serverBaseUrlCustomiser2); - openAPIService.setServerBaseUrl(generatedUrl); + openAPIService.setServerBaseUrl(generatedUrl, new MockClientHttpRequest()); openAPIService.updateServers(openAPI); after = resource.getOpenApi(locale); assertThat(after.getServers().get(0).getUrl(), is("https://generated-url.com/second-path")); // Test that all serverBaseUrlCustomisers in the List are performed - ServerBaseUrlCustomizer serverBaseUrlCustomiser3 = serverBaseUrl -> serverBaseUrl.replace("/second-path", ""); + ServerBaseUrlCustomizer serverBaseUrlCustomiser3 = (serverBaseUrl, request) -> serverBaseUrl.replace("/second-path", ""); serverBaseUrlCustomizerList.add(serverBaseUrlCustomiser3); - openAPIService.setServerBaseUrl(generatedUrl); + openAPIService.setServerBaseUrl(generatedUrl, new MockClientHttpRequest()); openAPIService.updateServers(openAPI); after = resource.getOpenApi(locale); assertThat(after.getServers().get(0).getUrl(), is("https://generated-url.com")); diff --git a/springdoc-openapi-starter-webflux-api/src/main/java/org/springdoc/webflux/api/OpenApiActuatorResource.java b/springdoc-openapi-starter-webflux-api/src/main/java/org/springdoc/webflux/api/OpenApiActuatorResource.java index 75ad34354..3bee371a5 100644 --- a/springdoc-openapi-starter-webflux-api/src/main/java/org/springdoc/webflux/api/OpenApiActuatorResource.java +++ b/springdoc-openapi-starter-webflux-api/src/main/java/org/springdoc/webflux/api/OpenApiActuatorResource.java @@ -131,7 +131,7 @@ public Mono openapiYaml(ServerHttpRequest serverHttpRequest, Locale loca protected void calculateServerUrl(ServerHttpRequest serverHttpRequest, String apiDocsUrl, Locale locale) { super.initOpenAPIBuilder(locale); URI uri = getActuatorURI(serverHttpRequest.getURI().getScheme(), serverHttpRequest.getURI().getHost()); - openAPIService.setServerBaseUrl(uri.toString()); + openAPIService.setServerBaseUrl(uri.toString(), serverHttpRequest); } @Override diff --git a/springdoc-openapi-starter-webflux-api/src/main/java/org/springdoc/webflux/api/OpenApiResource.java b/springdoc-openapi-starter-webflux-api/src/main/java/org/springdoc/webflux/api/OpenApiResource.java index acbefce90..0b29503d1 100644 --- a/springdoc-openapi-starter-webflux-api/src/main/java/org/springdoc/webflux/api/OpenApiResource.java +++ b/springdoc-openapi-starter-webflux-api/src/main/java/org/springdoc/webflux/api/OpenApiResource.java @@ -229,7 +229,7 @@ protected void getWebFluxRouterFunctionPaths(Locale locale, OpenAPI openAPI) { protected void calculateServerUrl(ServerHttpRequest serverHttpRequest, String apiDocsUrl, Locale locale) { initOpenAPIBuilder(locale); String serverUrl = getServerUrl(serverHttpRequest, apiDocsUrl); - openAPIService.setServerBaseUrl(serverUrl); + openAPIService.setServerBaseUrl(serverUrl, serverHttpRequest); } /** diff --git a/springdoc-openapi-starter-webmvc-api/src/main/java/org/springdoc/webmvc/api/OpenApiResource.java b/springdoc-openapi-starter-webmvc-api/src/main/java/org/springdoc/webmvc/api/OpenApiResource.java index 4b5fd61a7..288496a80 100644 --- a/springdoc-openapi-starter-webmvc-api/src/main/java/org/springdoc/webmvc/api/OpenApiResource.java +++ b/springdoc-openapi-starter-webmvc-api/src/main/java/org/springdoc/webmvc/api/OpenApiResource.java @@ -55,6 +55,7 @@ import org.springframework.aop.support.AopUtils; import org.springframework.beans.factory.ObjectFactory; +import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.util.CollectionUtils; import org.springframework.util.MimeType; import org.springframework.web.bind.annotation.RequestMethod; @@ -244,7 +245,8 @@ private Comparator byReversedRequestMappingInfos() { protected void calculateServerUrl(HttpServletRequest request, String apiDocsUrl, Locale locale) { super.initOpenAPIBuilder(locale); String calculatedUrl = getServerUrl(request, apiDocsUrl); - openAPIService.setServerBaseUrl(calculatedUrl); + ServletServerHttpRequest serverRequest = request != null ? new ServletServerHttpRequest(request) : null; + openAPIService.setServerBaseUrl(calculatedUrl, serverRequest); } /**