diff --git a/Directory.Build.targets b/Directory.Build.targets index 181a4b228..ad9b19cce 100644 --- a/Directory.Build.targets +++ b/Directory.Build.targets @@ -43,5 +43,10 @@ + + + + + diff --git a/LSP.sln b/LSP.sln index 0543be637..3cc10f0f9 100644 --- a/LSP.sln +++ b/LSP.sln @@ -44,11 +44,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Protocol", "src\Protocol\Pr EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Server", "src\Server\Server.csproj", "{E540868F-438E-4F7F-BBB7-010D6CB18A57}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{7E4A7675-45F3-4636-88AC-2C158C1A140D}" - ProjectSection(SolutionItems) = preProject - cake.config = cake.config - EndProjectSection -EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Dap.Protocol", "src\Dap.Protocol\Dap.Protocol.csproj", "{F2C9D555-118E-442B-A953-9A7B58A53F33}" EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Dap.Server", "src\Dap.Server\Dap.Server.csproj", "{E1A9123B-A236-4240-8C82-A61BD85C3BF4}" @@ -73,6 +68,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Testing", "src\Testing\Test EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "JsonRpc.Testing", "src\JsonRpc.Testing\JsonRpc.Testing.csproj", "{202BA1AB-25DA-44ED-B962-FD82FCC74543}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "JsonRpc.Generators", "src\JsonRpc.Generators\JsonRpc.Generators.csproj", "{DE259174-73DC-4532-B641-AD218971EE29}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Generation.Tests", "test\Generation.Tests\Generation.Tests.csproj", "{671FFF78-BDD2-4389-B29C-BFD183DA9120}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -293,6 +292,30 @@ Global {202BA1AB-25DA-44ED-B962-FD82FCC74543}.Release|x64.Build.0 = Release|Any CPU {202BA1AB-25DA-44ED-B962-FD82FCC74543}.Release|x86.ActiveCfg = Release|Any CPU {202BA1AB-25DA-44ED-B962-FD82FCC74543}.Release|x86.Build.0 = Release|Any CPU + {DE259174-73DC-4532-B641-AD218971EE29}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {DE259174-73DC-4532-B641-AD218971EE29}.Debug|Any CPU.Build.0 = Debug|Any CPU + {DE259174-73DC-4532-B641-AD218971EE29}.Debug|x64.ActiveCfg = Debug|Any CPU + {DE259174-73DC-4532-B641-AD218971EE29}.Debug|x64.Build.0 = Debug|Any CPU + {DE259174-73DC-4532-B641-AD218971EE29}.Debug|x86.ActiveCfg = Debug|Any CPU + {DE259174-73DC-4532-B641-AD218971EE29}.Debug|x86.Build.0 = Debug|Any CPU + {DE259174-73DC-4532-B641-AD218971EE29}.Release|Any CPU.ActiveCfg = Release|Any CPU + {DE259174-73DC-4532-B641-AD218971EE29}.Release|Any CPU.Build.0 = Release|Any CPU + {DE259174-73DC-4532-B641-AD218971EE29}.Release|x64.ActiveCfg = Release|Any CPU + {DE259174-73DC-4532-B641-AD218971EE29}.Release|x64.Build.0 = Release|Any CPU + {DE259174-73DC-4532-B641-AD218971EE29}.Release|x86.ActiveCfg = Release|Any CPU + {DE259174-73DC-4532-B641-AD218971EE29}.Release|x86.Build.0 = Release|Any CPU + {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Debug|Any CPU.Build.0 = Debug|Any CPU + {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Debug|x64.ActiveCfg = Debug|Any CPU + {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Debug|x64.Build.0 = Debug|Any CPU + {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Debug|x86.ActiveCfg = Debug|Any CPU + {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Debug|x86.Build.0 = Debug|Any CPU + {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Release|Any CPU.ActiveCfg = Release|Any CPU + {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Release|Any CPU.Build.0 = Release|Any CPU + {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Release|x64.ActiveCfg = Release|Any CPU + {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Release|x64.Build.0 = Release|Any CPU + {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Release|x86.ActiveCfg = Release|Any CPU + {671FFF78-BDD2-4389-B29C-BFD183DA9120}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -316,6 +339,8 @@ Global {91919C54-3638-4A3C-963A-327D78368EE3} = {D764E024-3D3F-4112-B932-2DB722A1BACC} {A1EC39EE-AA1F-4EC9-9939-28C3532585C9} = {D764E024-3D3F-4112-B932-2DB722A1BACC} {202BA1AB-25DA-44ED-B962-FD82FCC74543} = {D764E024-3D3F-4112-B932-2DB722A1BACC} + {DE259174-73DC-4532-B641-AD218971EE29} = {D764E024-3D3F-4112-B932-2DB722A1BACC} + {671FFF78-BDD2-4389-B29C-BFD183DA9120} = {2F323ED5-EBF8-45E1-B9D3-C014561B3DDA} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {D38DD0EC-D095-4BCD-B8AF-2D788AF3B9AE} diff --git a/src/Dap.Protocol/Dap.Protocol.csproj b/src/Dap.Protocol/Dap.Protocol.csproj index 9cb4da29c..d25616d5f 100644 --- a/src/Dap.Protocol/Dap.Protocol.csproj +++ b/src/Dap.Protocol/Dap.Protocol.csproj @@ -10,6 +10,8 @@ + + <_Parameter1>OmniSharp.Extensions.LanguageServer, PublicKey=0024000004800000940000000602000000240000525341310004000001000100391db875e68eb4bfef49ce14313b9e13f2cd3cc89eb273bbe6c11a55044c7d4f566cf092e1c77ef9e7c75b1496ae7f95d925938f5a01793dd8d9f99ae0a7595779b71b971287d7d7b5960d052078d14f5ce1a85ea5c9fb2f59ac735ff7bc215cab469b7c3486006860bad6f4c3b5204ea2f28dd4e1d05e2cca462cfd593b9f9f diff --git a/src/JsonRpc.Generation/JsonRpc.Generation.csproj b/src/JsonRpc.Generation/JsonRpc.Generation.csproj new file mode 100644 index 000000000..c053a21df --- /dev/null +++ b/src/JsonRpc.Generation/JsonRpc.Generation.csproj @@ -0,0 +1,15 @@ + + + + + netstandard1.0 + OmniSharp.Extensions.JsonRpc.Generation + OmniSharp.Extensions.JsonRpc.Generation + + + + + + + + diff --git a/src/JsonRpc.Generators/GenerateHandlerMethodsGenerator.cs b/src/JsonRpc.Generators/GenerateHandlerMethodsGenerator.cs new file mode 100644 index 000000000..f186184fd --- /dev/null +++ b/src/JsonRpc.Generators/GenerateHandlerMethodsGenerator.cs @@ -0,0 +1,457 @@ +using System; +using System.Buffers; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using CodeGeneration.Roslyn; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; +using static OmniSharp.Extensions.JsonRpc.Generators.Helpers; + +namespace OmniSharp.Extensions.JsonRpc.Generators +{ + public class GenerateHandlerMethodsGenerator : IRichCodeGenerator + { + private readonly AttributeData _attributeData; + + public GenerateHandlerMethodsGenerator(AttributeData attributeData) + { + _attributeData = attributeData; + } + + public Task GenerateRichAsync(TransformationContext context, IProgress progress, CancellationToken cancellationToken) + { + if (!(context.ProcessingNode is InterfaceDeclarationSyntax handlerInterface)) + { + return Task.FromResult(new RichGenerationResult()); + } + + var methods = new List(); + var additionalUsings = new HashSet() { + "System", + "System.Collections.Generic", + "System.Threading", + "System.Threading.Tasks", + "MediatR", + "Microsoft.Extensions.DependencyInjection", + }; + var symbol = context.SemanticModel.GetDeclaredSymbol(handlerInterface); + + var className = GetExtensionClassName(symbol); + + foreach (var registry in GetRegistries(_attributeData, handlerInterface, symbol, context, progress, additionalUsings)) + { + if (IsNotification(symbol)) + { + var requestType = GetRequestType(symbol); + methods.AddRange(HandleNotifications(handlerInterface, symbol, requestType, registry, additionalUsings)); + } + else if (IsRequest(symbol)) + { + var requestType = GetRequestType(symbol); + var responseType = GetResponseType(symbol); + methods.AddRange(HandleRequest(handlerInterface, symbol, requestType, responseType, registry, additionalUsings)); + } + } + + var existingUsings = context.CompilationUnitUsings + .Join(additionalUsings, z => z.Name.ToFullString(), z => z, (a, b) => b) + ; + + var newUsings = additionalUsings + .Except(existingUsings) + .Select(z => UsingDirective(IdentifierName(z))) + ; + + var attributes = List(new[] { + AttributeList(SeparatedList(new[] { + Attribute(ParseName("System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute")), + Attribute(ParseName("System.Runtime.CompilerServices.CompilerGeneratedAttribute")), + })) + }); + + return Task.FromResult(new RichGenerationResult() { + Usings = List(newUsings), + Members = List(new[] { + NamespaceDeclaration(ParseName(symbol.ContainingNamespace.ToDisplayString())) + + .WithMembers(List(new MemberDeclarationSyntax[] { + ClassDeclaration(className) + .WithAttributeLists(attributes) + .WithModifiers(TokenList( + Token(SyntaxKind.PublicKeyword), + Token(SyntaxKind.StaticKeyword), + Token(SyntaxKind.PartialKeyword) + )) + .WithMembers(List(methods)) + .NormalizeWhitespace() + })) + }) + }); + } + + IEnumerable HandleNotifications( + InterfaceDeclarationSyntax handlerInterface, + INamedTypeSymbol interfaceType, + INamedTypeSymbol requestType, + NameSyntax registryType, + HashSet additionalUsings) + { + var methodName = GetOnMethodName(interfaceType, _attributeData); + + var parameters = ParameterList(SeparatedList(new[] { + Parameter(Identifier("registry")) + .WithType(registryType) + .WithModifiers(TokenList(Token(SyntaxKind.ThisKeyword))) + })); + + var capability = GetCapability(interfaceType); + var registrationOptions = GetRegistrationOptions(interfaceType); + if (capability != null) additionalUsings.Add(capability.ContainingNamespace.ToDisplayString()); + if (registrationOptions != null) additionalUsings.Add(registrationOptions.ContainingNamespace.ToDisplayString()); + + if (registrationOptions == null) + { + var method = MethodDeclaration(registryType, methodName) + .WithModifiers(TokenList( + Token(SyntaxKind.PublicKeyword), + Token(SyntaxKind.StaticKeyword)) + ) + .WithExpressionBody(GetNotificationHandlerExpression(GetMethodName(handlerInterface))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); + + MemberDeclarationSyntax MakeAction(TypeSyntax syntax) + { + return method + .WithParameterList(parameters.AddParameters(Parameter(Identifier("handler")) + .WithType(syntax))) + .NormalizeWhitespace(); + } + + yield return MakeAction(CreateAction(false, requestType)); + yield return MakeAction(CreateAsyncAction(false, requestType)); + yield return MakeAction(CreateAction(true, requestType)); + yield return MakeAction(CreateAsyncAction(true, requestType)); + if (capability != null) + { + yield return MakeAction(CreateAction(requestType, capability)); + yield return MakeAction(CreateAsyncAction(requestType, capability)); + yield return MakeAction(CreateAction(requestType, capability)); + yield return MakeAction(CreateAsyncAction(requestType, capability)); + } + } + else + { + var method = MethodDeclaration(registryType, methodName) + .WithModifiers(TokenList( + Token(SyntaxKind.PublicKeyword), + Token(SyntaxKind.StaticKeyword)) + ) + .WithBody(GetNotificationRegistrationHandlerExpression(GetMethodName(handlerInterface), requestType, registrationOptions)); + + var registrationParameter = Parameter(Identifier("registrationOptions")) + .WithType(IdentifierName(registrationOptions.Name)); + + MemberDeclarationSyntax MakeAction(TypeSyntax syntax) + { + return method + .WithParameterList(parameters.WithParameters(SeparatedList(parameters.Parameters.Concat( + new[] {Parameter(Identifier("handler")).WithType(syntax), registrationParameter})))) + .NormalizeWhitespace(); + } + + yield return MakeAction(CreateAction(false, requestType)); + yield return MakeAction(CreateAsyncAction(false, requestType)); + yield return MakeAction(CreateAction(true, requestType)); + yield return MakeAction(CreateAsyncAction(true, requestType)); + if (capability != null) + { + method = method.WithBody( + GetNotificationRegistrationHandlerExpression(GetMethodName(handlerInterface), requestType, registrationOptions, capability)); + yield return MakeAction(CreateAction(requestType, capability)); + yield return MakeAction(CreateAsyncAction(requestType, capability)); + } + } + } + + IEnumerable HandleRequest( + InterfaceDeclarationSyntax handlerInterface, + INamedTypeSymbol interfaceType, + INamedTypeSymbol requestType, + INamedTypeSymbol responseType, + NameSyntax registryType, + HashSet additionalUsings) + { + var methodName = GetOnMethodName(interfaceType, _attributeData); + + var capability = GetCapability(interfaceType); + var registrationOptions = GetRegistrationOptions(interfaceType); + var partialItems = GetPartialItems(requestType); + var partialItem = GetPartialItem(requestType); + if (capability != null) additionalUsings.Add(capability.ContainingNamespace.ToDisplayString()); + if (registrationOptions != null) additionalUsings.Add(registrationOptions.ContainingNamespace.ToDisplayString()); + if (partialItems != null) additionalUsings.Add(partialItems.ContainingNamespace.ToDisplayString()); + if (partialItem != null) additionalUsings.Add(partialItem.ContainingNamespace.ToDisplayString()); + + var parameters = ParameterList(SeparatedList(new[] { + Parameter(Identifier("registry")) + .WithType(registryType) + .WithModifiers(TokenList(Token(SyntaxKind.ThisKeyword))) + })); + + + if (registrationOptions == null) + { + var method = MethodDeclaration(registryType, methodName) + .WithModifiers(TokenList( + Token(SyntaxKind.PublicKeyword), + Token(SyntaxKind.StaticKeyword)) + ) + .WithExpressionBody(GetRequestHandlerExpression(GetMethodName(handlerInterface))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); + + MemberDeclarationSyntax MakeAction(TypeSyntax syntax) + { + return method + .WithParameterList(parameters.AddParameters(Parameter(Identifier("handler")) + .WithType(syntax))) + .NormalizeWhitespace(); + } + + yield return MakeAction(CreateAsyncFunc(responseType, false, requestType)); + yield return MakeAction(CreateAsyncFunc(responseType, true, requestType)); + if (partialItems != null) + { + var partialTypeSyntax = ResolveTypeName(partialItems); + var partialItemsSyntax = GenericName("IEnumerable").WithTypeArgumentList(TypeArgumentList(SeparatedList(new[] {partialTypeSyntax}))); + + method = method.WithExpressionBody(GetPartialResultsHandlerExpression(GetMethodName(handlerInterface), requestType, partialTypeSyntax, responseType)); + + yield return MakeAction(CreatePartialAction(requestType, partialItemsSyntax, true)); + yield return MakeAction(CreatePartialAction(requestType, partialItemsSyntax, false)); + if (capability != null) + { + method = method.WithExpressionBody(GetPartialResultsCapabilityHandlerExpression(GetMethodName(handlerInterface), requestType, responseType, + partialTypeSyntax, capability)); + yield return MakeAction(CreatePartialAction(requestType, partialItemsSyntax, capability)); + } + } + + if (partialItem != null) + { + var partialTypeSyntax = ResolveTypeName(partialItem); + + method = method.WithExpressionBody(GetPartialResultHandlerExpression(GetMethodName(handlerInterface), requestType, responseType)); + + yield return MakeAction(CreatePartialAction(requestType, partialTypeSyntax, true)); + yield return MakeAction(CreatePartialAction(requestType, partialTypeSyntax, false)); + if (capability != null) + { + method = method.WithExpressionBody(GetPartialResultCapabilityHandlerExpression(GetMethodName(handlerInterface), requestType, responseType, capability)); + yield return MakeAction(CreatePartialAction(requestType, partialTypeSyntax, capability)); + } + } + + if (capability != null) + { + method = method.WithExpressionBody( + GetRequestCapabilityHandlerExpression(GetMethodName(handlerInterface), requestType, responseType, capability)); + yield return MakeAction(CreateAsyncFunc(responseType, requestType, capability)); + } + } + else + { + var method = MethodDeclaration(registryType, methodName) + .WithModifiers(TokenList( + Token(SyntaxKind.PublicKeyword), + Token(SyntaxKind.StaticKeyword)) + ) + .WithBody(GetRequestRegistrationHandlerExpression(GetMethodName(handlerInterface), requestType, responseType, registrationOptions)); + if (responseType.Name == "Unit") + { + method = method.WithBody(GetVoidRequestRegistrationHandlerExpression(GetMethodName(handlerInterface), requestType, registrationOptions)); + } + + var registrationParameter = Parameter(Identifier("registrationOptions")) + .WithType(IdentifierName(registrationOptions.Name)); + + MemberDeclarationSyntax MakeAction(TypeSyntax syntax) + { + return method + .WithParameterList(parameters.WithParameters(SeparatedList(parameters.Parameters.Concat( + new[] {Parameter(Identifier("handler")).WithType(syntax), registrationParameter})))) + .NormalizeWhitespace(); + } + + yield return MakeAction(CreateAsyncFunc(responseType, false, requestType)); + yield return MakeAction(CreateAsyncFunc(responseType, true, requestType)); + + if (partialItems != null) + { + var partialTypeSyntax = ResolveTypeName(partialItems); + var partialItemsSyntax = GenericName("IEnumerable").WithTypeArgumentList(TypeArgumentList(SeparatedList(new[] {partialTypeSyntax}))); + + method = method.WithBody(GetPartialResultsRegistrationHandlerExpression(GetMethodName(handlerInterface), requestType, responseType, partialTypeSyntax, + registrationOptions)); + + yield return MakeAction(CreatePartialAction(requestType, partialItemsSyntax, true)); + yield return MakeAction(CreatePartialAction(requestType, partialItemsSyntax, false)); + if (capability != null) + { + method = method.WithBody(GetPartialResultsRegistrationHandlerExpression(GetMethodName(handlerInterface), requestType, responseType, partialTypeSyntax, + registrationOptions, capability)); + yield return MakeAction(CreatePartialAction(requestType, partialItemsSyntax, capability)); + } + } + + if (partialItem != null) + { + var partialTypeSyntax = ResolveTypeName(partialItem); + + method = method.WithBody(GetPartialResultRegistrationHandlerExpression(GetMethodName(handlerInterface), requestType, responseType, registrationOptions)); + + yield return MakeAction(CreatePartialAction(requestType, partialTypeSyntax, true)); + yield return MakeAction(CreatePartialAction(requestType, partialTypeSyntax, false)); + if (capability != null) + { + method = method.WithBody(GetPartialResultRegistrationHandlerExpression(GetMethodName(handlerInterface), requestType, responseType, registrationOptions, + capability)); + yield return MakeAction(CreatePartialAction(requestType, partialTypeSyntax, capability)); + } + } + + if (capability != null) + { + method = method.WithBody( + GetRequestRegistrationHandlerExpression(GetMethodName(handlerInterface), requestType, responseType, registrationOptions, capability)); + if (responseType.Name == "Unit") + { + method = method.WithBody(GetVoidRequestRegistrationHandlerExpression(GetMethodName(handlerInterface), requestType, registrationOptions, capability)); + } + + yield return MakeAction(CreateAsyncFunc(responseType, requestType, capability)); + } + } + } + + static IEnumerable GetRegistries( + AttributeData attributeData, + InterfaceDeclarationSyntax interfaceSyntax, + INamedTypeSymbol interfaceType, + TransformationContext context, + IProgress progress, + HashSet additionalUsings) + { + if (attributeData.ConstructorArguments[0].Values.Length > 0) + { + return attributeData.ConstructorArguments[0].Values.Select(z => z.Value).OfType() + .Select(ResolveTypeName); + } + + if (interfaceType.ContainingNamespace.ToDisplayString().StartsWith("OmniSharp.Extensions.LanguageServer.Protocol")) + { + var attribute = interfaceType.GetAttributes().First(z => z.AttributeClass?.Name == "MethodAttribute"); + if (attribute.ConstructorArguments.Length < 2) + { + progress.Report(Diagnostic.Create(GeneratorDiagnostics.MissingDirection, interfaceSyntax.Identifier.GetLocation())); + return Enumerable.Empty(); + } + + var direction = (int) interfaceType.GetAttributes().First(z => z.AttributeClass?.Name == "MethodAttribute").ConstructorArguments[1].Value; + + /* + Unspecified = 0b0000, + ServerToClient = 0b0001, + ClientToServer = 0b0010, + Bidirectional = 0b0011 + */ + var maskedDirection = (0b0011 & direction); + + + if (maskedDirection == 1) + { + additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Client"); + additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Client.Capabilities"); + return new[] {LanguageProtocolServerToClient}; + } + + if (maskedDirection == 2) + { + additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Server"); + return new[] {LanguageProtocolClientToServer}; + } + + if (maskedDirection == 3) + { + additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Client"); + additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Client.Capabilities"); + additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Server"); + return new[] {LanguageProtocolClientToServer, LanguageProtocolServerToClient}; + } + } + + if (interfaceType.ContainingNamespace.ToDisplayString().StartsWith("OmniSharp.Extensions.DebugAdapter.Protocol")) + { + var attribute = interfaceType.GetAttributes().First(z => z.AttributeClass?.Name == "MethodAttribute"); + if (attribute.ConstructorArguments.Length < 2) + { + progress.Report(Diagnostic.Create(GeneratorDiagnostics.MissingDirection, interfaceSyntax.Identifier.GetLocation())); + return Enumerable.Empty(); + } + + var direction = (int) interfaceType.GetAttributes().First(z => z.AttributeClass?.Name == "MethodAttribute").ConstructorArguments[1].Value; + + /* + Unspecified = 0b0000, + ServerToClient = 0b0001, + ClientToServer = 0b0010, + Bidirectional = 0b0011 + */ + var maskedDirection = (0b0011 & direction); + additionalUsings.Add("OmniSharp.Extensions.DebugAdapter.Protocol"); + + if (maskedDirection == 1) + { + return new[] {DebugProtocolServerToClient}; + } + + if (maskedDirection == 2) + { + return new[] {DebugProtocolClientToServer}; + } + + if (maskedDirection == 3) + { + return new[] {DebugProtocolClientToServer, DebugProtocolServerToClient}; + } + } + + throw new NotImplementedException("Add inference logic here " + interfaceSyntax.Identifier.ToFullString()); + } + + private static NameSyntax LanguageProtocolServerToClient { get; } = + IdentifierName("ILanguageClientRegistry"); + + private static NameSyntax LanguageProtocolClientToServer { get; } = + IdentifierName("ILanguageServerRegistry"); + + private static NameSyntax DebugProtocolServerToClient { get; } = + IdentifierName("IDebugAdapterClientRegistry"); + + private static NameSyntax DebugProtocolClientToServer { get; } = + IdentifierName("IDebugAdapterServerRegistry"); + + public Task> GenerateAsync(TransformationContext context, IProgress progress, CancellationToken cancellationToken) => + throw new NotImplementedException(); + } + + /* + IJsonRpcNotificationHandler + IJsonRpcRequestHandler + IJsonRpcRequestHandler + */ +} diff --git a/src/JsonRpc.Generators/GenerateRequestMethodsGenerator.cs b/src/JsonRpc.Generators/GenerateRequestMethodsGenerator.cs new file mode 100644 index 000000000..9ba5ead07 --- /dev/null +++ b/src/JsonRpc.Generators/GenerateRequestMethodsGenerator.cs @@ -0,0 +1,338 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using CodeGeneration.Roslyn; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; +using static OmniSharp.Extensions.JsonRpc.Generators.Helpers; + +namespace OmniSharp.Extensions.JsonRpc.Generators +{ + public class GenerateRequestMethodsGenerator : IRichCodeGenerator + { + private readonly AttributeData _attributeData; + + public GenerateRequestMethodsGenerator(AttributeData attributeData) + { + _attributeData = attributeData; + } + + public Task> GenerateAsync(TransformationContext context, IProgress progress, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } + + public Task GenerateRichAsync(TransformationContext context, IProgress progress, CancellationToken cancellationToken) + { + if (!(context.ProcessingNode is InterfaceDeclarationSyntax handlerInterface)) + { + return Task.FromResult(new RichGenerationResult()); + } + + var methods = new List(); + var additionalUsings = new HashSet() { + "System", + "System.Collections.Generic", + "System.Threading", + "System.Threading.Tasks", + "MediatR", + "Microsoft.Extensions.DependencyInjection", + }; + var symbol = context.SemanticModel.GetDeclaredSymbol(handlerInterface); + + var className = GetExtensionClassName(symbol); + + var registries = GetProxies(_attributeData, handlerInterface, symbol, context, progress, additionalUsings); + + if (_attributeData.ConstructorArguments[0].Values.Length == 0 && !symbol.ContainingNamespace.ToDisplayString().StartsWith("OmniSharp.Extensions.DebugAdapter.Protocol")) + { + progress.Report(Diagnostic.Create(GeneratorDiagnostics.NoResponseRouterProvided, handlerInterface.Identifier.GetLocation(), symbol.Name, + string.Join(", ", registries.Select(z => z.ToFullString())))); + } + + foreach (var registry in registries) + { + if (IsNotification(symbol)) + { + var requestType = GetRequestType(symbol); + methods.AddRange(HandleNotifications(handlerInterface, symbol, requestType, registry, additionalUsings)); + } + + if (IsRequest(symbol)) + { + var requestType = GetRequestType(symbol); + var responseType = GetResponseType(symbol); + methods.AddRange(HandleRequests(handlerInterface, symbol, requestType, responseType, registry, additionalUsings)); + } + } + + + var attributes = List(new[] { + AttributeList(SeparatedList(new[] { + Attribute(ParseName("System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute")), + Attribute(ParseName("System.Runtime.CompilerServices.CompilerGeneratedAttribute")), + })) + }); + if (symbol.GetAttributes().Any(z => z.AttributeClass.Name == "GenerateRequestMethodsAttribute")) + { + attributes = List(); + } + + var existingUsings = context.CompilationUnitUsings + .Join(additionalUsings, z => z.Name.ToFullString(), z => z, (a, b) => b) + ; + + var newUsings = additionalUsings + .Except(existingUsings) + .Select(z => UsingDirective(IdentifierName(z))) + ; + return Task.FromResult(new RichGenerationResult() { + Usings = List(newUsings), + Members = List(new[] { + NamespaceDeclaration(ParseName(symbol.ContainingNamespace.ToDisplayString())) + .WithMembers(List(new MemberDeclarationSyntax[] { + ClassDeclaration(className) + .WithAttributeLists(attributes) + .WithModifiers(TokenList( + Token(SyntaxKind.PublicKeyword), + Token(SyntaxKind.StaticKeyword), + Token(SyntaxKind.PartialKeyword) + )) + .WithMembers(List(methods)) + .NormalizeWhitespace() + })) + }) + }); + } + + public static IEnumerable GetProxies( + AttributeData attributeData, + InterfaceDeclarationSyntax interfaceSyntax, + INamedTypeSymbol interfaceType, + TransformationContext context, + IProgress progress, + HashSet additionalUsings) + { + if (attributeData.ConstructorArguments[0].Values.Length > 0) + { + return attributeData.ConstructorArguments[0].Values.Select(z => z.Value).OfType() + .Select(ResolveTypeName); + } + + if (interfaceType.ContainingNamespace.ToDisplayString().StartsWith("OmniSharp.Extensions.LanguageServer.Protocol")) + { + var attribute = interfaceType.GetAttributes().First(z => z.AttributeClass?.Name == "MethodAttribute"); + if (attribute.ConstructorArguments.Length < 2) + { + progress.Report(Diagnostic.Create(GeneratorDiagnostics.MissingDirection, interfaceSyntax.Identifier.GetLocation())); + return Enumerable.Empty(); + } + + var direction = (int) interfaceType.GetAttributes().First(z => z.AttributeClass?.Name == "MethodAttribute").ConstructorArguments[1].Value; + + /* + Unspecified = 0b0000, + ServerToClient = 0b0001, + ClientToServer = 0b0010, + Bidirectional = 0b0011 + */ + var maskedDirection = (0b0011 & direction); + + if (maskedDirection == 1) + { + additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Server"); + return new[] {LanguageProtocolServerToClient}; + } + + if (maskedDirection == 2) + { + additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Client"); + return new[] {LanguageProtocolClientToServer}; + } + + if (maskedDirection == 3) + { + additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Server"); + additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Client"); + return new[] {LanguageProtocolClientToServer, LanguageProtocolServerToClient}; + } + } + + if (interfaceType.ContainingNamespace.ToDisplayString().StartsWith("OmniSharp.Extensions.DebugAdapter.Protocol")) + { + var attribute = interfaceType.GetAttributes().First(z => z.AttributeClass?.Name == "MethodAttribute"); + if (attribute.ConstructorArguments.Length < 2) + { + progress.Report(Diagnostic.Create(GeneratorDiagnostics.MissingDirection, interfaceSyntax.Identifier.GetLocation())); + return Enumerable.Empty(); + } + + var direction = (int) interfaceType.GetAttributes().First(z => z.AttributeClass?.Name == "MethodAttribute").ConstructorArguments[1].Value; + + /* + Unspecified = 0b0000, + ServerToClient = 0b0001, + ClientToServer = 0b0010, + Bidirectional = 0b0011 + */ + var maskedDirection = (0b0011 & direction); + additionalUsings.Add("OmniSharp.Extensions.DebugAdapter.Protocol"); + + if (maskedDirection == 1) + { + return new[] {DebugProtocolServerToClient}; + } + + if (maskedDirection == 2) + { + return new[] {DebugProtocolClientToServer}; + } + + if (maskedDirection == 3) + { + return new[] {DebugProtocolClientToServer, DebugProtocolServerToClient}; + } + } + + throw new NotImplementedException("Add inference logic here " + interfaceSyntax.Identifier.ToFullString()); + } + + private static NameSyntax LanguageProtocolServerToClient { get; } = + ParseName("ILanguageServer"); + + private static NameSyntax LanguageProtocolClientToServer { get; } = + ParseName("ILanguageClient"); + + private static NameSyntax DebugProtocolServerToClient { get; } = + ParseName("IDebugAdapterServer"); + + private static NameSyntax DebugProtocolClientToServer { get; } = + ParseName("IDebugAdapterClient"); + + IEnumerable HandleNotifications( + InterfaceDeclarationSyntax handlerInterface, + INamedTypeSymbol interfaceType, + INamedTypeSymbol requestType, + NameSyntax registryType, + HashSet additionalUsings) + { + var methodName = GetSendMethodName(interfaceType, _attributeData); + var method = MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), methodName) + .WithModifiers(TokenList( + Token(SyntaxKind.PublicKeyword), + Token(SyntaxKind.StaticKeyword)) + ) + .WithExpressionBody(GetNotificationInvokeExpression()) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); + + yield return method + .WithParameterList( + ParameterList(SeparatedList(new[] { + Parameter(Identifier("mediator")) + .WithType(registryType) + .WithModifiers(TokenList(Token(SyntaxKind.ThisKeyword))), + Parameter(Identifier("@params")) + .WithType(IdentifierName(requestType.Name)) + }))) + .NormalizeWhitespace(); + } + + IEnumerable HandleRequests( + InterfaceDeclarationSyntax handlerInterface, + INamedTypeSymbol interfaceType, + INamedTypeSymbol requestType, + INamedTypeSymbol responseType, + NameSyntax registryType, + HashSet additionalUsings) + { + var methodName = GetSendMethodName(interfaceType, _attributeData); + var parameterList = ParameterList(SeparatedList(new[] { + Parameter(Identifier("mediator")) + .WithType(registryType) + .WithModifiers(TokenList(Token(SyntaxKind.ThisKeyword))), + Parameter(Identifier("@params")) + .WithType(IdentifierName(requestType.Name)), + Parameter(Identifier("cancellationToken")) + .WithType(IdentifierName("CancellationToken")) + .WithDefault(EqualsValueClause( + LiteralExpression(SyntaxKind.DefaultLiteralExpression, Token(SyntaxKind.DefaultKeyword))) + ) + })); + var partialItem = GetPartialItem(requestType); + if (partialItem != null) + { + additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Progress"); + var partialTypeSyntax = ResolveTypeName(partialItem); + yield return MethodDeclaration( + GenericName( + Identifier("IRequestProgressObservable")) + .WithTypeArgumentList( + TypeArgumentList( + SeparatedList( + new TypeSyntax[] { + partialTypeSyntax, + ResolveTypeName(responseType) + }))), + Identifier(methodName) + ) + .WithModifiers(TokenList( + Token(SyntaxKind.PublicKeyword), + Token(SyntaxKind.StaticKeyword)) + ) + .WithParameterList(parameterList) + .WithExpressionBody(GetPartialInvokeExpression(ResolveTypeName(responseType))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) + .NormalizeWhitespace(); + yield break; + } + + var partialItems = GetPartialItems(requestType); + if (partialItems != null) + { + additionalUsings.Add("OmniSharp.Extensions.LanguageServer.Protocol.Progress"); + var partialTypeSyntax = ResolveTypeName(partialItems); + var partialItemsSyntax = GenericName("IEnumerable").WithTypeArgumentList(TypeArgumentList(SeparatedList(new[] {partialTypeSyntax}))); + yield return MethodDeclaration( + GenericName( + Identifier("IRequestProgressObservable")) + .WithTypeArgumentList( + TypeArgumentList( + SeparatedList( + new TypeSyntax[] { + partialItemsSyntax, + ResolveTypeName(responseType) + }))), + Identifier(methodName) + ) + .WithModifiers(TokenList( + Token(SyntaxKind.PublicKeyword), + Token(SyntaxKind.StaticKeyword)) + ) + .WithParameterList(parameterList) + .WithExpressionBody(GetPartialInvokeExpression(ResolveTypeName(responseType))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) + .NormalizeWhitespace(); + ; + yield break; + } + + + var responseSyntax = responseType.Name == "Unit" + ? IdentifierName("Task") as NameSyntax + : GenericName("Task").WithTypeArgumentList(TypeArgumentList(SeparatedList(new[] {ResolveTypeName(responseType)}))); + yield return MethodDeclaration(responseSyntax, methodName) + .WithModifiers(TokenList( + Token(SyntaxKind.PublicKeyword), + Token(SyntaxKind.StaticKeyword)) + ) + .WithParameterList(parameterList) + .WithExpressionBody(GetRequestInvokeExpression()) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) + .NormalizeWhitespace(); + } + } +} diff --git a/src/JsonRpc.Generators/GeneratorDiagnostics.cs b/src/JsonRpc.Generators/GeneratorDiagnostics.cs new file mode 100644 index 000000000..e00ef03e1 --- /dev/null +++ b/src/JsonRpc.Generators/GeneratorDiagnostics.cs @@ -0,0 +1,16 @@ +using Microsoft.CodeAnalysis; + +namespace OmniSharp.Extensions.JsonRpc.Generators +{ + static class GeneratorDiagnostics + { + public static DiagnosticDescriptor MissingDirection { get; } = new DiagnosticDescriptor("LSP1000", "Missing Direction", + "No direction defined for Language Server Protocol Handler", "JsonRPC", DiagnosticSeverity.Warning, true); + + public static DiagnosticDescriptor NoHandlerRegistryProvided { get; } = new DiagnosticDescriptor("JRPC1000", "No Handler Registry Provided", + "No Handler Registry Provided for handler {0}.", "JsonRPC", DiagnosticSeverity.Warning, true); + + public static DiagnosticDescriptor NoResponseRouterProvided { get; } = new DiagnosticDescriptor("JRPC1001", "No Response Router Provided", + "No Response Router Provided for handler {0}, defaulting to {1}.", "JsonRPC", DiagnosticSeverity.Warning, true); + } +} diff --git a/src/JsonRpc.Generators/Helpers.cs b/src/JsonRpc.Generators/Helpers.cs new file mode 100644 index 000000000..401cb808f --- /dev/null +++ b/src/JsonRpc.Generators/Helpers.cs @@ -0,0 +1,868 @@ +#nullable enable +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.RegularExpressions; +using System.Threading; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace OmniSharp.Extensions.JsonRpc.Generators +{ + static class Helpers + { + public static bool IsNotification(INamedTypeSymbol symbol) + { + return symbol.AllInterfaces.Any(z => z.Name == "IJsonRpcNotificationHandler"); + } + + public static bool IsRequest(INamedTypeSymbol symbol) + { + return symbol.AllInterfaces.Any(z => z.Name == "IJsonRpcRequestHandler"); + } + + public static ExpressionSyntax GetMethodName(InterfaceDeclarationSyntax interfaceSyntax) + { + var methodAttribute = interfaceSyntax.AttributeLists + .SelectMany(z => z.Attributes) + .First(z => z.Name.ToString() == "MethodAttribute" || z.Name.ToString() == "Method"); + + return methodAttribute.ArgumentList.Arguments[0].Expression; + } + + public static INamedTypeSymbol GetRequestType(INamedTypeSymbol symbol) + { + var handlerInterface = symbol.AllInterfaces.First(z => z.Name == "IRequestHandler" && z.TypeArguments.Length == 2); + return handlerInterface.TypeArguments[0] as INamedTypeSymbol; + } + + public static INamedTypeSymbol GetResponseType(INamedTypeSymbol symbol) + { + var handlerInterface = symbol.AllInterfaces.First(z => z.Name == "IRequestHandler" && z.TypeArguments.Length == 2); + return handlerInterface.TypeArguments[1] as INamedTypeSymbol; + } + + public static INamedTypeSymbol? GetCapability(INamedTypeSymbol symbol) + { + var handlerInterface = symbol.AllInterfaces + .FirstOrDefault(z => z.Name == "ICapability" && z.TypeArguments.Length == 1); + return handlerInterface?.TypeArguments[0] as INamedTypeSymbol; + } + + public static INamedTypeSymbol? GetRegistrationOptions(INamedTypeSymbol symbol) + { + var handlerInterface = symbol.AllInterfaces + .FirstOrDefault(z => z.Name == "IRegistration" && z.TypeArguments.Length == 1); + return handlerInterface?.TypeArguments[0] as INamedTypeSymbol; + } + + public static INamedTypeSymbol? GetPartialItems(INamedTypeSymbol symbol) + { + var handlerInterface = symbol.AllInterfaces + .FirstOrDefault(z => z.Name == "IPartialItems" && z.TypeArguments.Length == 1); + return handlerInterface?.TypeArguments[0] as INamedTypeSymbol; + } + + public static INamedTypeSymbol? GetPartialItem(INamedTypeSymbol symbol) + { + var handlerInterface = symbol.AllInterfaces + .FirstOrDefault(z => z.Name == "IPartialItem" && z.TypeArguments.Length == 1); + return handlerInterface?.TypeArguments[0] as INamedTypeSymbol; + } + + public static GenericNameSyntax CreateAction(bool withCancellationToken, params ITypeSymbol[] types) + { + var typeArguments = types.Select(ResolveTypeName).ToList(); + if (withCancellationToken) + { + typeArguments.Add(IdentifierName("CancellationToken")); + } + + return GenericName(Identifier("Action")) + .WithTypeArgumentList(TypeArgumentList(SeparatedList(typeArguments))); + } + + public static NameSyntax ResolveTypeName(ITypeSymbol symbol) + { + if (symbol is INamedTypeSymbol namedTypeSymbol) + { + if (namedTypeSymbol.IsGenericType) + { + // TODO: Fix for generic types + return ParseName(namedTypeSymbol.ToString()); + } + + // Assume that we're adding the correct namespaces. + return IdentifierName(namedTypeSymbol.Name); + } + + return IdentifierName(symbol.Name); + } + + public static GenericNameSyntax CreateAction(params ITypeSymbol[] types) => CreateAction(true, types); + + public static GenericNameSyntax CreateAsyncAction(params ITypeSymbol[] types) => CreateAsyncFunc(null, true, types); + + public static GenericNameSyntax CreateAsyncAction(bool withCancellationToken, params ITypeSymbol[] types) => CreateAsyncFunc(null, withCancellationToken, types); + + public static GenericNameSyntax CreateAsyncFunc(ITypeSymbol? responseType, params ITypeSymbol[] types) => CreateAsyncFunc(responseType, true, types); + + public static GenericNameSyntax CreateAsyncFunc(ITypeSymbol? responseType, bool withCancellationToken, params ITypeSymbol[] types) + { + var typeArguments = types.Select(ResolveTypeName).ToList(); + if (withCancellationToken) + { + typeArguments.Add(IdentifierName("CancellationToken")); + } + + if (responseType == null || responseType.Name == "Unit") + { + typeArguments.Add(IdentifierName("Task")); + } + else + { + typeArguments.Add(GenericName(Identifier("Task"), TypeArgumentList(SeparatedList(new TypeSyntax[] { + ResolveTypeName(responseType) + })))); + } + + return GenericName(Identifier("Func")) + .WithTypeArgumentList(TypeArgumentList(SeparatedList(typeArguments))); + } + + public static GenericNameSyntax CreatePartialAction(ITypeSymbol requestType, NameSyntax partialType, bool withCancellationToken, params ITypeSymbol[] types) + { + var typeArguments = new List() { + ResolveTypeName(requestType), + GenericName("IObserver").WithTypeArgumentList(TypeArgumentList(SeparatedList(new TypeSyntax[] {partialType}))), + }; + typeArguments.AddRange(types.Select(ResolveTypeName)); + if (withCancellationToken) + { + typeArguments.Add(IdentifierName("CancellationToken")); + } + + return GenericName(Identifier("Action")) + .WithTypeArgumentList(TypeArgumentList(SeparatedList(typeArguments))); + } + + public static GenericNameSyntax CreatePartialAction(ITypeSymbol requestType, NameSyntax partialType, params ITypeSymbol[] types) => + CreatePartialAction(requestType, partialType, true, types); + + private static ExpressionStatementSyntax EnsureRegistrationOptionsIsSet(NameSyntax registrationOptionsName, TypeSyntax registrationOptionsType) + { + return ExpressionStatement(AssignmentExpression( + SyntaxKind.CoalesceAssignmentExpression, + registrationOptionsName, + ObjectCreationExpression(registrationOptionsType) + .WithArgumentList(ArgumentList()))); + } + + private static InvocationExpressionSyntax AddHandler(params ArgumentSyntax[] arguments) => InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("registry"), + IdentifierName("AddHandler"))) + .WithArgumentList(ArgumentList(SeparatedList(arguments))); + + private static ArgumentListSyntax GetHandlerArgumentList() => + ArgumentList(SeparatedList(new[] { + Argument(IdentifierName("handler")) + })); + + private static ArgumentListSyntax GetRegistrationHandlerArgumentList(NameSyntax registrationOptionsName) => + ArgumentList(SeparatedList(new[] { + Argument(IdentifierName("handler")), + Argument(registrationOptionsName) + })); + + private static ArgumentListSyntax GetPartialResultArgumentList(NameSyntax responseName) => + ArgumentList( + SeparatedList( + new[] { + Argument(IdentifierName("handler")), + Argument(InvocationExpression(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("_"), + GenericName(Identifier("GetService")) + .WithTypeArgumentList(TypeArgumentList(SingletonSeparatedList(IdentifierName("IProgressManager")))) + ))), + Argument(SimpleLambdaExpression(Parameter(Identifier("values")), + ObjectCreationExpression(responseName) + .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName("values"))))))) + })); + + private static ArgumentListSyntax GetPartialResultRegistrationArgumentList(NameSyntax registrationOptionsName, NameSyntax responseName) => + ArgumentList( + SeparatedList( + new[] { + Argument(IdentifierName("handler")), + Argument(registrationOptionsName), + Argument(InvocationExpression(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("_"), + GenericName(Identifier("GetService")) + .WithTypeArgumentList(TypeArgumentList(SingletonSeparatedList(IdentifierName("IProgressManager")))) + ))), + Argument(SimpleLambdaExpression(Parameter(Identifier("values")), + ObjectCreationExpression(responseName) + .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName("values"))))))) + })); + + private static ArgumentListSyntax GetPartialItemsArgumentList(NameSyntax responseName) => + ArgumentList( + SeparatedList( + new[] { + Argument(IdentifierName("handler")), + Argument(InvocationExpression(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("_"), + GenericName(Identifier("GetService")) + .WithTypeArgumentList(TypeArgumentList(SingletonSeparatedList(IdentifierName("IProgressManager")))) + ))), + Argument(SimpleLambdaExpression(Parameter(Identifier("values")), + ObjectCreationExpression(responseName) + .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName("values"))))))) + })); + + private static ArgumentListSyntax GetPartialItemsRegistrationArgumentList(NameSyntax registrationOptionsName, NameSyntax responseName) => + ArgumentList( + SeparatedList( + new[] { + Argument(IdentifierName("handler")), + Argument(registrationOptionsName), + Argument(InvocationExpression(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("_"), + GenericName(Identifier("GetService")) + .WithTypeArgumentList(TypeArgumentList(SingletonSeparatedList(IdentifierName("IProgressManager")))) + ))), + Argument(SimpleLambdaExpression(Parameter(Identifier("values")), + ObjectCreationExpression(responseName) + .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName("values"))))))) + })); + + private static ObjectCreationExpressionSyntax CreateHandlerArgument(NameSyntax className, string innerClassName, params TypeSyntax[] genericArguments) + { + return ObjectCreationExpression( + QualifiedName( + className, + GenericName(innerClassName).WithTypeArgumentList(TypeArgumentList(SeparatedList(genericArguments))) + ) + ); + } + + public static ArrowExpressionClauseSyntax GetNotificationCapabilityHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol capability) + { + var requestName = ResolveTypeName(requestType); + var capabilityName = ResolveTypeName(capability); + return ArrowExpressionClause( + AddHandler( + Argument(nameExpression), + Argument( + CreateHandlerArgument( + IdentifierName("LanguageProtocolDelegatingHandlers"), + "NotificationCapability", + requestName, + capabilityName + ) + .WithArgumentList(GetHandlerArgumentList())) + ) + ); + } + + public static BlockSyntax GetNotificationRegistrationHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol registrationOptions) + { + var requestName = ResolveTypeName(requestType); + var registrationOptionsName = ResolveTypeName(registrationOptions); + return Block( + EnsureRegistrationOptionsIsSet(IdentifierName("registrationOptions"), registrationOptionsName), + ReturnStatement( + AddHandler( + Argument(nameExpression), + Argument( + CreateHandlerArgument( + IdentifierName("LanguageProtocolDelegatingHandlers"), + "Notification", + requestName, + registrationOptionsName + ) + .WithArgumentList(GetRegistrationHandlerArgumentList(IdentifierName("registrationOptions")))) + ) + ) + ); + } + + public static BlockSyntax GetNotificationRegistrationHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol registrationOptions, + ITypeSymbol capability) + { + var requestName = ResolveTypeName(requestType); + var registrationOptionsName = ResolveTypeName(registrationOptions); + var capabilityName = ResolveTypeName(capability); + return Block( + EnsureRegistrationOptionsIsSet(IdentifierName("registrationOptions"), registrationOptionsName), + ReturnStatement( + AddHandler( + Argument(nameExpression), + Argument( + CreateHandlerArgument( + IdentifierName("LanguageProtocolDelegatingHandlers"), + "Notification", + requestName, + capabilityName, + registrationOptionsName + ) + .WithArgumentList(GetRegistrationHandlerArgumentList(IdentifierName("registrationOptions")))) + ) + ) + ); + } + + public static ArrowExpressionClauseSyntax GetRequestCapabilityHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType, + ITypeSymbol capability) + { + var requestName = ResolveTypeName(requestType); + var responseName = ResolveTypeName(responseType); + var capabilityName = ResolveTypeName(capability); + return ArrowExpressionClause( + AddHandler( + Argument(nameExpression), + Argument( + CreateHandlerArgument( + IdentifierName("LanguageProtocolDelegatingHandlers"), + "RequestCapability", + requestName, + responseName, + capabilityName + ) + .WithArgumentList(GetHandlerArgumentList())) + ) + ); + } + + public static BlockSyntax GetRequestRegistrationHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType, + ITypeSymbol registrationOptions) + { + var requestName = ResolveTypeName(requestType); + var responseName = ResolveTypeName(responseType); + var registrationOptionsName = ResolveTypeName(registrationOptions); + return Block( + EnsureRegistrationOptionsIsSet(IdentifierName("registrationOptions"), registrationOptionsName), + ReturnStatement( + AddHandler( + Argument(nameExpression), + Argument( + CreateHandlerArgument( + IdentifierName("LanguageProtocolDelegatingHandlers"), + "RequestRegistration", + requestName, + responseName, + registrationOptionsName + ) + .WithArgumentList(GetRegistrationHandlerArgumentList(IdentifierName("registrationOptions")))) + ) + ) + ); + } + + public static BlockSyntax GetVoidRequestRegistrationHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol registrationOptions) + { + var requestName = ResolveTypeName(requestType); + var registrationOptionsName = ResolveTypeName(registrationOptions); + return Block( + EnsureRegistrationOptionsIsSet(IdentifierName("registrationOptions"), registrationOptionsName), + ReturnStatement( + AddHandler( + Argument(nameExpression), + Argument( + CreateHandlerArgument( + IdentifierName("LanguageProtocolDelegatingHandlers"), + "RequestRegistration", + requestName, + registrationOptionsName + ) + .WithArgumentList(GetRegistrationHandlerArgumentList(IdentifierName("registrationOptions")))) + ) + ) + ); + } + + public static BlockSyntax GetRequestRegistrationHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType, + ITypeSymbol registrationOptions, + ITypeSymbol capability) + { + var requestName = ResolveTypeName(requestType); + var responseName = ResolveTypeName(responseType); + var registrationOptionsName = ResolveTypeName(registrationOptions); + var capabilityName = ResolveTypeName(capability); + return Block( + EnsureRegistrationOptionsIsSet(IdentifierName("registrationOptions"), registrationOptionsName), + ReturnStatement( + AddHandler( + Argument(nameExpression), + Argument( + CreateHandlerArgument( + IdentifierName("LanguageProtocolDelegatingHandlers"), + "Request", + requestName, + responseName, + capabilityName, + registrationOptionsName + ) + .WithArgumentList(GetRegistrationHandlerArgumentList(IdentifierName("registrationOptions")))) + ) + ) + ); + } + + public static BlockSyntax GetVoidRequestRegistrationHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol registrationOptions, + ITypeSymbol capability) + { + var requestName = ResolveTypeName(requestType); + var registrationOptionsName = ResolveTypeName(registrationOptions); + var capabilityName = ResolveTypeName(capability); + return Block( + EnsureRegistrationOptionsIsSet(IdentifierName("registrationOptions"), registrationOptionsName), + ReturnStatement( + AddHandler( + Argument(nameExpression), + Argument( + CreateHandlerArgument( + IdentifierName("LanguageProtocolDelegatingHandlers"), + "Request", + requestName, + capabilityName, + registrationOptionsName + ) + .WithArgumentList(GetRegistrationHandlerArgumentList(IdentifierName("registrationOptions")))) + ) + ) + ); + } + + public static ArrowExpressionClauseSyntax GetRequestHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType) + { + var requestName = ResolveTypeName(requestType); + var responseName = ResolveTypeName(responseType); + return ArrowExpressionClause( + AddHandler( + Argument(nameExpression), + Argument( + CreateHandlerArgument( + IdentifierName("LanguageProtocolDelegatingHandlers"), + "Request", + requestName, + responseName + ) + .WithArgumentList(GetHandlerArgumentList())) + ) + ); + } + + public static ArrowExpressionClauseSyntax GetRequestHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType) + { + var requestName = ResolveTypeName(requestType); + return ArrowExpressionClause( + AddHandler( + Argument(nameExpression), + Argument( + CreateHandlerArgument( + IdentifierName("LanguageProtocolDelegatingHandlers"), + "Request", + requestName + ) + .WithArgumentList(GetHandlerArgumentList())) + ) + ); + } + + public static ArrowExpressionClauseSyntax GetPartialResultCapabilityHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType, + ITypeSymbol capability) + { + var requestName = ResolveTypeName(requestType); + var responseName = ResolveTypeName(responseType); + var capabilityName = ResolveTypeName(capability); + return ArrowExpressionClause( + AddHandler( + Argument(nameExpression), + Argument( + CreateHandlerArgument( + IdentifierName("LanguageProtocolDelegatingHandlers"), + "PartialResultCapability", + requestName, + responseName, + capabilityName + ) + .WithArgumentList(GetPartialResultArgumentList(responseName))) + ) + ); + } + + public static BlockSyntax GetPartialResultRegistrationHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType, + ITypeSymbol registrationOptions) + { + var requestName = ResolveTypeName(requestType); + var responseName = ResolveTypeName(responseType); + var registrationOptionsName = ResolveTypeName(registrationOptions); + return Block( + EnsureRegistrationOptionsIsSet(IdentifierName("registrationOptions"), registrationOptionsName), + ReturnStatement( + AddHandler( + Argument(nameExpression), + Argument( + CreateHandlerArgument( + IdentifierName("LanguageProtocolDelegatingHandlers"), + "PartialResult", + requestName, + responseName, + registrationOptionsName + ) + .WithArgumentList(GetPartialResultRegistrationArgumentList(IdentifierName("registrationOptions"), responseName))) + ) + ) + ); + } + + public static BlockSyntax GetPartialResultRegistrationHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType, + ITypeSymbol registrationOptions, + ITypeSymbol capability) + { + var requestName = ResolveTypeName(requestType); + var responseName = ResolveTypeName(responseType); + var registrationOptionsName = ResolveTypeName(registrationOptions); + var capabilityName = ResolveTypeName(capability); + return Block( + EnsureRegistrationOptionsIsSet(IdentifierName("registrationOptions"), registrationOptionsName), + ReturnStatement( + AddHandler( + Argument(nameExpression), + Argument( + CreateHandlerArgument( + IdentifierName("LanguageProtocolDelegatingHandlers"), + "PartialResult", + requestName, + responseName, + capabilityName, + registrationOptionsName + ) + .WithArgumentList(GetPartialResultRegistrationArgumentList(IdentifierName("registrationOptions"), responseName))) + ) + ) + ); + } + + public static ArrowExpressionClauseSyntax GetPartialResultHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType) + { + var requestName = ResolveTypeName(requestType); + var responseName = ResolveTypeName(responseType); + return ArrowExpressionClause( + AddHandler( + Argument(nameExpression), + Argument( + CreateHandlerArgument( + IdentifierName("LanguageProtocolDelegatingHandlers"), + "PartialResult", + requestName, + responseName + ) + .WithArgumentList(GetPartialResultArgumentList(responseName))) + ) + ); + } + + public static ArrowExpressionClauseSyntax GetPartialResultsCapabilityHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType, + NameSyntax itemName, ITypeSymbol capability) + { + var requestName = ResolveTypeName(requestType); + var responseName = ResolveTypeName(responseType); + var capabilityName = ResolveTypeName(capability); + return ArrowExpressionClause( + AddHandler( + Argument(nameExpression), + Argument( + SimpleLambdaExpression( + Parameter( + Identifier("_")), + CreateHandlerArgument( + IdentifierName("LanguageProtocolDelegatingHandlers"), + "PartialResultsCapability", + requestName, + responseName, + itemName, + capabilityName + ) + .WithArgumentList(GetPartialItemsArgumentList(responseName))) + ) + ) + ); + } + + public static BlockSyntax GetPartialResultsRegistrationHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType, + NameSyntax itemName, ITypeSymbol registrationOptions) + { + var requestName = ResolveTypeName(requestType); + var responseName = ResolveTypeName(responseType); + var registrationOptionsName = ResolveTypeName(registrationOptions); + return Block( + EnsureRegistrationOptionsIsSet(IdentifierName("registrationOptions"), registrationOptionsName), + ReturnStatement( + AddHandler( + Argument(nameExpression), + Argument( + SimpleLambdaExpression( + Parameter( + Identifier("_")), + CreateHandlerArgument( + IdentifierName("LanguageProtocolDelegatingHandlers"), + "PartialResults", + requestName, + responseName, + itemName, + registrationOptionsName + ) + .WithArgumentList(GetPartialItemsRegistrationArgumentList(IdentifierName("registrationOptions"), responseName))) + ) + )) + ); + } + + public static BlockSyntax GetPartialResultsRegistrationHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, ITypeSymbol responseType, + NameSyntax itemName, ITypeSymbol registrationOptions, + ITypeSymbol capability) + { + var requestName = ResolveTypeName(requestType); + var responseName = ResolveTypeName(responseType); + var registrationOptionsName = ResolveTypeName(registrationOptions); + var capabilityName = ResolveTypeName(capability); + return Block( + EnsureRegistrationOptionsIsSet(IdentifierName("registrationOptions"), registrationOptionsName), + ReturnStatement( + AddHandler( + Argument(nameExpression), + Argument( + SimpleLambdaExpression( + Parameter( + Identifier("_")), + CreateHandlerArgument( + IdentifierName("LanguageProtocolDelegatingHandlers"), + "PartialResults", + requestName, + responseName, + itemName, + capabilityName, + registrationOptionsName + ) + .WithArgumentList(GetPartialItemsRegistrationArgumentList(IdentifierName("registrationOptions"), responseName))) + ) + )) + ); + } + + public static ArrowExpressionClauseSyntax GetPartialResultsHandlerExpression(ExpressionSyntax nameExpression, ITypeSymbol requestType, NameSyntax itemName, + ITypeSymbol responseType) + { + var requestName = ResolveTypeName(requestType); + var responseName = ResolveTypeName(responseType); + return ArrowExpressionClause( + AddHandler( + Argument(nameExpression), + Argument( + SimpleLambdaExpression( + Parameter( + Identifier("_")), + CreateHandlerArgument( + IdentifierName("LanguageProtocolDelegatingHandlers"), + "PartialResults", + requestName, + responseName, + itemName + ) + .WithArgumentList(GetPartialItemsArgumentList(responseName))) + ) + ) + ); + } + + public static ArrowExpressionClauseSyntax GetNotificationHandlerExpression(ExpressionSyntax nameExpression) + { + return ArrowExpressionClause( + InvocationExpression(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("registry"), + IdentifierName("AddHandler") + )) + .WithArgumentList(ArgumentList(SeparatedList(new SyntaxNodeOrToken[] { + Argument(nameExpression), + Token(SyntaxKind.CommaToken), + Argument(InvocationExpression(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("NotificationHandler"), + IdentifierName("For") + )) + .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName("handler")))))) + }))) + ); + } + + public static ArrowExpressionClauseSyntax GetRequestHandlerExpression(ExpressionSyntax nameExpression) + { + return ArrowExpressionClause( + InvocationExpression(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("registry"), + IdentifierName("AddHandler") + )) + .WithArgumentList(ArgumentList(SeparatedList(new SyntaxNodeOrToken[] { + Argument(nameExpression), + Token(SyntaxKind.CommaToken), + Argument(InvocationExpression(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("RequestHandler"), + IdentifierName("For") + )) + .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName("handler")))))) + }))) + ); + } + + public static ArrowExpressionClauseSyntax GetNotificationInvokeExpression() + { + return ArrowExpressionClause( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("mediator"), + IdentifierName("SendNotification"))) + .WithArgumentList( + ArgumentList( + SeparatedList(new[] { + Argument(IdentifierName(@"@params")) + })) + ) + ); + } + + public static ArrowExpressionClauseSyntax GetRequestInvokeExpression() + { + return ArrowExpressionClause( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("mediator"), + IdentifierName("SendRequest"))) + .WithArgumentList( + ArgumentList( + SeparatedList(new[] { + Argument(IdentifierName(@"@params")), + Argument(IdentifierName("cancellationToken")) + })) + ) + ); + } + + public static ArrowExpressionClauseSyntax GetPartialInvokeExpression(NameSyntax responseType) + { + return ArrowExpressionClause( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("mediator"), + IdentifierName("ProgressManager")), + IdentifierName("MonitorUntil"))) + .WithArgumentList( + ArgumentList( + SeparatedList( + new [] { + Argument( + IdentifierName(@"@params")), + Argument( + SimpleLambdaExpression( + Parameter(Identifier("value")), + ObjectCreationExpression(responseType) + .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(IdentifierName("value"))))))), + Argument(IdentifierName("cancellationToken")) + })))); + } + + public static string GetExtensionClassName(INamedTypeSymbol symbol) + { + return SpecialCasedHandlerFullName(symbol).Split('.').Last() + "Extensions"; + ; + } + + private static string SpecialCasedHandlerFullName(INamedTypeSymbol symbol) + { + return new Regex(@"(\w+)$") + .Replace(symbol.ToDisplayString() ?? string.Empty, + symbol.Name.Substring(1, symbol.Name.IndexOf("Handler", StringComparison.Ordinal) - 1)) + ; + } + + public static string SpecialCasedHandlerName(INamedTypeSymbol symbol) + { + var name = SpecialCasedHandlerFullName(symbol); + return name.Substring(name.LastIndexOf('.') + 1); + } + + public static string GetOnMethodName(INamedTypeSymbol symbol, AttributeData attributeData) + { + var namedMethod = attributeData.NamedArguments + .Where(z => z.Key == "MethodName") + .Select(z => z.Value.Value) + .FirstOrDefault(); + if (namedMethod is string value) return value; + return "On" + SpecialCasedHandlerName(symbol); + } + + public static string GetSendMethodName(INamedTypeSymbol symbol, AttributeData attributeData) + { + var namedMethod = attributeData.NamedArguments + .Where(z => z.Key == "MethodName") + .Select(z => z.Value.Value) + .FirstOrDefault(); + if (namedMethod is string value) return value; + var name = SpecialCasedHandlerName(symbol); + if ( + name.StartsWith("Run") + // TODO: Change this next breaking change + // || name.StartsWith("Set") + // || name.StartsWith("Attach") + // || name.StartsWith("Read") + || name.StartsWith("Did") + || name.StartsWith("Log") + || name.StartsWith("Show") + || name.StartsWith("Register") + || name.StartsWith("Prepare") + || name.StartsWith("Publish") + || name.StartsWith("ApplyWorkspaceEdit") + || name.StartsWith("Unregister")) + { + return name; + } + + if (name.EndsWith("Resolve", StringComparison.Ordinal)) + { + return "Resolve" + name.Substring(0, name.IndexOf("Resolve", StringComparison.Ordinal)); + } + + return IsNotification(symbol) ? "Send" + name : "Request" + name; + } + + private static string HandlerName(INamedTypeSymbol symbol) + { + var name = HandlerFullName(symbol); + return name.Substring(name.LastIndexOf('.') + 1); + } + + private static string HandlerFullName(INamedTypeSymbol symbol) + { + return new Regex(@"(\w+)$") + .Replace(symbol.ToDisplayString() ?? string.Empty, + symbol.Name.Substring(1, symbol.Name.IndexOf("Handler", StringComparison.Ordinal) - 1)); + } + } +} diff --git a/src/JsonRpc.Generators/JsonRpc.Generators.csproj b/src/JsonRpc.Generators/JsonRpc.Generators.csproj new file mode 100644 index 000000000..3a1bf25f0 --- /dev/null +++ b/src/JsonRpc.Generators/JsonRpc.Generators.csproj @@ -0,0 +1,17 @@ + + + + + + netstandard2.0 + OmniSharp.Extensions.JsonRpc.Generators + OmniSharp.Extensions.JsonRpc.Generators + false + + + + + + + + diff --git a/src/JsonRpc/Generation/GenerateHandlerMethodsAttribute.cs b/src/JsonRpc/Generation/GenerateHandlerMethodsAttribute.cs new file mode 100644 index 000000000..5c909df09 --- /dev/null +++ b/src/JsonRpc/Generation/GenerateHandlerMethodsAttribute.cs @@ -0,0 +1,24 @@ +using System; +using System.Diagnostics; +using CodeGeneration.Roslyn; + +namespace OmniSharp.Extensions.JsonRpc.Generation +{ + /// + /// Allows generating OnXyz handler methods for a given IJsonRpcHandler + /// + /// + /// Efforts will be made to make this available for consumers once source generators land + /// + [AttributeUsage(AttributeTargets.Interface)] + [CodeGenerationAttribute("OmniSharp.Extensions.JsonRpc.Generators.GenerateHandlerMethodsGenerator, OmniSharp.Extensions.JsonRpc.Generators")] + [Conditional("CodeGeneration")] + public class GenerateHandlerMethodsAttribute : Attribute + { + public GenerateHandlerMethodsAttribute(params Type[] registryTypes) + { + } + + public string MethodName { get; set; } + } +} diff --git a/src/JsonRpc/Generation/GenerateRequestMethodsAttribute.cs b/src/JsonRpc/Generation/GenerateRequestMethodsAttribute.cs new file mode 100644 index 000000000..2a3905ef1 --- /dev/null +++ b/src/JsonRpc/Generation/GenerateRequestMethodsAttribute.cs @@ -0,0 +1,24 @@ +using System; +using System.Diagnostics; +using CodeGeneration.Roslyn; + +namespace OmniSharp.Extensions.JsonRpc.Generation +{ + /// + /// Allows generating SendXyz/RequestXyz methods for a given IJsonRpcHandler + /// + /// + /// Efforts will be made to make this available for consumers once source generators land + /// + [AttributeUsage(AttributeTargets.Interface)] + [CodeGenerationAttribute("OmniSharp.Extensions.JsonRpc.Generators.GenerateRequestMethodsGenerator, OmniSharp.Extensions.JsonRpc.Generators")] + [Conditional("CodeGeneration")] + public class GenerateRequestMethodsAttribute : Attribute + { + public GenerateRequestMethodsAttribute(params Type[] proxyTypes) + { + } + + public string MethodName { get; set; } + } +} diff --git a/src/JsonRpc/JsonRpc.csproj b/src/JsonRpc/JsonRpc.csproj index 6c8e67ac4..8cb348ece 100644 --- a/src/JsonRpc/JsonRpc.csproj +++ b/src/JsonRpc/JsonRpc.csproj @@ -18,8 +18,12 @@ + + + <_Parameter1>OmniSharp.Extensions.DebugAdapter, PublicKey=0024000004800000940000000602000000240000525341310004000001000100391db875e68eb4bfef49ce14313b9e13f2cd3cc89eb273bbe6c11a55044c7d4f566cf092e1c77ef9e7c75b1496ae7f95d925938f5a01793dd8d9f99ae0a7595779b71b971287d7d7b5960d052078d14f5ce1a85ea5c9fb2f59ac735ff7bc215cab469b7c3486006860bad6f4c3b5204ea2f28dd4e1d05e2cca462cfd593b9f9f + <_Parameter1>OmniSharp.Extensions.LanguageProtocol, PublicKey=0024000004800000940000000602000000240000525341310004000001000100391db875e68eb4bfef49ce14313b9e13f2cd3cc89eb273bbe6c11a55044c7d4f566cf092e1c77ef9e7c75b1496ae7f95d925938f5a01793dd8d9f99ae0a7595779b71b971287d7d7b5960d052078d14f5ce1a85ea5c9fb2f59ac735ff7bc215cab469b7c3486006860bad6f4c3b5204ea2f28dd4e1d05e2cca462cfd593b9f9f diff --git a/src/Protocol/Document/ITypeDefinitionHandler.cs b/src/Protocol/Document/ITypeDefinitionHandler.cs index 10d7d3021..9a807da24 100644 --- a/src/Protocol/Document/ITypeDefinitionHandler.cs +++ b/src/Protocol/Document/ITypeDefinitionHandler.cs @@ -4,6 +4,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.DependencyInjection; using OmniSharp.Extensions.JsonRpc; +using OmniSharp.Extensions.JsonRpc.Generation; using OmniSharp.Extensions.LanguageServer.Protocol.Client; using OmniSharp.Extensions.LanguageServer.Protocol.Client.Capabilities; using OmniSharp.Extensions.LanguageServer.Protocol.Models; @@ -13,6 +14,7 @@ namespace OmniSharp.Extensions.LanguageServer.Protocol.Document { [Parallel, Method(TextDocumentNames.TypeDefinition, Direction.ClientToServer)] + [GenerateHandlerMethods, GenerateRequestMethods(typeof(ITextDocumentLanguageClient), typeof(ILanguageClient))] public interface ITypeDefinitionHandler : IJsonRpcRequestHandler, IRegistration, ICapability { } public abstract class TypeDefinitionHandler : ITypeDefinitionHandler @@ -28,96 +30,4 @@ public TypeDefinitionHandler(TypeDefinitionRegistrationOptions registrationOptio public virtual void SetCapability(TypeDefinitionCapability capability) => Capability = capability; protected TypeDefinitionCapability Capability { get; private set; } } - - public static class TypeDefinitionExtensions - { -public static ILanguageServerRegistry OnTypeDefinition(this ILanguageServerRegistry registry, - Func> - handler, - TypeDefinitionRegistrationOptions registrationOptions) - { - registrationOptions ??= new TypeDefinitionRegistrationOptions(); - return registry.AddHandler(TextDocumentNames.TypeDefinition, - new LanguageProtocolDelegatingHandlers.Request(handler, registrationOptions)); - } - -public static ILanguageServerRegistry OnTypeDefinition(this ILanguageServerRegistry registry, - Func> handler, - TypeDefinitionRegistrationOptions registrationOptions) - { - registrationOptions ??= new TypeDefinitionRegistrationOptions(); - return registry.AddHandler(TextDocumentNames.TypeDefinition, - new LanguageProtocolDelegatingHandlers.RequestRegistration(handler, registrationOptions)); - } - -public static ILanguageServerRegistry OnTypeDefinition(this ILanguageServerRegistry registry, - Func> handler, - TypeDefinitionRegistrationOptions registrationOptions) - { - registrationOptions ??= new TypeDefinitionRegistrationOptions(); - return registry.AddHandler(TextDocumentNames.TypeDefinition, - new LanguageProtocolDelegatingHandlers.RequestRegistration(handler, registrationOptions)); - } - -public static ILanguageServerRegistry OnTypeDefinition(this ILanguageServerRegistry registry, - Action>, TypeDefinitionCapability, - CancellationToken> handler, - TypeDefinitionRegistrationOptions registrationOptions) - { - registrationOptions ??= new TypeDefinitionRegistrationOptions(); - return registry.AddHandler(TextDocumentNames.TypeDefinition, - _ => - new LanguageProtocolDelegatingHandlers.PartialResults(handler, - registrationOptions, _.GetService(), x => new LocationOrLocationLinks(x))); - } - -public static ILanguageServerRegistry OnTypeDefinition(this ILanguageServerRegistry registry, - Action>, TypeDefinitionCapability> - handler, - TypeDefinitionRegistrationOptions registrationOptions) - { - registrationOptions ??= new TypeDefinitionRegistrationOptions(); - return registry.AddHandler(TextDocumentNames.TypeDefinition, - _ => - new LanguageProtocolDelegatingHandlers.PartialResults(handler, - registrationOptions, _.GetService(), x => new LocationOrLocationLinks(x))); - } - -public static ILanguageServerRegistry OnTypeDefinition(this ILanguageServerRegistry registry, - Action>, CancellationToken> handler, - TypeDefinitionRegistrationOptions registrationOptions) - { - registrationOptions ??= new TypeDefinitionRegistrationOptions(); - return registry.AddHandler(TextDocumentNames.TypeDefinition, - _ => - new LanguageProtocolDelegatingHandlers.PartialResults(handler, registrationOptions, - _.GetService(), x => new LocationOrLocationLinks(x))); - } - -public static ILanguageServerRegistry OnTypeDefinition(this ILanguageServerRegistry registry, - Action>> handler, - TypeDefinitionRegistrationOptions registrationOptions) - { - registrationOptions ??= new TypeDefinitionRegistrationOptions(); - return registry.AddHandler(TextDocumentNames.TypeDefinition, - _ => - new LanguageProtocolDelegatingHandlers.PartialResults(handler, registrationOptions, - _.GetService(), x => new LocationOrLocationLinks(x))); - } - - public static IRequestProgressObservable, LocationOrLocationLinks> RequestTypeDefinition( - this ITextDocumentLanguageClient mediator, - TypeDefinitionParams @params, - CancellationToken cancellationToken = default) - { - return mediator.ProgressManager.MonitorUntil(@params, x => new LocationOrLocationLinks(x), cancellationToken); - } - } } diff --git a/src/Protocol/General/ILanguageProtocolInitializeHandler.cs b/src/Protocol/General/ILanguageProtocolInitializeHandler.cs index 14f081489..b6d7a6d8c 100644 --- a/src/Protocol/General/ILanguageProtocolInitializeHandler.cs +++ b/src/Protocol/General/ILanguageProtocolInitializeHandler.cs @@ -2,6 +2,7 @@ using System.Threading; using System.Threading.Tasks; using OmniSharp.Extensions.JsonRpc; +using OmniSharp.Extensions.JsonRpc.Generation; using OmniSharp.Extensions.LanguageServer.Protocol.Client; using OmniSharp.Extensions.LanguageServer.Protocol.Models; using OmniSharp.Extensions.LanguageServer.Protocol.Server; @@ -12,6 +13,7 @@ namespace OmniSharp.Extensions.LanguageServer.Protocol.General /// InitializeError /// [Serial, Method(GeneralNames.Initialize, Direction.ClientToServer)] + [GenerateHandlerMethods(typeof(ILanguageServerRegistry), MethodName = "OnLanguageProtocolInitialize"), GenerateRequestMethods(typeof(ILanguageClient), MethodName = "RequestLanguageProtocolInitialize")] public interface ILanguageProtocolInitializeHandler : IJsonRpcRequestHandler { } @@ -20,26 +22,4 @@ public abstract class LanguageProtocolInitializeHandler : ILanguageProtocolIniti { public abstract Task Handle(InitializeParams request, CancellationToken cancellationToken); } - - public static class LanguageProtocolInitializeExtensions - { - public static ILanguageServerRegistry OnLanguageProtocolInitialize(this ILanguageServerRegistry registry, - Func> - handler) - { - return registry.AddHandler(GeneralNames.Initialize, RequestHandler.For(handler)); - } - - public static ILanguageServerRegistry OnLanguageProtocolInitialize(this ILanguageServerRegistry registry, - Func> handler) - { - return registry.AddHandler(GeneralNames.Initialize, RequestHandler.For(handler)); - } - - public static Task RequestLanguageProtocolInitialize(this ILanguageClient mediator, InitializeParams @params, - CancellationToken cancellationToken = default) - { - return mediator.SendRequest(@params, cancellationToken); - } - } } diff --git a/src/Protocol/General/IShutdownHandler.cs b/src/Protocol/General/IShutdownHandler.cs index b3837006d..ca0696ddb 100644 --- a/src/Protocol/General/IShutdownHandler.cs +++ b/src/Protocol/General/IShutdownHandler.cs @@ -3,6 +3,7 @@ using System.Threading.Tasks; using MediatR; using OmniSharp.Extensions.JsonRpc; +using OmniSharp.Extensions.JsonRpc.Generation; using OmniSharp.Extensions.LanguageServer.Protocol.Client; using OmniSharp.Extensions.LanguageServer.Protocol.Models; using OmniSharp.Extensions.LanguageServer.Protocol.Server; @@ -10,6 +11,7 @@ namespace OmniSharp.Extensions.LanguageServer.Protocol.General { [Serial, Method(GeneralNames.Shutdown, Direction.ClientToServer)] + [GenerateHandlerMethods, GenerateRequestMethods(typeof(ITextDocumentLanguageClient), typeof(ILanguageClient))] public interface IShutdownHandler : IJsonRpcRequestHandler { } @@ -25,33 +27,8 @@ public virtual async Task Handle(ShutdownParams request, CancellationToken protected abstract Task Handle(CancellationToken cancellationToken); } - public static class ShutdownExtensions + public static partial class ShutdownExtensions { - public static ILanguageServerRegistry OnShutdown(this ILanguageServerRegistry registry, - Func - handler) - { - return registry.AddHandler(GeneralNames.Shutdown, - RequestHandler.For(async (_, ct) => { - await handler(_, ct); - return Unit.Value; - })); - } - - public static ILanguageServerRegistry OnShutdown(this ILanguageServerRegistry registry, Func handler) - { - return registry.AddHandler(GeneralNames.Shutdown, - RequestHandler.For(async (_, ct) => { - await handler(_); - return Unit.Value; - })); - } - - public static Task RequestShutdown(this ILanguageClient mediator, ShutdownParams @params, CancellationToken cancellationToken = default) - { - return mediator.SendRequest(@params, cancellationToken); - } - public static Task RequestShutdown(this ILanguageClient mediator, CancellationToken cancellationToken = default) { return mediator.SendRequest(ShutdownParams.Instance, cancellationToken); diff --git a/src/Protocol/IProgressHandler.cs b/src/Protocol/IProgressHandler.cs index 469a820cc..c61a0e034 100644 --- a/src/Protocol/IProgressHandler.cs +++ b/src/Protocol/IProgressHandler.cs @@ -4,6 +4,7 @@ using System.Threading.Tasks; using MediatR; using OmniSharp.Extensions.JsonRpc; +using OmniSharp.Extensions.JsonRpc.Generation; using OmniSharp.Extensions.LanguageServer.Protocol.Client; using OmniSharp.Extensions.LanguageServer.Protocol.Models; using OmniSharp.Extensions.LanguageServer.Protocol.Progress; @@ -12,6 +13,7 @@ namespace OmniSharp.Extensions.LanguageServer.Protocol { [Parallel, Method(GeneralNames.Progress, Direction.Bidirectional)] + [GenerateHandlerMethods, GenerateRequestMethods(typeof(IGeneralLanguageClient), typeof(ILanguageClient), typeof(IGeneralLanguageServer), typeof(ILanguageServer))] public interface IProgressHandler : IJsonRpcNotificationHandler { } @@ -21,66 +23,8 @@ public abstract class ProgressHandler : IProgressHandler public abstract Task Handle(ProgressParams request, CancellationToken cancellationToken); } - public static class ProgressExtensions + public static partial class ProgressExtensions { -public static ILanguageServerRegistry OnProgress(this ILanguageServerRegistry registry, - Action handler) - { - return registry.AddHandler(GeneralNames.Progress, NotificationHandler.For(handler)); - } - -public static ILanguageServerRegistry OnProgress(this ILanguageServerRegistry registry, - Func handler) - { - return registry.AddHandler(GeneralNames.Progress, NotificationHandler.For(handler)); - } - -public static ILanguageServerRegistry OnProgress(this ILanguageServerRegistry registry, - Action handler) - { - return registry.AddHandler(GeneralNames.Progress, NotificationHandler.For(handler)); - } - -public static ILanguageServerRegistry OnProgress(this ILanguageServerRegistry registry, - Func handler) - { - return registry.AddHandler(GeneralNames.Progress, NotificationHandler.For(handler)); - } - -public static ILanguageClientRegistry OnProgress(this ILanguageClientRegistry registry, - Action handler) - { - return registry.AddHandler(GeneralNames.Progress, NotificationHandler.For(handler)); - } - -public static ILanguageClientRegistry OnProgress(this ILanguageClientRegistry registry, - Func handler) - { - return registry.AddHandler(GeneralNames.Progress, NotificationHandler.For(handler)); - } - -public static ILanguageClientRegistry OnProgress(this ILanguageClientRegistry registry, - Action handler) - { - return registry.AddHandler(GeneralNames.Progress, NotificationHandler.For(handler)); - } - -public static ILanguageClientRegistry OnProgress(this ILanguageClientRegistry registry, - Func handler) - { - return registry.AddHandler(GeneralNames.Progress, NotificationHandler.For(handler)); - } - - public static void SendProgress(this IGeneralLanguageClient registry, ProgressParams @params) - { - registry.SendNotification(@params); - } - - public static void SendProgress(this IGeneralLanguageServer registry, ProgressParams @params) - { - registry.SendNotification(@params); - } - public static IRequestProgressObservable RequestProgress(this IClientProxy requestRouter, IPartialItemRequest @params, Func factory, CancellationToken cancellationToken = default) { var resultToken = new ProgressToken(Guid.NewGuid().ToString()); diff --git a/src/Protocol/Protocol.csproj b/src/Protocol/Protocol.csproj index 996e1cb95..1661083a9 100644 --- a/src/Protocol/Protocol.csproj +++ b/src/Protocol/Protocol.csproj @@ -8,6 +8,8 @@ + + <_Parameter1>OmniSharp.Extensions.LanguageServer, PublicKey=0024000004800000940000000602000000240000525341310004000001000100391db875e68eb4bfef49ce14313b9e13f2cd3cc89eb273bbe6c11a55044c7d4f566cf092e1c77ef9e7c75b1496ae7f95d925938f5a01793dd8d9f99ae0a7595779b71b971287d7d7b5960d052078d14f5ce1a85ea5c9fb2f59ac735ff7bc215cab469b7c3486006860bad6f4c3b5204ea2f28dd4e1d05e2cca462cfd593b9f9f @@ -18,4 +20,7 @@ <_Parameter1>OmniSharp.Extensions.LanguageClient, PublicKey=0024000004800000940000000602000000240000525341310004000001000100391db875e68eb4bfef49ce14313b9e13f2cd3cc89eb273bbe6c11a55044c7d4f566cf092e1c77ef9e7c75b1496ae7f95d925938f5a01793dd8d9f99ae0a7595779b71b971287d7d7b5960d052078d14f5ce1a85ea5c9fb2f59ac735ff7bc215cab469b7c3486006860bad6f4c3b5204ea2f28dd4e1d05e2cca462cfd593b9f9f + + + diff --git a/test/Dap.Tests/Dap.Tests.csproj b/test/Dap.Tests/Dap.Tests.csproj index 2498bb013..ca60129bc 100644 --- a/test/Dap.Tests/Dap.Tests.csproj +++ b/test/Dap.Tests/Dap.Tests.csproj @@ -9,6 +9,7 @@ + diff --git a/test/Dap.Tests/FoundationTests.cs b/test/Dap.Tests/FoundationTests.cs index 7ea1bc34b..fc5a1df19 100644 --- a/test/Dap.Tests/FoundationTests.cs +++ b/test/Dap.Tests/FoundationTests.cs @@ -250,7 +250,9 @@ public void Match(string description, params Func[] matchers) new[] { registrySub, Substitute.For(new Type[] {method.GetParameters()[1].ParameterType}, Array.Empty()), }.Concat(method.GetParameters().Skip(2).Select(z => - !z.ParameterType.IsGenericType ? Activator.CreateInstance(z.ParameterType) : Substitute.For(new Type[] {z.ParameterType}, Array.Empty())) + !z.ParameterType.IsGenericType + ? Activator.CreateInstance(z.ParameterType) + : Substitute.For(new Type[] {z.ParameterType}, Array.Empty())) ) .ToArray()); @@ -335,12 +337,12 @@ public static Type GetHandlerInterface(Type type) } - public class TypeHandlerData : TheoryData { public TypeHandlerData() { - foreach (var type in typeof(CompletionsArguments).Assembly.ExportedTypes.Where(z => z.IsInterface && typeof(IJsonRpcHandler).IsAssignableFrom(z) && !z.IsGenericType)) + foreach (var type in typeof(CompletionsArguments).Assembly.ExportedTypes.Where( + z => z.IsInterface && typeof(IJsonRpcHandler).IsAssignableFrom(z) && !z.IsGenericType)) { Add(HandlerTypeDescriptorHelper.GetHandlerTypeDescriptor(type)); } @@ -414,8 +416,16 @@ private static string GetOnMethodName(IHandlerTypeDescriptor descriptor) private static string GetSendMethodName(IHandlerTypeDescriptor descriptor) { - var name = HandlerName(descriptor); - if (name.StartsWith("Run")) return name; + var name = SpecialCasedHandlerName(descriptor); + if (name.StartsWith("Run") + // TODO: Change this next breaking change + // || name.StartsWith("Set") + // || name.StartsWith("Attach") + // || name.StartsWith("Read") + ) + { + return name; + } return descriptor.IsNotification ? "Send" + name : "Request" + name; } diff --git a/test/Generation.Tests/Generation.Tests.csproj b/test/Generation.Tests/Generation.Tests.csproj new file mode 100644 index 000000000..a298c355f --- /dev/null +++ b/test/Generation.Tests/Generation.Tests.csproj @@ -0,0 +1,16 @@ + + + netcoreapp3.1;netcoreapp2.1 + true + AnyCPU + + + + + + + + + + + diff --git a/test/Generation.Tests/GenerationHelpers.cs b/test/Generation.Tests/GenerationHelpers.cs new file mode 100644 index 000000000..0047a76dc --- /dev/null +++ b/test/Generation.Tests/GenerationHelpers.cs @@ -0,0 +1,144 @@ +using System; +using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Threading; +using System.Threading.Tasks; +using CodeGeneration.Roslyn; +using CodeGeneration.Roslyn.Engine; +using FluentAssertions; +using MediatR; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Text; +using Microsoft.Extensions.DependencyInjection; +using OmniSharp.Extensions.DebugAdapter.Protocol; +using OmniSharp.Extensions.JsonRpc.Generation; +using OmniSharp.Extensions.LanguageServer.Protocol.Server; +using Xunit; + +namespace Generation.Tests +{ + public static class GenerationHelpers + { + static GenerationHelpers() + { + // this "core assemblies hack" is from https://stackoverflow.com/a/47196516/4418060 + var coreAssemblyPath = Path.GetDirectoryName(typeof(object).Assembly.Location)!; + var coreAssemblyNames = new[] + { + "mscorlib.dll", + "netstandard.dll", + "System.dll", + "System.Core.dll", +#if NETCOREAPP + "System.Private.CoreLib.dll", +#endif + "System.Runtime.dll", + }; + var coreMetaReferences = + coreAssemblyNames.Select(x => MetadataReference.CreateFromFile(Path.Combine(coreAssemblyPath, x))); + var otherAssemblies = new[] + { + typeof(CSharpCompilation).Assembly, + typeof(CodeGenerationAttributeAttribute).Assembly, + typeof(GenerateHandlerMethodsAttribute).Assembly, + typeof(IDebugAdapterClientRegistry).Assembly, + typeof(Unit).Assembly, + typeof(ILanguageServerRegistry).Assembly, + }; + MetadataReferences = coreMetaReferences + .Concat(otherAssemblies.Distinct().Select(x => MetadataReference.CreateFromFile(x.Location))) + .ToImmutableArray(); + } + + internal const string CrLf = "\r\n"; + internal const string Lf = "\n"; + internal const string DefaultFilePathPrefix = "Test"; + internal const string CSharpDefaultFileExt = "cs"; + internal const string TestProjectName = "TestProject"; + + internal static readonly string NormalizedPreamble = NormalizeToLf(DocumentTransform.GeneratedByAToolPreamble + Lf); + + internal static readonly ImmutableArray MetadataReferences; + + public static async Task AssertGeneratedAsExpected(string source, string expected) + { + var generatedTree = await GenerateAsync(source); + // normalize line endings to just LF + var generatedText = NormalizeToLf(generatedTree.GetText().ToString()); + // and append preamble to the expected + var expectedText = NormalizedPreamble + NormalizeToLf(expected).Trim(); + generatedText.Should().Be(expectedText); + } + + public static async Task Generate(string source) + { + var generatedTree = await GenerateAsync(source); + // normalize line endings to just LF + var generatedText = NormalizeToLf(generatedTree.GetText().ToString()); + // and append preamble to the expected + return generatedText; + } + + public static string NormalizeToLf(string input) + { + return input.Replace(CrLf, Lf); + } + + public static async Task GenerateAsync(string source) + { + var document = CreateProject(source).Documents.Single(); + var tree = await document.GetSyntaxTreeAsync(); + if (tree is null) + { + throw new InvalidOperationException("Could not get the syntax tree of the sources"); + } + + var compilation = (CSharpCompilation?)await document.Project.GetCompilationAsync(); + if (compilation is null) + { + throw new InvalidOperationException("Could not compile the sources"); + } + + var diagnostics = compilation.GetDiagnostics(); + Assert.Empty(diagnostics.Where(x => x.Severity >= DiagnosticSeverity.Warning)); + var progress = new Progress(); + var result = await DocumentTransform.TransformAsync(compilation, tree, null, Assembly.Load, progress, CancellationToken.None); + return result; + } + + public static Project CreateProject(params string[] sources) + { + var projectId = ProjectId.CreateNewId(debugName: TestProjectName); + var solution = new AdhocWorkspace() + .CurrentSolution + .AddProject(projectId, TestProjectName, TestProjectName, LanguageNames.CSharp) + .WithProjectCompilationOptions( + projectId, + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)) + .WithProjectParseOptions( + projectId, + new CSharpParseOptions(preprocessorSymbols: new[] { "SOMETHING_ACTIVE" })) + .AddMetadataReferences(projectId, MetadataReferences); + + int count = 0; + foreach (var source in sources) + { + var newFileName = DefaultFilePathPrefix + count + "." + CSharpDefaultFileExt; + var documentId = DocumentId.CreateNewId(projectId, debugName: newFileName); + solution = solution.AddDocument(documentId, newFileName, SourceText.From(source)); + count++; + } + var project = solution.GetProject(projectId); + if (project is null) + { + throw new InvalidOperationException($"The ad hoc workspace does not contain a project with the id {projectId.Id}"); + } + + return project; + } + } +} diff --git a/test/Generation.Tests/JsonRpcGenerationTests.cs b/test/Generation.Tests/JsonRpcGenerationTests.cs new file mode 100644 index 000000000..a33d4deef --- /dev/null +++ b/test/Generation.Tests/JsonRpcGenerationTests.cs @@ -0,0 +1,602 @@ +using System.Threading.Tasks; +using Snapper; +using Snapper.Attributes; +using Snapper.Core; +using Xunit; +using static Generation.Tests.GenerationHelpers; + +namespace Generation.Tests +{ + public class JsonRpcGenerationTests + { + [Fact] + public async Task Supports_Generating_Notifications_And_Infers_Direction_ExitHandler() + { + var source = @" +using System; +using System.Threading; +using System.Threading.Tasks; +using MediatR; +using OmniSharp.Extensions.JsonRpc; +using OmniSharp.Extensions.JsonRpc.Generation; +using OmniSharp.Extensions.LanguageServer.Protocol; +using OmniSharp.Extensions.LanguageServer.Protocol.Client; +using OmniSharp.Extensions.LanguageServer.Protocol.Models; +using OmniSharp.Extensions.LanguageServer.Protocol.Server; + +namespace OmniSharp.Extensions.LanguageServer.Protocol.Test +{ + [Serial, Method(GeneralNames.Exit, Direction.ClientToServer), GenerateHandlerMethods, GenerateRequestMethods] + public interface IExitHandler : IJsonRpcNotificationHandler + { + } +}"; + + var expected = @"using System; +using System.Threading; +using System.Threading.Tasks; +using MediatR; +using OmniSharp.Extensions.JsonRpc; +using OmniSharp.Extensions.JsonRpc.Generation; +using OmniSharp.Extensions.LanguageServer.Protocol; +using OmniSharp.Extensions.LanguageServer.Protocol.Client; +using OmniSharp.Extensions.LanguageServer.Protocol.Models; +using OmniSharp.Extensions.LanguageServer.Protocol.Server; +using System.Collections.Generic; +using Microsoft.Extensions.DependencyInjection; + +namespace OmniSharp.Extensions.LanguageServer.Protocol.Test +{ + [System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute, System.Runtime.CompilerServices.CompilerGeneratedAttribute] + public static partial class ExitExtensions + { + public static ILanguageServerRegistry OnExit(this ILanguageServerRegistry registry, Action handler) => registry.AddHandler(GeneralNames.Exit, NotificationHandler.For(handler)); + public static ILanguageServerRegistry OnExit(this ILanguageServerRegistry registry, Func handler) => registry.AddHandler(GeneralNames.Exit, NotificationHandler.For(handler)); + public static ILanguageServerRegistry OnExit(this ILanguageServerRegistry registry, Action handler) => registry.AddHandler(GeneralNames.Exit, NotificationHandler.For(handler)); + public static ILanguageServerRegistry OnExit(this ILanguageServerRegistry registry, Func handler) => registry.AddHandler(GeneralNames.Exit, NotificationHandler.For(handler)); + } +} + +namespace OmniSharp.Extensions.LanguageServer.Protocol.Test +{ + public static partial class ExitExtensions + { + public static void SendExit(this ILanguageClient mediator, ExitParams @params) => mediator.SendNotification(@params); + } +}"; + await AssertGeneratedAsExpected(source, expected); + } + + [Fact] + public async Task Supports_Generating_Notifications_And_Infers_Direction_CapabilitiesHandler() + { + var source = @" +using System; +using System.Threading; +using System.Threading.Tasks; +using MediatR; +using OmniSharp.Extensions.JsonRpc; +using OmniSharp.Extensions.JsonRpc.Generation; + +namespace OmniSharp.Extensions.DebugAdapter.Protocol.Events.Test +{ + [Parallel, Method(EventNames.Capabilities, Direction.ServerToClient), GenerateHandlerMethods, GenerateRequestMethods] + public interface ICapabilitiesHandler : IJsonRpcNotificationHandler { } +}"; + + var expected = @" +using System; +using System.Threading; +using System.Threading.Tasks; +using MediatR; +using OmniSharp.Extensions.JsonRpc; +using OmniSharp.Extensions.JsonRpc.Generation; +using System.Collections.Generic; +using Microsoft.Extensions.DependencyInjection; +using OmniSharp.Extensions.DebugAdapter.Protocol; + +namespace OmniSharp.Extensions.DebugAdapter.Protocol.Events.Test +{ + [System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute, System.Runtime.CompilerServices.CompilerGeneratedAttribute] + public static partial class CapabilitiesExtensions + { + public static IDebugAdapterClientRegistry OnCapabilities(this IDebugAdapterClientRegistry registry, Action handler) => registry.AddHandler(EventNames.Capabilities, NotificationHandler.For(handler)); + public static IDebugAdapterClientRegistry OnCapabilities(this IDebugAdapterClientRegistry registry, Func handler) => registry.AddHandler(EventNames.Capabilities, NotificationHandler.For(handler)); + public static IDebugAdapterClientRegistry OnCapabilities(this IDebugAdapterClientRegistry registry, Action handler) => registry.AddHandler(EventNames.Capabilities, NotificationHandler.For(handler)); + public static IDebugAdapterClientRegistry OnCapabilities(this IDebugAdapterClientRegistry registry, Func handler) => registry.AddHandler(EventNames.Capabilities, NotificationHandler.For(handler)); + } +} + +namespace OmniSharp.Extensions.DebugAdapter.Protocol.Events.Test +{ + public static partial class CapabilitiesExtensions + { + public static void SendCapabilities(this IDebugAdapterServer mediator, CapabilitiesEvent @params) => mediator.SendNotification(@params); + } +}"; + await AssertGeneratedAsExpected(source, expected); + } + + + [Fact] + public async Task Supports_Generating_Notifications_ExitHandler() + { + var source = @" +using System; +using System.Threading; +using System.Threading.Tasks; +using MediatR; +using OmniSharp.Extensions.JsonRpc; +using OmniSharp.Extensions.JsonRpc.Generation; +using OmniSharp.Extensions.LanguageServer.Protocol; +using OmniSharp.Extensions.LanguageServer.Protocol.Client; +using OmniSharp.Extensions.LanguageServer.Protocol.Models; +using OmniSharp.Extensions.LanguageServer.Protocol.Server; + +namespace Test +{ + [Serial, Method(GeneralNames.Exit, Direction.ClientToServer), GenerateHandlerMethods(typeof(ILanguageServerRegistry)), GenerateRequestMethods(typeof(ILanguageClient))] + public interface IExitHandler : IJsonRpcNotificationHandler + { + } +}"; + + var expected = @"using System; +using System.Threading; +using System.Threading.Tasks; +using MediatR; +using OmniSharp.Extensions.JsonRpc; +using OmniSharp.Extensions.JsonRpc.Generation; +using OmniSharp.Extensions.LanguageServer.Protocol; +using OmniSharp.Extensions.LanguageServer.Protocol.Client; +using OmniSharp.Extensions.LanguageServer.Protocol.Models; +using OmniSharp.Extensions.LanguageServer.Protocol.Server; +using System.Collections.Generic; +using Microsoft.Extensions.DependencyInjection; + +namespace Test +{ + [System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute, System.Runtime.CompilerServices.CompilerGeneratedAttribute] + public static partial class ExitExtensions + { + public static ILanguageServerRegistry OnExit(this ILanguageServerRegistry registry, Action handler) => registry.AddHandler(GeneralNames.Exit, NotificationHandler.For(handler)); + public static ILanguageServerRegistry OnExit(this ILanguageServerRegistry registry, Func handler) => registry.AddHandler(GeneralNames.Exit, NotificationHandler.For(handler)); + public static ILanguageServerRegistry OnExit(this ILanguageServerRegistry registry, Action handler) => registry.AddHandler(GeneralNames.Exit, NotificationHandler.For(handler)); + public static ILanguageServerRegistry OnExit(this ILanguageServerRegistry registry, Func handler) => registry.AddHandler(GeneralNames.Exit, NotificationHandler.For(handler)); + } +} + +namespace Test +{ + public static partial class ExitExtensions + { + public static void SendExit(this ILanguageClient mediator, ExitParams @params) => mediator.SendNotification(@params); + } +}"; + await AssertGeneratedAsExpected(source, expected); + } + + [Fact] + public async Task Supports_Generating_Notifications_And_Infers_Direction_DidChangeTextHandler() + { + var source = @" +using System; +using System.Threading; +using System.Threading.Tasks; +using MediatR; +using OmniSharp.Extensions.JsonRpc; +using OmniSharp.Extensions.JsonRpc.Generation; +using OmniSharp.Extensions.LanguageServer.Protocol; +using OmniSharp.Extensions.LanguageServer.Protocol.Client; +using OmniSharp.Extensions.LanguageServer.Protocol.Client.Capabilities; +using OmniSharp.Extensions.LanguageServer.Protocol.Models; +using OmniSharp.Extensions.LanguageServer.Protocol.Server; + +namespace OmniSharp.Extensions.LanguageServer.Protocol.Test +{ + [Serial, Method(TextDocumentNames.DidChange, Direction.ClientToServer), GenerateHandlerMethods, GenerateRequestMethods] + public interface IDidChangeTextDocumentHandler : IJsonRpcNotificationHandler, + IRegistration, ICapability + { } +}"; + + var expected = @" +using System; +using System.Threading; +using System.Threading.Tasks; +using MediatR; +using OmniSharp.Extensions.JsonRpc; +using OmniSharp.Extensions.JsonRpc.Generation; +using OmniSharp.Extensions.LanguageServer.Protocol; +using OmniSharp.Extensions.LanguageServer.Protocol.Client; +using OmniSharp.Extensions.LanguageServer.Protocol.Client.Capabilities; +using OmniSharp.Extensions.LanguageServer.Protocol.Models; +using OmniSharp.Extensions.LanguageServer.Protocol.Server; +using System.Collections.Generic; +using Microsoft.Extensions.DependencyInjection; + +namespace OmniSharp.Extensions.LanguageServer.Protocol.Test +{ + [System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute, System.Runtime.CompilerServices.CompilerGeneratedAttribute] + public static partial class DidChangeTextDocumentExtensions + { + public static ILanguageServerRegistry OnDidChangeTextDocument(this ILanguageServerRegistry registry, Action handler, TextDocumentChangeRegistrationOptions registrationOptions) + { + registrationOptions ??= new TextDocumentChangeRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.DidChange, new LanguageProtocolDelegatingHandlers.Notification(handler, registrationOptions)); + } + + public static ILanguageServerRegistry OnDidChangeTextDocument(this ILanguageServerRegistry registry, Func handler, TextDocumentChangeRegistrationOptions registrationOptions) + { + registrationOptions ??= new TextDocumentChangeRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.DidChange, new LanguageProtocolDelegatingHandlers.Notification(handler, registrationOptions)); + } + + public static ILanguageServerRegistry OnDidChangeTextDocument(this ILanguageServerRegistry registry, Action handler, TextDocumentChangeRegistrationOptions registrationOptions) + { + registrationOptions ??= new TextDocumentChangeRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.DidChange, new LanguageProtocolDelegatingHandlers.Notification(handler, registrationOptions)); + } + + public static ILanguageServerRegistry OnDidChangeTextDocument(this ILanguageServerRegistry registry, Func handler, TextDocumentChangeRegistrationOptions registrationOptions) + { + registrationOptions ??= new TextDocumentChangeRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.DidChange, new LanguageProtocolDelegatingHandlers.Notification(handler, registrationOptions)); + } + + public static ILanguageServerRegistry OnDidChangeTextDocument(this ILanguageServerRegistry registry, Action handler, TextDocumentChangeRegistrationOptions registrationOptions) + { + registrationOptions ??= new TextDocumentChangeRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.DidChange, new LanguageProtocolDelegatingHandlers.Notification(handler, registrationOptions)); + } + + public static ILanguageServerRegistry OnDidChangeTextDocument(this ILanguageServerRegistry registry, Func handler, TextDocumentChangeRegistrationOptions registrationOptions) + { + registrationOptions ??= new TextDocumentChangeRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.DidChange, new LanguageProtocolDelegatingHandlers.Notification(handler, registrationOptions)); + } + } +} + +namespace OmniSharp.Extensions.LanguageServer.Protocol.Test +{ + public static partial class DidChangeTextDocumentExtensions + { + public static void DidChangeTextDocument(this ILanguageClient mediator, DidChangeTextDocumentParams @params) => mediator.SendNotification(@params); + } +}"; + await AssertGeneratedAsExpected(source, expected); + } + + [Fact] + public async Task Supports_Generating_Notifications_And_Infers_Direction_FoldingRangeHandler() + { + var source = @" +using System; +using System.Threading; +using System.Threading.Tasks; +using MediatR; +using OmniSharp.Extensions.JsonRpc; +using OmniSharp.Extensions.JsonRpc.Generation; +using OmniSharp.Extensions.LanguageServer.Protocol; +using OmniSharp.Extensions.LanguageServer.Protocol.Client; +using OmniSharp.Extensions.LanguageServer.Protocol.Client.Capabilities; +using OmniSharp.Extensions.LanguageServer.Protocol.Models; +using OmniSharp.Extensions.LanguageServer.Protocol.Server; + +namespace OmniSharp.Extensions.LanguageServer.Protocol.Test +{ + [Parallel, Method(TextDocumentNames.FoldingRange, Direction.ClientToServer)] + [GenerateHandlerMethods, GenerateRequestMethods(typeof(ITextDocumentLanguageClient), typeof(ILanguageClient))] + public interface IFoldingRangeHandler : IJsonRpcRequestHandler>, + IRegistration, ICapability + { + } +}"; + + var expected = @" +using System; +using System.Threading; +using System.Threading.Tasks; +using MediatR; +using OmniSharp.Extensions.JsonRpc; +using OmniSharp.Extensions.JsonRpc.Generation; +using OmniSharp.Extensions.LanguageServer.Protocol; +using OmniSharp.Extensions.LanguageServer.Protocol.Client; +using OmniSharp.Extensions.LanguageServer.Protocol.Client.Capabilities; +using OmniSharp.Extensions.LanguageServer.Protocol.Models; +using OmniSharp.Extensions.LanguageServer.Protocol.Server; +using System.Collections.Generic; +using Microsoft.Extensions.DependencyInjection; +using OmniSharp.Extensions.LanguageServer.Protocol.Progress; + +namespace OmniSharp.Extensions.LanguageServer.Protocol.Test +{ + [System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute, System.Runtime.CompilerServices.CompilerGeneratedAttribute] + public static partial class FoldingRangeExtensions + { + public static ILanguageServerRegistry OnFoldingRange(this ILanguageServerRegistry registry, Func>> handler, FoldingRangeRegistrationOptions registrationOptions) + { + registrationOptions ??= new FoldingRangeRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.FoldingRange, new LanguageProtocolDelegatingHandlers.RequestRegistration, FoldingRangeRegistrationOptions>(handler, registrationOptions)); + } + + public static ILanguageServerRegistry OnFoldingRange(this ILanguageServerRegistry registry, Func>> handler, FoldingRangeRegistrationOptions registrationOptions) + { + registrationOptions ??= new FoldingRangeRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.FoldingRange, new LanguageProtocolDelegatingHandlers.RequestRegistration, FoldingRangeRegistrationOptions>(handler, registrationOptions)); + } + + public static ILanguageServerRegistry OnFoldingRange(this ILanguageServerRegistry registry, Action>, CancellationToken> handler, FoldingRangeRegistrationOptions registrationOptions) + { + registrationOptions ??= new FoldingRangeRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.FoldingRange, _ => new LanguageProtocolDelegatingHandlers.PartialResults, FoldingRange, FoldingRangeRegistrationOptions>(handler, registrationOptions, _.GetService(), values => new OmniSharp.Extensions.LanguageServer.Protocol.Models.Container(values))); + } + + public static ILanguageServerRegistry OnFoldingRange(this ILanguageServerRegistry registry, Action>> handler, FoldingRangeRegistrationOptions registrationOptions) + { + registrationOptions ??= new FoldingRangeRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.FoldingRange, _ => new LanguageProtocolDelegatingHandlers.PartialResults, FoldingRange, FoldingRangeRegistrationOptions>(handler, registrationOptions, _.GetService(), values => new OmniSharp.Extensions.LanguageServer.Protocol.Models.Container(values))); + } + + public static ILanguageServerRegistry OnFoldingRange(this ILanguageServerRegistry registry, Action>, FoldingRangeCapability, CancellationToken> handler, FoldingRangeRegistrationOptions registrationOptions) + { + registrationOptions ??= new FoldingRangeRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.FoldingRange, _ => new LanguageProtocolDelegatingHandlers.PartialResults, FoldingRange, FoldingRangeCapability, FoldingRangeRegistrationOptions>(handler, registrationOptions, _.GetService(), values => new OmniSharp.Extensions.LanguageServer.Protocol.Models.Container(values))); + } + + public static ILanguageServerRegistry OnFoldingRange(this ILanguageServerRegistry registry, Func>> handler, FoldingRangeRegistrationOptions registrationOptions) + { + registrationOptions ??= new FoldingRangeRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.FoldingRange, new LanguageProtocolDelegatingHandlers.Request, FoldingRangeCapability, FoldingRangeRegistrationOptions>(handler, registrationOptions)); + } + } +} + +namespace OmniSharp.Extensions.LanguageServer.Protocol.Test +{ + public static partial class FoldingRangeExtensions + { + public static IRequestProgressObservable, OmniSharp.Extensions.LanguageServer.Protocol.Models.Container> RequestFoldingRange(this ITextDocumentLanguageClient mediator, FoldingRangeRequestParam @params, CancellationToken cancellationToken = default) => mediator.ProgressManager.MonitorUntil(@params, value => new OmniSharp.Extensions.LanguageServer.Protocol.Models.Container(value), cancellationToken); + public static IRequestProgressObservable, OmniSharp.Extensions.LanguageServer.Protocol.Models.Container> RequestFoldingRange(this ILanguageClient mediator, FoldingRangeRequestParam @params, CancellationToken cancellationToken = default) => mediator.ProgressManager.MonitorUntil(@params, value => new OmniSharp.Extensions.LanguageServer.Protocol.Models.Container(value), cancellationToken); + } +}"; + await AssertGeneratedAsExpected(source, expected); + } + + [Fact] + public async Task Supports_Generating_Requests_And_Infers_Direction() + { + var source = @" +using System; +using System.Threading; +using System.Threading.Tasks; +using MediatR; +using OmniSharp.Extensions.JsonRpc; +using OmniSharp.Extensions.JsonRpc.Generation; +using OmniSharp.Extensions.LanguageServer.Protocol; +using OmniSharp.Extensions.LanguageServer.Protocol.Client; +using OmniSharp.Extensions.LanguageServer.Protocol.Models; +using OmniSharp.Extensions.LanguageServer.Protocol.Server; +using OmniSharp.Extensions.LanguageServer.Protocol.Client.Capabilities; + +namespace OmniSharp.Extensions.LanguageServer.Protocol.Test +{ + [Parallel, Method(TextDocumentNames.Definition, Direction.ClientToServer), GenerateHandlerMethods, GenerateRequestMethods] + public interface IDefinitionHandler : IJsonRpcRequestHandler, IRegistration, ICapability { } +}"; + + var expected = @" +using System; +using System.Threading; +using System.Threading.Tasks; +using MediatR; +using OmniSharp.Extensions.JsonRpc; +using OmniSharp.Extensions.JsonRpc.Generation; +using OmniSharp.Extensions.LanguageServer.Protocol; +using OmniSharp.Extensions.LanguageServer.Protocol.Client; +using OmniSharp.Extensions.LanguageServer.Protocol.Models; +using OmniSharp.Extensions.LanguageServer.Protocol.Server; +using OmniSharp.Extensions.LanguageServer.Protocol.Client.Capabilities; +using System.Collections.Generic; +using Microsoft.Extensions.DependencyInjection; +using OmniSharp.Extensions.LanguageServer.Protocol.Progress; + +namespace OmniSharp.Extensions.LanguageServer.Protocol.Test +{ + [System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute, System.Runtime.CompilerServices.CompilerGeneratedAttribute] + public static partial class DefinitionExtensions + { + public static ILanguageServerRegistry OnDefinition(this ILanguageServerRegistry registry, Func> handler, DefinitionRegistrationOptions registrationOptions) + { + registrationOptions ??= new DefinitionRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.Definition, new LanguageProtocolDelegatingHandlers.RequestRegistration(handler, registrationOptions)); + } + + public static ILanguageServerRegistry OnDefinition(this ILanguageServerRegistry registry, Func> handler, DefinitionRegistrationOptions registrationOptions) + { + registrationOptions ??= new DefinitionRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.Definition, new LanguageProtocolDelegatingHandlers.RequestRegistration(handler, registrationOptions)); + } + + public static ILanguageServerRegistry OnDefinition(this ILanguageServerRegistry registry, Action>, CancellationToken> handler, DefinitionRegistrationOptions registrationOptions) + { + registrationOptions ??= new DefinitionRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.Definition, _ => new LanguageProtocolDelegatingHandlers.PartialResults(handler, registrationOptions, _.GetService(), values => new LocationOrLocationLinks(values))); + } + + public static ILanguageServerRegistry OnDefinition(this ILanguageServerRegistry registry, Action>> handler, DefinitionRegistrationOptions registrationOptions) + { + registrationOptions ??= new DefinitionRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.Definition, _ => new LanguageProtocolDelegatingHandlers.PartialResults(handler, registrationOptions, _.GetService(), values => new LocationOrLocationLinks(values))); + } + + public static ILanguageServerRegistry OnDefinition(this ILanguageServerRegistry registry, Action>, DefinitionCapability, CancellationToken> handler, DefinitionRegistrationOptions registrationOptions) + { + registrationOptions ??= new DefinitionRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.Definition, _ => new LanguageProtocolDelegatingHandlers.PartialResults(handler, registrationOptions, _.GetService(), values => new LocationOrLocationLinks(values))); + } + + public static ILanguageServerRegistry OnDefinition(this ILanguageServerRegistry registry, Func> handler, DefinitionRegistrationOptions registrationOptions) + { + registrationOptions ??= new DefinitionRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.Definition, new LanguageProtocolDelegatingHandlers.Request(handler, registrationOptions)); + } + } +} + +namespace OmniSharp.Extensions.LanguageServer.Protocol.Test +{ + public static partial class DefinitionExtensions + { + public static IRequestProgressObservable, LocationOrLocationLinks> RequestDefinition(this ILanguageClient mediator, DefinitionParams @params, CancellationToken cancellationToken = default) => mediator.ProgressManager.MonitorUntil(@params, value => new LocationOrLocationLinks(value), cancellationToken); + } +}"; + await AssertGeneratedAsExpected(source, expected); + } + + + [Fact] + public async Task Supports_Generating_Requests() + { + var source = @" +using System; +using System.Threading; +using System.Threading.Tasks; +using MediatR; +using OmniSharp.Extensions.JsonRpc; +using OmniSharp.Extensions.JsonRpc.Generation; +using OmniSharp.Extensions.LanguageServer.Protocol; +using OmniSharp.Extensions.LanguageServer.Protocol.Client; +using OmniSharp.Extensions.LanguageServer.Protocol.Models; +using OmniSharp.Extensions.LanguageServer.Protocol.Server; +using OmniSharp.Extensions.LanguageServer.Protocol.Client.Capabilities; + +namespace Test +{ + [Parallel, Method(TextDocumentNames.Definition, Direction.ClientToServer), GenerateHandlerMethods(typeof(ILanguageServerRegistry)), GenerateRequestMethods(typeof(ITextDocumentLanguageClient))] + public interface IDefinitionHandler : IJsonRpcRequestHandler, IRegistration, ICapability { } +}"; + + var expected = @" +using System; +using System.Threading; +using System.Threading.Tasks; +using MediatR; +using OmniSharp.Extensions.JsonRpc; +using OmniSharp.Extensions.JsonRpc.Generation; +using OmniSharp.Extensions.LanguageServer.Protocol; +using OmniSharp.Extensions.LanguageServer.Protocol.Client; +using OmniSharp.Extensions.LanguageServer.Protocol.Models; +using OmniSharp.Extensions.LanguageServer.Protocol.Server; +using OmniSharp.Extensions.LanguageServer.Protocol.Client.Capabilities; +using System.Collections.Generic; +using Microsoft.Extensions.DependencyInjection; +using OmniSharp.Extensions.LanguageServer.Protocol.Progress; + +namespace Test +{ + [System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute, System.Runtime.CompilerServices.CompilerGeneratedAttribute] + public static partial class DefinitionExtensions + { + public static ILanguageServerRegistry OnDefinition(this ILanguageServerRegistry registry, Func> handler, DefinitionRegistrationOptions registrationOptions) + { + registrationOptions ??= new DefinitionRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.Definition, new LanguageProtocolDelegatingHandlers.RequestRegistration(handler, registrationOptions)); + } + + public static ILanguageServerRegistry OnDefinition(this ILanguageServerRegistry registry, Func> handler, DefinitionRegistrationOptions registrationOptions) + { + registrationOptions ??= new DefinitionRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.Definition, new LanguageProtocolDelegatingHandlers.RequestRegistration(handler, registrationOptions)); + } + + public static ILanguageServerRegistry OnDefinition(this ILanguageServerRegistry registry, Action>, CancellationToken> handler, DefinitionRegistrationOptions registrationOptions) + { + registrationOptions ??= new DefinitionRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.Definition, _ => new LanguageProtocolDelegatingHandlers.PartialResults(handler, registrationOptions, _.GetService(), values => new LocationOrLocationLinks(values))); + } + + public static ILanguageServerRegistry OnDefinition(this ILanguageServerRegistry registry, Action>> handler, DefinitionRegistrationOptions registrationOptions) + { + registrationOptions ??= new DefinitionRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.Definition, _ => new LanguageProtocolDelegatingHandlers.PartialResults(handler, registrationOptions, _.GetService(), values => new LocationOrLocationLinks(values))); + } + + public static ILanguageServerRegistry OnDefinition(this ILanguageServerRegistry registry, Action>, DefinitionCapability, CancellationToken> handler, DefinitionRegistrationOptions registrationOptions) + { + registrationOptions ??= new DefinitionRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.Definition, _ => new LanguageProtocolDelegatingHandlers.PartialResults(handler, registrationOptions, _.GetService(), values => new LocationOrLocationLinks(values))); + } + + public static ILanguageServerRegistry OnDefinition(this ILanguageServerRegistry registry, Func> handler, DefinitionRegistrationOptions registrationOptions) + { + registrationOptions ??= new DefinitionRegistrationOptions(); + return registry.AddHandler(TextDocumentNames.Definition, new LanguageProtocolDelegatingHandlers.Request(handler, registrationOptions)); + } + } +} + +namespace Test +{ + public static partial class DefinitionExtensions + { + public static IRequestProgressObservable, LocationOrLocationLinks> RequestDefinition(this ITextDocumentLanguageClient mediator, DefinitionParams @params, CancellationToken cancellationToken = default) => mediator.ProgressManager.MonitorUntil(@params, value => new LocationOrLocationLinks(value), cancellationToken); + } +}"; + await AssertGeneratedAsExpected(source, expected); + } + + [Fact] + public async Task Supports_Custom_Method_Names() { + var source = @" +using System; +using System.Threading; +using System.Threading.Tasks; +using OmniSharp.Extensions.JsonRpc; +using OmniSharp.Extensions.JsonRpc.Generation; +using OmniSharp.Extensions.LanguageServer.Protocol; +using OmniSharp.Extensions.LanguageServer.Protocol.Client; +using OmniSharp.Extensions.LanguageServer.Protocol.Models; +using OmniSharp.Extensions.LanguageServer.Protocol.Server; + +namespace Test +{ + [Serial, Method(GeneralNames.Initialize, Direction.ClientToServer), GenerateHandlerMethods(typeof(ILanguageServerRegistry), MethodName = ""OnLanguageProtocolInitialize""), GenerateRequestMethods(typeof(ITextDocumentLanguageClient), MethodName = ""RequestLanguageProtocolInitialize"")] + public interface ILanguageProtocolInitializeHandler : IJsonRpcRequestHandler {} +}"; + var expected = @" +using System; +using System.Threading; +using System.Threading.Tasks; +using OmniSharp.Extensions.JsonRpc; +using OmniSharp.Extensions.JsonRpc.Generation; +using OmniSharp.Extensions.LanguageServer.Protocol; +using OmniSharp.Extensions.LanguageServer.Protocol.Client; +using OmniSharp.Extensions.LanguageServer.Protocol.Models; +using OmniSharp.Extensions.LanguageServer.Protocol.Server; +using System.Collections.Generic; +using MediatR; +using Microsoft.Extensions.DependencyInjection; + +namespace Test +{ + [System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute, System.Runtime.CompilerServices.CompilerGeneratedAttribute] + public static partial class LanguageProtocolInitializeExtensions + { + public static ILanguageServerRegistry OnLanguageProtocolInitialize(this ILanguageServerRegistry registry, Func> handler) => registry.AddHandler(GeneralNames.Initialize, RequestHandler.For(handler)); + public static ILanguageServerRegistry OnLanguageProtocolInitialize(this ILanguageServerRegistry registry, Func> handler) => registry.AddHandler(GeneralNames.Initialize, RequestHandler.For(handler)); + } +} + +namespace Test +{ + public static partial class LanguageProtocolInitializeExtensions + { + public static Task RequestLanguageProtocolInitialize(this ITextDocumentLanguageClient mediator, InitializeParams @params, CancellationToken cancellationToken = default) => mediator.SendRequest(@params, cancellationToken); + } +}"; + await AssertGeneratedAsExpected(source, expected); + } + } +} diff --git a/test/JsonRpc.Tests/JsonRpc.Tests.csproj b/test/JsonRpc.Tests/JsonRpc.Tests.csproj index 68ef89466..6b06d178a 100644 --- a/test/JsonRpc.Tests/JsonRpc.Tests.csproj +++ b/test/JsonRpc.Tests/JsonRpc.Tests.csproj @@ -5,6 +5,7 @@ AnyCPU + diff --git a/test/Lsp.Tests/FoundationTests.cs b/test/Lsp.Tests/FoundationTests.cs index 942030833..5c9af9128 100644 --- a/test/Lsp.Tests/FoundationTests.cs +++ b/test/Lsp.Tests/FoundationTests.cs @@ -12,10 +12,12 @@ using NSubstitute.Core; using NSubstitute.Extensions; using OmniSharp.Extensions.JsonRpc; +using OmniSharp.Extensions.LanguageServer.Protocol.Client; using OmniSharp.Extensions.LanguageServer.Protocol.Document; using OmniSharp.Extensions.LanguageServer.Protocol.Document.Proposals; using OmniSharp.Extensions.LanguageServer.Protocol.Models; using OmniSharp.Extensions.LanguageServer.Protocol.Progress; +using OmniSharp.Extensions.LanguageServer.Protocol.Server; using OmniSharp.Extensions.LanguageServer.Protocol.Shared; using Xunit; using Xunit.Abstractions; @@ -111,8 +113,37 @@ public void HandlersShouldExtensionMethodClassWithMethods(ILspHandlerTypeDescrip .Distinct() .ToHashSet(); - registries.Should().HaveCount(descriptor.Direction == Direction.Bidirectional ? 4 : 2, - $"{descriptor.HandlerType.FullName} there should be methods for both handing the event and sending the event"); + if (descriptor.Direction == Direction.Bidirectional) + { + registries + .Where(z => typeof(IClientProxy).IsAssignableFrom(z) || typeof(ILanguageServerRegistry).IsAssignableFrom(z)) + .Should().HaveCountGreaterOrEqualTo(1, + $"{descriptor.HandlerType.FullName} there should be methods for both handing the event and sending the event"); + registries + .Where(z => typeof(IServerProxy).IsAssignableFrom(z) || typeof(ILanguageClientRegistry).IsAssignableFrom(z)) + .Should().HaveCountGreaterOrEqualTo(1, + $"{descriptor.HandlerType.FullName} there should be methods for both handing the event and sending the event"); + } + else if (descriptor.Direction == Direction.ServerToClient) + { + registries + .Where(z => typeof(IServerProxy).IsAssignableFrom(z) || typeof(ILanguageClientRegistry).IsAssignableFrom(z)) + .Should().HaveCountGreaterOrEqualTo(1, + $"{descriptor.HandlerType.FullName} there should be methods for both handing the event and sending the event"); + registries + .Where(z => typeof(IClientProxy).IsAssignableFrom(z) || typeof(ILanguageServerRegistry).IsAssignableFrom(z)) + .Should().HaveCount(0, $"{descriptor.HandlerType.FullName} must not cross the streams or be made bidirectional"); + } + else if (descriptor.Direction == Direction.ClientToServer) + { + registries + .Where(z => typeof(IClientProxy).IsAssignableFrom(z) || typeof(ILanguageServerRegistry).IsAssignableFrom(z)) + .Should().HaveCountGreaterOrEqualTo(1, + $"{descriptor.HandlerType.FullName} there should be methods for both handing the event and sending the event"); + registries + .Where(z => typeof(IServerProxy).IsAssignableFrom(z) || typeof(ILanguageClientRegistry).IsAssignableFrom(z)) + .Should().HaveCount(0, $"{descriptor.HandlerType.FullName} must not cross the streams or be made bidirectional"); + } } [Theory(DisplayName = "Handler all expected extensions methods based on method direction")] @@ -196,7 +227,7 @@ Func ForParameter(int index, Func m) var isFunc = ForAnyParameter(info => info.ParameterType.Name.StartsWith("Func")); var takesParameter = ForAnyParameter(info => info.ParameterType.GetGenericArguments().FirstOrDefault() == descriptor.ParamsType); var takesCapability = ForAnyParameter(info => info.ParameterType.GetGenericArguments().Skip(1).FirstOrDefault() == descriptor.CapabilityType); - var returnsTask =ForAnyParameter(info => info.ParameterType.GetGenericArguments().LastOrDefault() == typeof(Task)); + var returnsTask = ForAnyParameter(info => info.ParameterType.GetGenericArguments().LastOrDefault() == typeof(Task)); if (descriptor.IsRequest && TypeHandlerExtensionData.HandlersToSkip.All(z => descriptor.HandlerType != z)) { @@ -294,7 +325,8 @@ Func ForParameter(int index, Func m) if (descriptor.IsRequest && descriptor.HasPartialItems) { Func partialReturnType = info => - typeof(IRequestProgressObservable<,>).MakeGenericType(typeof(IEnumerable<>).MakeGenericType(descriptor.PartialItemsType), descriptor.ResponseType).IsAssignableFrom(info.ReturnType); + typeof(IRequestProgressObservable<,>).MakeGenericType(typeof(IEnumerable<>).MakeGenericType(descriptor.PartialItemsType), descriptor.ResponseType) + .IsAssignableFrom(info.ReturnType); matcher.Match( $"Func<{descriptor.ParamsType.Name}, CancellationToken, IProgressObservable, {descriptor.ResponseType.Name}>>", takesParameter, containsCancellationToken, partialReturnType); @@ -302,7 +334,7 @@ Func ForParameter(int index, Func m) else if (descriptor.IsRequest && descriptor.HasPartialItem) { Func partialReturnType = info => - typeof(IRequestProgressObservable<,>).MakeGenericType(descriptor.PartialItemType, descriptor.ResponseType).IsAssignableFrom(info.ReturnType) ; + typeof(IRequestProgressObservable<,>).MakeGenericType(descriptor.PartialItemType, descriptor.ResponseType).IsAssignableFrom(info.ReturnType); matcher.Match($"Func<{descriptor.ParamsType.Name}, CancellationToken, IProgressObservable<{descriptor.PartialItemType.Name}, {descriptor.ResponseType.Name}>>", takesParameter, containsCancellationToken, partialReturnType); } @@ -310,8 +342,7 @@ Func ForParameter(int index, Func m) { matcher.Match($"Func<{descriptor.ParamsType.Name}, CancellationToken, {returnType.Name}>", takesParameter, containsCancellationToken, returns); } - - if (descriptor.IsNotification) + else if (descriptor.IsNotification) { matcher.Match($"Action<{descriptor.ParamsType.Name}>", isAction, takesParameter); } @@ -368,7 +399,8 @@ public void Match(string description, params Func[] matchers) .ToArray()); registrySub.Received().ReceivedCalls() - .Any(z => z.GetMethodInfo().Name == nameof(IJsonRpcHandlerRegistry>.AddHandler) && z.GetArguments().Length == 3 && + .Any(z => z.GetMethodInfo().Name == nameof(IJsonRpcHandlerRegistry>.AddHandler) && + z.GetArguments().Length == 3 && z.GetArguments()[0].Equals(_descriptor.Method)) .Should().BeTrue($"{_descriptor.HandlerType.Name} {description} should have the correct method."); @@ -488,7 +520,7 @@ private static string HandlerFullName(ILspHandlerTypeDescriptor descriptor) descriptor.HandlerType.Name.Substring(1, descriptor.HandlerType.Name.IndexOf("Handler", StringComparison.Ordinal) - 1)); } - private static string SpecialCasedHandlerName(ILspHandlerTypeDescriptor descriptor) + public static string SpecialCasedHandlerName(ILspHandlerTypeDescriptor descriptor) { var name = SpecialCasedHandlerFullName(descriptor); return name.Substring(name.LastIndexOf('.') + 1); diff --git a/test/Lsp.Tests/Lsp.Tests.csproj b/test/Lsp.Tests/Lsp.Tests.csproj index 9e90b0f75..6911d62b9 100644 --- a/test/Lsp.Tests/Lsp.Tests.csproj +++ b/test/Lsp.Tests/Lsp.Tests.csproj @@ -9,6 +9,7 @@ +