Skip to content

Commit 2ad1ff0

Browse files
Enhance validation error handling in ValidationEndpointFilterFactory using IProblemDetailsService (#62066)
Co-authored-by: Safia Abdalla <safia@safia.rocks>
1 parent 49dc7ec commit 2ad1ff0

File tree

2 files changed

+206
-2
lines changed

2 files changed

+206
-2
lines changed
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
#pragma warning disable ASP0029 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
2+
3+
// Licensed to the .NET Foundation under one or more agreements.
4+
// The .NET Foundation licenses this file to you under the MIT license.
5+
6+
using System.ComponentModel.DataAnnotations;
7+
using System.Net.Mime;
8+
using System.Text.Json;
9+
using Microsoft.AspNetCore.Builder;
10+
using Microsoft.AspNetCore.InternalTesting;
11+
using Microsoft.AspNetCore.Mvc;
12+
using Microsoft.AspNetCore.Routing;
13+
using Microsoft.Extensions.DependencyInjection;
14+
15+
namespace Microsoft.AspNetCore.Http.Extensions.Tests;
16+
17+
public class ValidationEndpointFilterFactoryTests : LoggedTest
18+
{
19+
[Fact]
20+
public async Task GetHttpValidationProblemDetailsWhenProblemDetailsServiceNotRegistered()
21+
{
22+
var services = new ServiceCollection();
23+
services.AddValidation();
24+
var serviceProvider = services.BuildServiceProvider();
25+
26+
var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(serviceProvider));
27+
28+
// Act - Create one endpoint with validation
29+
builder.MapGet("validation-test", ([Range(5, 10)] int param) => "Validation enabled here.");
30+
31+
// Build the endpoints
32+
var dataSource = Assert.Single(builder.DataSources);
33+
var endpoints = dataSource.Endpoints;
34+
35+
// Get filter factories from endpoint
36+
var endpoint = endpoints[0];
37+
38+
var context = new DefaultHttpContext
39+
{
40+
RequestServices = serviceProvider
41+
};
42+
43+
context.Request.Method = "GET";
44+
context.Request.QueryString = new QueryString("?param=15");
45+
using var ms = new MemoryStream();
46+
context.Response.Body = ms;
47+
48+
await endpoint.RequestDelegate(context);
49+
50+
// Assert
51+
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
52+
Assert.StartsWith(MediaTypeNames.Application.Json, context.Response.ContentType, StringComparison.OrdinalIgnoreCase);
53+
54+
ms.Seek(0, SeekOrigin.Begin);
55+
var problemDetails = await JsonSerializer.DeserializeAsync<ProblemDetails>(ms, JsonSerializerOptions.Web);
56+
57+
Assert.Equal("One or more validation errors occurred.", problemDetails.Title);
58+
59+
// Check that ProblemDetails contains the errors object with 1 validation error
60+
Assert.True(problemDetails.Extensions.TryGetValue("errors", out var errorsObj));
61+
var errors = Assert.IsType<JsonElement>(errorsObj);
62+
Assert.True(errors.EnumerateObject().Count() == 1);
63+
}
64+
65+
[Fact]
66+
public async Task UseProblemDetailsServiceWhenAddedInServiceCollection()
67+
{
68+
var services = new ServiceCollection();
69+
services.AddValidation();
70+
services.AddProblemDetails();
71+
var serviceProvider = services.BuildServiceProvider();
72+
73+
var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(serviceProvider));
74+
75+
// Act - Create one endpoint with validation
76+
builder.MapGet("validation-test", ([Range(5, 10)] int param) => "Validation enabled here.");
77+
78+
// Build the endpoints
79+
var dataSource = Assert.Single(builder.DataSources);
80+
var endpoints = dataSource.Endpoints;
81+
82+
// Get filter factories from endpoint
83+
var endpoint = endpoints[0];
84+
85+
var context = new DefaultHttpContext
86+
{
87+
RequestServices = serviceProvider
88+
};
89+
90+
context.Request.Method = "GET";
91+
context.Request.QueryString = new QueryString("?param=15");
92+
using var ms = new MemoryStream();
93+
context.Response.Body = ms;
94+
95+
await endpoint.RequestDelegate(context);
96+
97+
// Assert
98+
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
99+
Assert.StartsWith(MediaTypeNames.Application.ProblemJson, context.Response.ContentType, StringComparison.OrdinalIgnoreCase);
100+
101+
ms.Seek(0, SeekOrigin.Begin);
102+
var problemDetails = await JsonSerializer.DeserializeAsync<ProblemDetails>(ms, JsonSerializerOptions.Web);
103+
104+
// Check if the response is an actual ProblemDetails object
105+
Assert.Equal("https://tools.ietf.org/html/rfc9110#section-15.5.1", problemDetails.Type);
106+
Assert.Equal("One or more validation errors occurred.", problemDetails.Title);
107+
Assert.Equal(StatusCodes.Status400BadRequest, problemDetails.Status);
108+
109+
// Check that ProblemDetails contains the errors object with 1 validation error
110+
Assert.True(problemDetails.Extensions.TryGetValue("errors", out var errorsObj));
111+
var errors = Assert.IsType<JsonElement>(errorsObj);
112+
Assert.True(errors.EnumerateObject().Count() == 1);
113+
}
114+
115+
[Fact]
116+
public async Task UseProblemDetailsServiceWithCallbackWhenAddedInServiceCollection()
117+
{
118+
var services = new ServiceCollection();
119+
services.AddValidation();
120+
121+
services.AddProblemDetails(options =>
122+
{
123+
options.CustomizeProblemDetails = context =>
124+
{
125+
context.ProblemDetails.Extensions.Add("timestamp", DateTimeOffset.Now);
126+
};
127+
});
128+
129+
var serviceProvider = services.BuildServiceProvider();
130+
131+
var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(serviceProvider));
132+
133+
// Act - Create one endpoint with validation
134+
builder.MapGet("validation-test", ([Range(5, 10)] int param) => "Validation enabled here.");
135+
136+
// Build the endpoints
137+
var dataSource = Assert.Single(builder.DataSources);
138+
var endpoints = dataSource.Endpoints;
139+
140+
// Get filter factories from endpoint
141+
var endpoint = endpoints[0];
142+
143+
var context = new DefaultHttpContext
144+
{
145+
RequestServices = serviceProvider
146+
};
147+
148+
context.Request.Method = "GET";
149+
context.Request.QueryString = new QueryString("?param=15");
150+
using var ms = new MemoryStream();
151+
context.Response.Body = ms;
152+
153+
await endpoint.RequestDelegate(context);
154+
155+
// Assert
156+
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
157+
Assert.StartsWith(MediaTypeNames.Application.ProblemJson, context.Response.ContentType, StringComparison.OrdinalIgnoreCase);
158+
159+
ms.Seek(0, SeekOrigin.Begin);
160+
var problemDetails = await JsonSerializer.DeserializeAsync<ProblemDetails>(ms, JsonSerializerOptions.Web);
161+
162+
// Check if the response is an actual ProblemDetails object
163+
Assert.Equal("https://tools.ietf.org/html/rfc9110#section-15.5.1", problemDetails.Type);
164+
Assert.Equal("One or more validation errors occurred.", problemDetails.Title);
165+
Assert.Equal(StatusCodes.Status400BadRequest, problemDetails.Status);
166+
167+
// Check that ProblemDetails contains the errors object with 1 validation error
168+
Assert.True(problemDetails.Extensions.TryGetValue("errors", out var errorsObj));
169+
var errors = Assert.IsType<JsonElement>(errorsObj);
170+
Assert.True(errors.EnumerateObject().Count() == 1);
171+
172+
// Check that ProblemDetails customizations are applied in the response
173+
Assert.True(problemDetails.Extensions.ContainsKey("timestamp"));
174+
}
175+
176+
private class DefaultEndpointRouteBuilder(IApplicationBuilder applicationBuilder) : IEndpointRouteBuilder
177+
{
178+
private IApplicationBuilder ApplicationBuilder { get; } = applicationBuilder ?? throw new ArgumentNullException(nameof(applicationBuilder));
179+
public IApplicationBuilder CreateApplicationBuilder() => ApplicationBuilder.New();
180+
public ICollection<EndpointDataSource> DataSources { get; } = [];
181+
public IServiceProvider ServiceProvider => ApplicationBuilder.ApplicationServices;
182+
}
183+
}

