diff --git a/src/System.Web.Http.Cors/CorsHttpConfigurationExtensions.cs b/src/System.Web.Http.Cors/CorsHttpConfigurationExtensions.cs index 3bf6631b5..42190f97b 100644 --- a/src/System.Web.Http.Cors/CorsHttpConfigurationExtensions.cs +++ b/src/System.Web.Http.Cors/CorsHttpConfigurationExtensions.cs @@ -26,7 +26,17 @@ public static class CorsHttpConfigurationExtensions /// The . public static void EnableCors(this HttpConfiguration httpConfiguration) { - EnableCors(httpConfiguration, null); + EnableCors(httpConfiguration, null, false); + } + + /// + /// Enables the support for CORS. + /// + /// The . + /// Indicates whether upstream exceptions should be rethrown + public static void EnableCors(this HttpConfiguration httpConfiguration, bool rethrowExceptions) + { + EnableCors(httpConfiguration, null, rethrowExceptions); } /// @@ -34,8 +44,20 @@ public static void EnableCors(this HttpConfiguration httpConfiguration) /// /// The . /// The default . - /// httpConfiguration public static void EnableCors(this HttpConfiguration httpConfiguration, ICorsPolicyProvider defaultPolicyProvider) + { + EnableCors(httpConfiguration, defaultPolicyProvider, false); + } + + /// + /// Enables the support for CORS. + /// + /// The . + /// The default . + /// Indicates whether upstream exceptions should be rethrown + /// httpConfiguration + public static void EnableCors(this HttpConfiguration httpConfiguration, ICorsPolicyProvider defaultPolicyProvider, + bool rethrowExceptions) { if (httpConfiguration == null) { @@ -49,11 +71,11 @@ public static void EnableCors(this HttpConfiguration httpConfiguration, ICorsPol httpConfiguration.SetCorsPolicyProviderFactory(policyProviderFactory); } - AddCorsMessageHandler(httpConfiguration); + AddCorsMessageHandler(httpConfiguration, rethrowExceptions); } [SuppressMessage("Microsoft.Reliability", "CA2000:Dispose objects before losing scope", Justification = "Caller owns the disposable object")] - private static void AddCorsMessageHandler(this HttpConfiguration httpConfiguration) + private static void AddCorsMessageHandler(this HttpConfiguration httpConfiguration, bool rethrowExceptions) { object corsEnabled; if (!httpConfiguration.Properties.TryGetValue(CorsEnabledKey, out corsEnabled)) @@ -64,7 +86,7 @@ private static void AddCorsMessageHandler(this HttpConfiguration httpConfigurati if (!config.Properties.TryGetValue(CorsEnabledKey, out corsEnabled)) { // Execute this in the Initializer to ensure that the CorsMessageHandler is added last. - config.MessageHandlers.Add(new CorsMessageHandler(config)); + config.MessageHandlers.Add(new CorsMessageHandler(config, rethrowExceptions)); ITraceWriter traceWriter = config.Services.GetTraceWriter(); diff --git a/src/System.Web.Http.Cors/CorsMessageHandler.cs b/src/System.Web.Http.Cors/CorsMessageHandler.cs index 5d683ca26..5308ee340 100644 --- a/src/System.Web.Http.Cors/CorsMessageHandler.cs +++ b/src/System.Web.Http.Cors/CorsMessageHandler.cs @@ -18,13 +18,24 @@ namespace System.Web.Http.Cors public class CorsMessageHandler : DelegatingHandler { private HttpConfiguration _httpConfiguration; + private bool _rethrowExceptions; /// /// Initializes a new instance of the class. /// /// The . /// httpConfiguration - public CorsMessageHandler(HttpConfiguration httpConfiguration) + public CorsMessageHandler(HttpConfiguration httpConfiguration) : this(httpConfiguration, false) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The . + /// Indicates whether upstream exceptions should be rethrown + /// httpConfiguration + public CorsMessageHandler(HttpConfiguration httpConfiguration, bool rethrowExceptions) { if (httpConfiguration == null) { @@ -32,6 +43,7 @@ public CorsMessageHandler(HttpConfiguration httpConfiguration) } _httpConfiguration = httpConfiguration; + _rethrowExceptions = rethrowExceptions; } /// @@ -60,6 +72,11 @@ protected async override Task SendAsync(HttpRequestMessage } catch (Exception exception) { + if (_rethrowExceptions) + { + throw; + } + return HandleException(request, exception); } } diff --git a/test/System.Web.Http.Cors.Test/Controllers/ThrowingController.cs b/test/System.Web.Http.Cors.Test/Controllers/ThrowingController.cs new file mode 100644 index 000000000..6879816c2 --- /dev/null +++ b/test/System.Web.Http.Cors.Test/Controllers/ThrowingController.cs @@ -0,0 +1,11 @@ +namespace System.Web.Http.Cors +{ + [EnableCors("*", "*", "*")] + public class ThrowingController : ApiController + { + public string Get() + { + throw new Exception(); + } + } +} diff --git a/test/System.Web.Http.Cors.Test/CorsMessageHandlerTest.cs b/test/System.Web.Http.Cors.Test/CorsMessageHandlerTest.cs index e3fe7cf49..c6ecfdf25 100644 --- a/test/System.Web.Http.Cors.Test/CorsMessageHandlerTest.cs +++ b/test/System.Web.Http.Cors.Test/CorsMessageHandlerTest.cs @@ -7,6 +7,7 @@ using System.Threading; using System.Threading.Tasks; using System.Web.Cors; +using System.Web.Http.ExceptionHandling; using System.Web.Http.Hosting; using Microsoft.TestCommon; @@ -180,6 +181,40 @@ public async Task SendAsync_HandlesExceptions_ThrownDuringPreflight() Assert.Equal(HttpStatusCode.MethodNotAllowed, response.StatusCode); } + [Fact] + public async Task SendAsync_Preflight_RethrowsExceptions_WhenRethrowFlagIsTrue() + { + HttpConfiguration config = new HttpConfiguration(); + config.Routes.MapHttpRoute("default", "{controller}"); + HttpServer server = new HttpServer(config); + CorsMessageHandler corsHandler = new CorsMessageHandler(config, true); + corsHandler.InnerHandler = server; + HttpMessageInvoker invoker = new HttpMessageInvoker(corsHandler); + HttpRequestMessage request = new HttpRequestMessage(HttpMethod.Options, "http://localhost/sample"); + request.SetConfiguration(config); + request.Headers.Add(CorsConstants.Origin, "http://localhost"); + request.Headers.Add(CorsConstants.AccessControlRequestMethod, "RandomMethod"); + + await Assert.ThrowsAsync(() => invoker.SendAsync(request, CancellationToken.None)); + } + + [Fact] + public async Task SendAsync_RethrowsExceptions_WhenRethrowFlagIsTrue() + { + HttpConfiguration config = new HttpConfiguration(); + config.Routes.MapHttpRoute("default", "{controller}"); + config.Services.Replace(typeof(IExceptionHandler), new PassthroughExceptionHandler()); + HttpServer server = new HttpServer(config); + CorsMessageHandler corsHandler = new CorsMessageHandler(config, true); + corsHandler.InnerHandler = server; + HttpMessageInvoker invoker = new HttpMessageInvoker(corsHandler); + HttpRequestMessage request = new HttpRequestMessage(HttpMethod.Get, "http://localhost/throwing"); + request.SetConfiguration(config); + request.Headers.Add(CorsConstants.Origin, "http://localhost"); + + await Assert.ThrowsAsync(() => invoker.SendAsync(request, CancellationToken.None)); + } + [Fact] public Task HandleCorsRequestAsync_NullConfig_Throws() { @@ -238,5 +273,13 @@ public Task HandleCorsPreflightRequestAsync_NullContext_Throws() () => corsHandler.HandleCorsPreflightRequestAsync(new HttpRequestMessage(), null, CancellationToken.None), "corsRequestContext"); } + + private class PassthroughExceptionHandler : IExceptionHandler + { + public Task HandleAsync(ExceptionHandlerContext context, CancellationToken cancellationToken) + { + throw context.Exception; + } + } } } \ No newline at end of file diff --git a/test/System.Web.Http.Cors.Test/System.Web.Http.Cors.Test.csproj b/test/System.Web.Http.Cors.Test/System.Web.Http.Cors.Test.csproj index ac5963fcd..9454da007 100644 --- a/test/System.Web.Http.Cors.Test/System.Web.Http.Cors.Test.csproj +++ b/test/System.Web.Http.Cors.Test/System.Web.Http.Cors.Test.csproj @@ -82,6 +82,7 @@ +