src/Http/Routing/src/ValidationEndpointFilterFactory.cs

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
using System.ComponentModel.DataAnnotations;
77
using System.Linq;
8+
using System.Net.Mime;
89
using System.Reflection;
10+
using Microsoft.AspNetCore.Http.HttpResults;
911
using Microsoft.AspNetCore.Http.Metadata;
1012
using Microsoft.Extensions.DependencyInjection;
1113
using Microsoft.Extensions.Options;
@@ -92,8 +94,27 @@ public static EndpointFilterDelegate Create(EndpointFilterFactoryContext context
9294
if (validateContext is { ValidationErrors.Count: > 0 })
9395
{
9496
context.HttpContext.Response.StatusCode = StatusCodes.Status400BadRequest;
95-
context.HttpContext.Response.ContentType = "application/problem+json";
96-
return await ValueTask.FromResult(new HttpValidationProblemDetails(validateContext.ValidationErrors));
97+
98+
var problemDetails = new HttpValidationProblemDetails(validateContext.ValidationErrors);
99+
100+
var problemDetailsService = context.HttpContext.RequestServices.GetService<IProblemDetailsService>();
101+
if (problemDetailsService is not null)
102+
{
103+
if (await problemDetailsService.TryWriteAsync(new()
104+
{
105+
HttpContext = context.HttpContext,
106+
ProblemDetails = problemDetails
107+
}))
108+
{
109+
// We need to prevent further execution, because the actual
110+
// ProblemDetails response has already been written by ProblemDetailsService.
111+
return EmptyHttpResult.Instance;
112+
}
113+
}
114+
115+
// Fallback to the default implementation.
116+
context.HttpContext.Response.ContentType = MediaTypeNames.Application.ProblemJson;
117+
return problemDetails;
97118
}
98119

99120
return await next(context);

0 commit comments

Comments
 (0)